Skip to content

Commit

Permalink
[ROCm] fix the flax and praxis installation in jax ci
Browse files Browse the repository at this point in the history
  • Loading branch information
wangye805 committed Dec 20, 2024
1 parent fe9e149 commit e1f2ecc
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions ci/jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ DIR=`dirname $0`

. $DIR/_utils.sh

install_flax() {
pip list | awk '/jax/ { print $1"=="$2}' > reqs
pip install flax -r reqs
}

install_praxis() {
git clone https://github.com/google/praxis.git && cd praxis || return $?
git checkout $_praxis_commit || return $?
Expand All @@ -23,7 +28,13 @@ install_praxis() {
}

install_prerequisites() {
_praxis_commit="899b56ebe9128a0"
install_flax; rc=$?
if [ $rc -ne 0 ]; then
script_error "Failed to install flax"
exit $rc
fi

_praxis_commit="3f4cbb4bcda366db"
pip show jaxlib | grep Version | grep -q 0.4.23
if [ $? -eq 0 ]; then
echo "JAX lib 0.4.23 is detected"
Expand All @@ -47,8 +58,6 @@ install_prerequisites() {
test $rc -eq 0 || exit $rc
fi

pip install 'ml-dtypes>=0.2.0' 'typing_extensions>=4.11.0'
rc=$?
if [ $rc -ne 0 ]; then
script_error "Failed to install test prerequisites"
exit $rc
Expand Down

0 comments on commit e1f2ecc

Please sign in to comment.