diff --git a/dependencies/requirements/base_requirements/requirements.txt b/dependencies/requirements/base_requirements/requirements.txt new file mode 100644 index 00000000..796254aa --- /dev/null +++ b/dependencies/requirements/base_requirements/requirements.txt @@ -0,0 +1,42 @@ +--extra-index-url https://download.pytorch.org/whl/cpu +absl-py +aqtp +datasets +einops +flax +ftfy +google-cloud-storage +grain +hf_transfer +huggingface_hub +imageio-ffmpeg +imageio +jax +jaxlib +Jinja2 +opencv-python-headless +optax +orbax-checkpoint +parameterized +Pillow +pyink +pylint +pytest +ruff +scikit-image +sentencepiece +tensorboard-plugin-profile +tensorboard +tensorboardx +tensorflow-datasets +tensorflow +tokamax +tokenizers +transformers + +# pinning torch and torchvision to specific versions to avoid +# installing GPU versions from PyPI when running seed-env +torch @ https://download.pytorch.org/whl/cpu/torch-2.10.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl +torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.25.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl +qwix @ https://github.com/google/qwix/archive/408a0f48f988b6c5b180e07f0cb1d05997bf0dcc.zip + diff --git a/dependencies/requirements/generated_requirements/requirements.txt b/dependencies/requirements/generated_requirements/requirements.txt new file mode 100644 index 00000000..f9da2e47 --- /dev/null +++ b/dependencies/requirements/generated_requirements/requirements.txt @@ -0,0 +1,191 @@ +# Generated by seed-env. Do not edit manually. +# If you need to modify dependencies, please do so in the host requirements file and run seed-env again. + +absl-py>=2.3.1 +aiofiles>=25.1.0 +aiohappyeyeballs>=2.6.1 +aiohttp>=3.13.3 +aiosignal>=1.4.0 +aqtp>=0.9.0 +array-record>=0.8.3 ; sys_platform != 'win32' +astroid>=4.0.4 +astunparse>=1.6.3 +attrs>=25.4.0 +auditwheel>=6.6.0 +black>=25.12.0 +build>=1.4.0 +certifi>=2026.1.4 +cffi>=2.0.0 ; platform_python_implementation != 'PyPy' +charset-normalizer>=3.4.4 +cheroot>=11.1.2 +chex>=0.1.91 +click>=8.3.1 +cloudpickle>=3.1.2 +colorama>=0.4.6 +contourpy>=1.3.3 +cryptography>=46.0.5 +cycler>=0.12.1 +dataclasses-json>=0.6.7 +datasets>=2.14.4 +decorator>=5.2.1 +dill>=0.3.7 +dm-tree>=0.1.9 +docstring-parser>=0.17.0 +einops>=0.8.2 +etils>=1.13.0 +execnet>=2.1.2 +filelock>=3.20.3 +flatbuffers>=25.12.19 +flax>=0.12.4 +fonttools>=4.61.1 +frozenlist>=1.8.0 +fsspec>=2026.1.0 +ftfy>=6.3.1 +gast>=0.7.0 +gcsfs>=2026.1.0 +google-api-core>=2.29.0 +google-auth-oauthlib>=1.2.4 +google-auth>=2.48.0 +google-cloud-core>=2.5.0 +google-cloud-storage-control>=1.10.0 +google-cloud-storage>=3.9.0 +google-crc32c>=1.8.0 +google-pasta>=0.2.0 +google-resumable-media>=2.8.0 +googleapis-common-protos>=1.72.0 +grain>=0.2.15 +grpc-google-iam-v1>=0.14.3 +grpcio-status>=1.76.0 +grpcio>=1.76.0 +gviz-api>=1.10.0 +h5py>=3.15.1 +hf-transfer>=0.1.9 +hf-xet>=1.2.1 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' +huggingface-hub>=0.36.2 +humanize>=4.15.0 +hypothesis>=6.142.1 +idna>=3.11 +imageio-ffmpeg>=0.6.0 +imageio>=2.37.2 +immutabledict>=4.3.0 +importlib-resources>=6.5.2 +iniconfig>=2.3.0 +isort>=7.0.0 +jaraco-functools>=4.4.0 +jax>=0.9.0 +jaxlib>=0.9.0 +jaxtyping>=0.3.7 +jinja2>=3.1.6 +keras>=3.13.1 +kiwisolver>=1.4.9 +lazy-loader>=0.4 +libclang>=18.1.1 +libtpu>=0.0.34 ; platform_machine == 'x86_64' and sys_platform == 'linux' +markdown-it-py>=4.0.0 +markdown>=3.10.1 +markupsafe>=3.0.3 +marshmallow>=3.26.2 +matplotlib>=3.10.8 +mccabe>=0.7.0 +mdurl>=0.1.2 +ml-dtypes>=0.5.4 +more-itertools>=10.8.0 +mpmath>=1.3.0 +msgpack>=1.1.2 +multidict>=6.7.1 +multiprocess>=0.70.15 +mypy-extensions>=1.1.0 +namex>=0.1.0 +nest-asyncio>=1.6.0 +networkx>=3.6.1 +numpy-typing-compat>=20251206.2.0 +numpy>=2.0.2 +nvidia-cuda-cccl>=13.1.115 +oauthlib>=3.3.1 +opencv-python-headless>=4.13.0.92 +opt-einsum>=3.4.0 +optax>=0.2.6 +optree>=0.18.0 +optype>=0.15.0 +orbax-checkpoint>=0.11.32 +orbax-export>=0.0.8 +packaging>=26.0 +pandas>=3.0.0 +parameterized>=0.9.0 +pathspec>=1.0.4 +pillow>=12.1.0 +platformdirs>=4.7.1 +pluggy>=1.6.0 +portpicker>=1.6.0 +promise>=2.3 +propcache>=0.4.1 +proto-plus>=1.27.1 +protobuf>=6.33.5 +psutil>=7.2.1 +pyarrow>=23.0.0 +pyasn1-modules>=0.4.2 +pyasn1>=0.6.2 +pycparser>=3.0 ; implementation_name != 'PyPy' and platform_python_implementation != 'PyPy' +pyelftools>=0.32 +pygments>=2.19.2 +pyink>=25.12.0 +pylint>=4.0.4 +pyparsing>=3.3.2 +pyproject-hooks>=1.2.0 +pytest-xdist>=3.8.0 +pytest>=8.4.2 +python-dateutil>=2.9.0.post0 +pytokens>=0.4.1 +pyyaml>=6.0.3 +qwix @ https://github.com/google/qwix/archive/408a0f48f988b6c5b180e07f0cb1d05997bf0dcc.zip +regex>=2026.1.15 +requests-oauthlib>=2.0.0 +requests>=2.32.5 +rich>=14.2.0 +rsa>=4.9.1 +ruff>=0.15.1 +safetensors>=0.7.0 +scikit-image>=0.26.0 +scipy-stubs>=1.17.0.1 +scipy>=1.17.0 +sentencepiece>=0.2.1 +setuptools>=80.10.1 +simple-parsing>=0.1.8 +simplejson>=3.20.2 +six>=1.17.0 +sortedcontainers>=2.4.0 +sympy>=1.14.0 +tensorboard-data-server>=0.7.2 +tensorboard-plugin-profile>=2.21.6 +tensorboard>=2.20.0 +tensorboardx>=2.6.4 +tensorflow-datasets>=4.9.9 +tensorflow-metadata>=1.17.3 +tensorflow>=2.20.0 +tensorstore>=0.1.80 +termcolor>=3.3.0 +tifffile>=2026.1.28 +tokamax>=0.1.0 +tokenizers>=0.22.2 +toml>=0.10.2 +tomlkit>=0.14.0 +toolz>=1.1.0 +torch @ https://download.pytorch.org/whl/cpu/torch-2.10.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl +torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.25.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl +tqdm>=4.67.3 +transformers>=4.57.6 +treescope>=0.1.10 +typing-extensions>=4.15.0 +typing-inspect>=0.9.0 +tzdata>=2025.3 ; sys_platform == 'emscripten' or sys_platform == 'win32' +urllib3>=2.6.3 +wadler-lindig>=0.1.7 +wcwidth>=0.6.0 +werkzeug>=3.1.5 +wheel>=0.46.2 +wrapt>=2.1.1 +xprof>=2.21.6 +xxhash>=3.6.0 +yarl>=1.22.0 +zipp>=3.23.0 +zstandard>=0.25.0 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b631a3f3..f3fa6e6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,55 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +[build-system] +requires = ["hatchling", "hatch-requirements-txt"] +build-backend = "hatchling.build" + +[tool.hatch.version] +path = "src/maxdiffusion/__init__.py" + +[project] +name = "maxdiffusion" +dynamic = ["version", "optional-dependencies"] +requires-python = ">=3.12" +readme = "README.md" +license = "Apache-2.0" +classifiers = [ + "Programming Language :: Python", +] +dependencies = [] + +[tool.hatch.metadata.hooks.requirements_txt.optional-dependencies] +tpu = ["dependencies/requirements/generated_requirements/tpu-requirements.txt"] +cuda12 = ["dependencies/requirements/generated_requirements/cuda12-requirements.txt"] + +[project.urls] +Repository = "https://github.com/AI-Hypercomputer/maxdiffusion.git" +"Bug Tracker" = "https://github.com/AI-Hypercomputer/maxdiffusion/issues" + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["src/maxdiffusion", "src/install_maxdiffusion_extra_deps"] + +[tool.hatch.build.targets.wheel.hooks.custom] +path = "build_hooks.py" + +[project.scripts] +install_maxdiffusion_github_deps = "install_maxdiffusion_extra_deps.install_github_deps:main" + [tool.ruff] # Never enforce `E501` (line length violations). ignore = ["C901", "E501", "E741", "F402", "F823", "E402", "I001"] diff --git a/setup.sh b/setup.sh index ab81c4a7..63993b26 100644 --- a/setup.sh +++ b/setup.sh @@ -35,13 +35,14 @@ if ! python3 -c 'import sys; assert sys.version_info >= (3, 12)' 2>/dev/null; th if [[ $REPLY =~ ^[Yy]$ ]]; then # Check if uv is installed first; if not, install uv if ! command -v uv &> /dev/null; then - echo -e "\n'uv' command not found. Installing it now via the official installer..." - curl -LsSf https://astral.sh/uv/install.sh | sh + # echo -e "\n'uv' command not found. Installing it now via the official installer..." + # curl -LsSf https://astral.sh/uv/install.sh | sh - echo -e "\n\e[33m'uv' has been installed.\e[0m" - echo "The installer likely printed instructions to update your shell's PATH." - echo "Please open a NEW terminal session (or 'source ~/.bashrc') and re-run this script." - exit 1 + # echo -e "\n\e[33m'uv' has been installed.\e[0m" + # echo "The installer likely printed instructions to update your shell's PATH." + # echo "Please open a NEW terminal session (or 'source ~/.bashrc') and re-run this script." + # exit 1 + pip install uv fi maxdiffusion_dir=$(pwd) cd @@ -53,7 +54,7 @@ if ! python3 -c 'import sys; assert sys.version_info >= (3, 12)' 2>/dev/null; th echo "No name provided. Using default name: '$venv_name'" fi echo "Creating virtual environment '$venv_name' with Python 3.12..." - uv venv --python 3.12 "$venv_name" --seed + python3 -m uv venv --python 3.12 "$venv_name" --seed printf '%s\n' "$(realpath -- "$venv_name")" >> /tmp/venv_created echo -e "\n\e[32mVirtual environment '$venv_name' created successfully!\e[0m" echo "To activate it, run the following command:" @@ -81,6 +82,8 @@ apt update -y && apt -y install gcsfuse rm -rf /var/lib/apt/lists/* EOF +python3 -m pip install -U setuptools wheel uv + # Set environment variables from command line arguments for ARGUMENT in "$@"; do IFS='=' read -r KEY VALUE <<< "$ARGUMENT" @@ -104,7 +107,7 @@ if [[ -n $JAX_VERSION && ! ($MODE == "stable" || -z $MODE) ]]; then fi # Install dependencies from requirements.txt first -pip3 install -U -r requirements.txt || echo "Failed to install dependencies in the requirements" >&2 +python3 -m uv pip install -U -r requirements.txt || echo "Failed to install dependencies in the requirements" >&2 # Install JAX and JAXlib based on the specified mode if [[ "$MODE" == "stable" || ! -v MODE ]]; then @@ -113,23 +116,23 @@ if [[ "$MODE" == "stable" || ! -v MODE ]]; then echo "Installing stable jax, jaxlib for tpu" if [[ -n "$JAX_VERSION" ]]; then echo "Installing stable jax, jaxlib, libtpu version ${JAX_VERSION}" - pip3 install "jax[tpu]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + python3 -m uv pip install "jax[tpu]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html else echo "Installing stable jax, jaxlib, libtpu for tpu" - pip3 install 'jax[tpu]>0.4' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + python3 -m uv pip install 'jax[tpu]>0.4' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html fi elif [[ $DEVICE == "gpu" ]]; then echo "Installing stable jax, jaxlib for NVIDIA gpu" if [[ -n "$JAX_VERSION" ]]; then echo "Installing stable jax, jaxlib ${JAX_VERSION}" - pip3 install -U "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + python3 -m uv pip install -U "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html else echo "Installing stable jax, jaxlib, libtpu for NVIDIA gpu" - pip3 install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + python3 -m uv pip install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html fi export NVTE_FRAMEWORK=jax - pip3 install transformer_engine[jax]==2.1.0 + python3 -m uv pip install transformer_engine[jax]==2.1.0 fi elif [[ $MODE == "nightly" ]]; then @@ -140,22 +143,22 @@ elif [[ $MODE == "nightly" ]]; then pip install -U --pre jax jaxlib jax-cuda12-plugin[with_cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html # Install Transformer Engine export NVTE_FRAMEWORK=jax - pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable + python3 -m uv pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable elif [[ $DEVICE == "tpu" ]]; then echo "Installing jax-nightly,jaxlib-nightly" # Install jax-nightly - pip3 install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html + python3 -m uv pip install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html # Install jaxlib-nightly - pip3 install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + python3 -m uv pip install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html # Install libtpu-nightly - pip3 install --pre -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + python3 -m uv pip install --pre -U libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html fi echo "Installing nightly tensorboard plugin profile" - pip3 install tbp-nightly --upgrade + python3 -m uv pip install tbp-nightly --upgrade else echo -e "\n\nError: You can only set MODE to [stable,nightly].\n\n" exit 1 fi # Install maxdiffusion -pip3 install -U . || echo "Failed to install maxdiffusion" >&2 \ No newline at end of file +python3 -m uv pip install -U . || echo "Failed to install maxdiffusion" >&2 \ No newline at end of file diff --git a/src/install_maxdiffusion_extra_deps/__init__.py b/src/install_maxdiffusion_extra_deps/__init__.py new file mode 100644 index 00000000..b6c442c7 --- /dev/null +++ b/src/install_maxdiffusion_extra_deps/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/install_maxdiffusion_extra_deps/extra_deps_from_github.txt b/src/install_maxdiffusion_extra_deps/extra_deps_from_github.txt new file mode 100644 index 00000000..130c41ae --- /dev/null +++ b/src/install_maxdiffusion_extra_deps/extra_deps_from_github.txt @@ -0,0 +1,3 @@ +torch @ https://download.pytorch.org/whl/cpu/torch-2.10.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl +torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.25.0%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl +qwix @ https://github.com/google/qwix/archive/408a0f48f988b6c5b180e07f0cb1d05997bf0dcc.zip diff --git a/src/install_maxdiffusion_extra_deps/install_github_deps.py b/src/install_maxdiffusion_extra_deps/install_github_deps.py new file mode 100644 index 00000000..7b1c288c --- /dev/null +++ b/src/install_maxdiffusion_extra_deps/install_github_deps.py @@ -0,0 +1,91 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Installs extra dependencies from a requirements file using uv. + +This script is designed to be run to install dependencies specified in +'extra_deps_from_github.txt', which is expected to be in the same directory. +It first ensures 'uv' is installed and then uses it to install the packages +listed in the requirements file. +""" + +import subprocess +import sys +from pathlib import Path + + +def main(): + """ + Installs extra dependencies specified in extra_deps_from_github.txt using uv. + + This script looks for 'extra_deps_from_github.txt' relative to its own location. + It executes 'uv pip install -r --resolution=lowest'. + """ + script_dir = Path(__file__).resolve().parent + + # Adjust this path if your extra_deps_from_github.txt is in a different location, + # e.g., script_dir / "data" / "extra_deps_from_github.txt" + extra_deps_file = script_dir / "extra_deps_from_github.txt" + + if not extra_deps_file.exists(): + print(f"Error: '{extra_deps_file}' not found.") + print("Please ensure 'extra_deps_from_github.txt' is in the correct location relative to the script.") + sys.exit(1) + # Check if 'uv' is available in the environment + try: + subprocess.run([sys.executable, "-m", "pip", "install", "uv"], check=True, capture_output=True) + subprocess.run([sys.executable, "-m", "uv", "--version"], check=True, capture_output=True) + except subprocess.CalledProcessError as e: + print(f"Error checking uv version: {e}") + print(f"Stderr: {e.stderr.decode()}") + sys.exit(1) + + command = [ + sys.executable, # Use the current Python executable's pip to ensure the correct environment + "-m", + "uv", + "pip", + "install", + "-r", + str(extra_deps_file), + "--no-deps", + ] + + print(f"Installing extra dependencies from '{extra_deps_file}' using uv...") + print(f"Running command: {' '.join(command)}") + + try: + # Run the command + process = subprocess.run(command, check=True, capture_output=True, text=True) + print("Extra dependencies installed successfully!") + print("--- Output from uv ---") + print(process.stdout) + if process.stderr: + print("--- Errors/Warnings from uv (if any) ---") + print(process.stderr) + except subprocess.CalledProcessError as e: + print("Failed to install extra dependencies.") + print(f"Command '{' '.join(e.cmd)}' returned non-zero exit status {e.returncode}.") + print("--- Stderr ---") + print(e.stderr) + print("--- Stdout ---") + print(e.stdout) + sys.exit(e.returncode) + except (OSError, FileNotFoundError) as e: + print(f"An OS-level error occurred while trying to run uv: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main()