Skip to content

Commit

Permalink
Refactor build-jax.sh and pin orbax-checkpoint (#1211)
Browse files Browse the repository at this point in the history
  • Loading branch information
DwarKapex authored Dec 21, 2024
1 parent e301808 commit 92cc80b
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 9 deletions.
11 changes: 10 additions & 1 deletion .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,21 @@ ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/
RUN mkdir -p /opt/pip-tools.d

## Editable installations of jax and jaxlib
## For 25.01 release we also pin several packages obtained
## from https://github.com/jax-ml/jax-ai-stack
RUN <<"EOF" bash -ex
for component in $(ls ${BUILD_PATH_JAXLIB}); do
echo "-e file://${BUILD_PATH_JAXLIB}/${component}" >> /opt/pip-tools.d/requirements-jax.in;
done
echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/requirements-jax.in
echo "numpy<2.0.0" >> /opt/pip-tools.d/requirements-jax.in
for pkg in \
"ml_dtypes==0.4.0" \
"optax==0.2.4" \
"orbax-checkpoint==0.10.2" \
"orbax-export==0.0.6" \
; do
echo "$pkg" >> /opt/pip-tools.d/requirements-jax.in
done
EOF

## Flax
Expand Down
16 changes: 16 additions & 0 deletions .github/container/Dockerfile.maxtext
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ ARG URLREF_MAXTEXT=https://github.com/google/maxtext.git#main
ARG URLREF_TFTEXT=https://github.com/tensorflow/text.git#master
ARG SRC_PATH_MAXTEXT=/opt/maxtext
ARG SRC_PATH_TFTEXT=/opt/tensorflow-text
ARG URLREF_JETSTREAM=https://github.com/google/jetstream.git#main
ARG SRC_PATH_JETSTREAM=/opt/jetstream

###############################################################################
## build tensorflow-text and lingvo, which do not have working arm64 pip wheels
Expand Down Expand Up @@ -56,6 +58,7 @@ RUN echo "tensorflow-text @ file://$(ls /opt/tensorflow_text*.whl)" >> /opt/pip-

RUN <<"EOF" bash -ex
git-clone.sh ${URLREF_MAXTEXT} ${SRC_PATH_MAXTEXT}
sed -i '/google-jetstream/d' ${SRC_PATH_MAXTEXT}/requirements.txt
echo "-r ${SRC_PATH_MAXTEXT}/requirements.txt" >> /opt/pip-tools.d/requirements-maxtext.in

# specify some restrictions to speed up the build and
Expand All @@ -64,6 +67,7 @@ for pattern in \
"s|absl-py|absl-py>=2.1.0|g" \
"s|protobuf==3.20.3|protobuf>=3.19.0|g" \
"s|tensorflow-datasets|tensorflow-datasets>=4.8.0|g" \
"s|grain-nightly|grain|g" \
; do
sed -i "${pattern}" ${SRC_PATH_MAXTEXT}/requirements.txt;
done
Expand All @@ -76,6 +80,18 @@ EOF

ADD test-maxtext.sh /usr/local/bin

###############################################################################
## Add JetStream
###############################################################################

ARG URLREF_JETSTREAM
ARG SRC_PATH_JETSTREAM

RUN <<"EOF" bash -ex
git-clone.sh ${URLREF_JETSTREAM} ${SRC_PATH_JETSTREAM}
echo "-e file://${SRC_PATH_JETSTREAM}" >> /opt/pip-tools.d/requirements-jetstream.in
EOF

###############################################################################
## Install accumulated packages from the base image and the previous stage
###############################################################################
Expand Down
1 change: 1 addition & 0 deletions .github/container/Dockerfile.t5x
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ echo "seqio-nightly>=0.0.18.dev20240714" >> /opt/pip-tools.d/requirements-t5x.in
# 2. Remove head-of-tree specs from select dependencies
pushd ${SRC_PATH_T5X}
sed -i "s| @ git+https://github.com/google/flax#egg=flax||g" setup.py
sed -i "s| @ git+https://github.com/deepmind/optax#egg=optax||g" setup.py

# for ARM64 build
if [[ "$(dpkg --print-architecture)" == "arm64" ]]; then
Expand Down
14 changes: 9 additions & 5 deletions .github/container/build-jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,17 @@ else
fi

# install jax and jaxlib
pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax-cuda-pjrt -e ${BUILD_PATH_JAXLIB}/jax-cuda-plugin -e "${SRC_PATH_JAX}"
pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax-cuda-pjrt -e ${BUILD_PATH_JAXLIB}/jax-cuda-plugin
jaxlib_version=$(pip show jaxlib | grep Version | tr ':' '\n' | tail -1)
sed -i "s|^_current_jaxlib_version.*|_current_jaxlib_version = '${jaxlib_version}'|" /opt/jax/setup.py
sed -i "s| f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}',| f'jaxlib>=0.4.30',|" /opt/jax/setup.py
pip --disable-pip-version-check install -e "${SRC_PATH_JAX}"

## after installation (example)
# jax 0.4.36.dev20241125+f828f2d7d /opt/jax
# jax-cuda12-pjrt 0.4.36.dev20241125 /opt/jaxlibs/jax-cuda-pjrt
# jax-cuda12-plugin 0.4.36.dev20241125 /opt/jaxlibs/jax-cuda-plugin
# jaxlib 0.4.36.dev20241125 /opt/jaxlibs/jaxlib
# jax 0.4.36.dev20241220+f828f2d7d /opt/jax
# jax-cuda12-pjrt 0.4.36.dev20241220 /opt/jaxlibs/jax-cuda-pjrt
# jax-cuda12-plugin 0.4.36.dev20241220 /opt/jaxlibs/jax-cuda-plugin
# jaxlib 0.4.36.dev20241220 /opt/jaxlibs/jaxlib
pip list | grep jax

# Ensure directories are readable by all for non-root users
Expand Down
6 changes: 3 additions & 3 deletions .github/container/manifest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ flax:
url: https://github.com/google/flax.git
mirror_url: https://github.com/nvjax-svc-0/flax.git
tracking_ref: main
latest_verified_commit: 718aa8ccb12c3fdefcf3d196874e4fc667b3ade5
latest_verified_commit: d89c955d1faac9dd2162a0c674f7897f2c53f54d
mode: git-clone
patches:
pull/3340/head: file://patches/flax/PR-3340.patch # Add Sharding Annotations to Flax Modules
Expand Down Expand Up @@ -177,8 +177,8 @@ panopticapi:
mode: git-clone
orbax-checkpoint:
url: https://github.com/google/orbax.git
tracking_ref: main
latest_verified_commit: 16c2d409e365576284dbaf190ac002b24c1f927f
tracking_ref: v0.10.2
latest_verified_commit: d6101bad9ec5ddee8ee8b8c10e1d27d6c57f0963
mode: pip-vcs
pathwaysutils:
url: https://github.com/google/pathways-utils.git
Expand Down

0 comments on commit 92cc80b

Please sign in to comment.