Merge pull request #285037 from breakds/PR/breakds/fix_jaxlib_bin_and_cuda
jaxlib-bin: use correct cuda releases
This commit is contained in:
commit
e8a69d497b
@ -33,7 +33,7 @@
|
|||||||
}:
|
}:
|
||||||
|
|
||||||
let
|
let
|
||||||
inherit (cudaPackagesGoogle) cudatoolkit cudnn;
|
inherit (cudaPackagesGoogle) cudatoolkit cudnn cudaVersion;
|
||||||
|
|
||||||
version = "0.4.23";
|
version = "0.4.23";
|
||||||
|
|
||||||
@ -118,26 +118,44 @@ let
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
# Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html.
|
# Note that the prebuilt jaxlib binary requires specific version of CUDA to
|
||||||
|
# work. The cuda12 jaxlib binaries only works with CUDA 12.2, and cuda11
|
||||||
|
# jaxlib binaries only works with CUDA 11.8. This is why we need to find a
|
||||||
|
# binary that matches the provided cudaVersion.
|
||||||
|
gpuSrcVersionString = "cuda${cudaVersion}-${pythonVersion}";
|
||||||
|
|
||||||
|
# Find new releases at https://storage.googleapis.com/jax-releases
|
||||||
# When upgrading, you can get these hashes from prefetch.sh. See
|
# When upgrading, you can get these hashes from prefetch.sh. See
|
||||||
# https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index.
|
# https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index.
|
||||||
gpuSrcs = {
|
gpuSrcs = {
|
||||||
"3.9" = fetchurl {
|
"cuda12.2-3.9" = fetchurl {
|
||||||
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl";
|
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl";
|
||||||
hash = "sha256-our2mSwHPdjVoDAZP+9aNUkJ+vxv1Tq7G5UqA9HvhNI=";
|
hash = "sha256-our2mSwHPdjVoDAZP+9aNUkJ+vxv1Tq7G5UqA9HvhNI=";
|
||||||
};
|
};
|
||||||
"3.10" = fetchurl {
|
"cuda12.2-3.10" = fetchurl {
|
||||||
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl";
|
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl";
|
||||||
hash = "sha256-jkIABnJZnn7A6n9VGs/MldzdDiKwWh0fEvl7Vqn85Kg=";
|
hash = "sha256-jkIABnJZnn7A6n9VGs/MldzdDiKwWh0fEvl7Vqn85Kg=";
|
||||||
};
|
};
|
||||||
"3.11" = fetchurl {
|
"cuda12.2-3.11" = fetchurl {
|
||||||
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl";
|
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl";
|
||||||
hash = "sha256-dMUcRnHjl8NyUeO3P1x7CNgF0iAHFKIzUtHh+/CNkow=";
|
hash = "sha256-dMUcRnHjl8NyUeO3P1x7CNgF0iAHFKIzUtHh+/CNkow=";
|
||||||
};
|
};
|
||||||
"3.12" = fetchurl {
|
"cuda12.2-3.12" = fetchurl {
|
||||||
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl";
|
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl";
|
||||||
hash = "sha256-kXJ6bUwX+QybqYPV9Kpwv+lhdoGEFRr4+1T0vfXoWRo=";
|
hash = "sha256-kXJ6bUwX+QybqYPV9Kpwv+lhdoGEFRr4+1T0vfXoWRo=";
|
||||||
};
|
};
|
||||||
|
"cuda11.8-3.9" = fetchurl {
|
||||||
|
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl";
|
||||||
|
hash = "sha256-m2Y5p12gF3OaADu+aGw5RjcKFrj9RB8xzNWnKNpSz60=";
|
||||||
|
};
|
||||||
|
"cuda11.8-3.10" = fetchurl {
|
||||||
|
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
|
||||||
|
hash = "osha256-aQ7iX3o0kQ4liPexv7dkBVWVTUpaty83L083MybGkf0=";
|
||||||
|
};
|
||||||
|
"cuda11.8-3.11" = fetchurl {
|
||||||
|
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl";
|
||||||
|
hash = "sha256-uIEyjEmv0HBaiYVl5PuICTI9XnH4zAfQ1l9tjALRcP4=";
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
in
|
in
|
||||||
@ -154,7 +172,7 @@ buildPythonPackage {
|
|||||||
(
|
(
|
||||||
cpuSrcs."${pythonVersion}-${stdenv.hostPlatform.system}"
|
cpuSrcs."${pythonVersion}-${stdenv.hostPlatform.system}"
|
||||||
or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}")
|
or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}")
|
||||||
) else gpuSrcs."${pythonVersion}";
|
) else gpuSrcs."${gpuSrcVersionString}";
|
||||||
|
|
||||||
# Prebuilt wheels are dynamically linked against things that nix can't find.
|
# Prebuilt wheels are dynamically linked against things that nix can't find.
|
||||||
# Run `autoPatchelfHook` to automagically fix them.
|
# Run `autoPatchelfHook` to automagically fix them.
|
||||||
@ -212,6 +230,7 @@ buildPythonPackage {
|
|||||||
broken =
|
broken =
|
||||||
!(cudaSupport -> (cudaPackagesGoogle ? cudatoolkit) && lib.versionAtLeast cudatoolkit.version "11.1")
|
!(cudaSupport -> (cudaPackagesGoogle ? cudatoolkit) && lib.versionAtLeast cudatoolkit.version "11.1")
|
||||||
|| !(cudaSupport -> (cudaPackagesGoogle ? cudnn) && lib.versionAtLeast cudnn.version "8.2")
|
|| !(cudaSupport -> (cudaPackagesGoogle ? cudnn) && lib.versionAtLeast cudnn.version "8.2")
|
||||||
|| !(cudaSupport -> stdenv.isLinux);
|
|| !(cudaSupport -> stdenv.isLinux)
|
||||||
|
|| !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}"));
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user