Merge pull request #285037 from breakds/PR/breakds/fix_jaxlib_bin_and_cuda

jaxlib-bin: use correct cuda releases
This commit is contained in:
Samuel Ainsworth 2024-02-09 09:02:04 -05:00 committed by GitHub
commit e8a69d497b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -33,7 +33,7 @@
}:
let
inherit (cudaPackagesGoogle) cudatoolkit cudnn;
inherit (cudaPackagesGoogle) cudatoolkit cudnn cudaVersion;
version = "0.4.23";
@ -118,26 +118,44 @@ let
};
};
# Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html.
# Note that the prebuilt jaxlib binary requires specific version of CUDA to
# work. The cuda12 jaxlib binaries only works with CUDA 12.2, and cuda11
# jaxlib binaries only works with CUDA 11.8. This is why we need to find a
# binary that matches the provided cudaVersion.
gpuSrcVersionString = "cuda${cudaVersion}-${pythonVersion}";
# Find new releases at https://storage.googleapis.com/jax-releases
# When upgrading, you can get these hashes from prefetch.sh. See
# https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index.
gpuSrcs = {
"3.9" = fetchurl {
"cuda12.2-3.9" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl";
hash = "sha256-our2mSwHPdjVoDAZP+9aNUkJ+vxv1Tq7G5UqA9HvhNI=";
};
"3.10" = fetchurl {
"cuda12.2-3.10" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-jkIABnJZnn7A6n9VGs/MldzdDiKwWh0fEvl7Vqn85Kg=";
};
"3.11" = fetchurl {
"cuda12.2-3.11" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl";
hash = "sha256-dMUcRnHjl8NyUeO3P1x7CNgF0iAHFKIzUtHh+/CNkow=";
};
"3.12" = fetchurl {
"cuda12.2-3.12" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl";
hash = "sha256-kXJ6bUwX+QybqYPV9Kpwv+lhdoGEFRr4+1T0vfXoWRo=";
};
"cuda11.8-3.9" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl";
hash = "sha256-m2Y5p12gF3OaADu+aGw5RjcKFrj9RB8xzNWnKNpSz60=";
};
"cuda11.8-3.10" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
hash = "osha256-aQ7iX3o0kQ4liPexv7dkBVWVTUpaty83L083MybGkf0=";
};
"cuda11.8-3.11" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl";
hash = "sha256-uIEyjEmv0HBaiYVl5PuICTI9XnH4zAfQ1l9tjALRcP4=";
};
};
in
@ -154,7 +172,7 @@ buildPythonPackage {
(
cpuSrcs."${pythonVersion}-${stdenv.hostPlatform.system}"
or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}")
) else gpuSrcs."${pythonVersion}";
) else gpuSrcs."${gpuSrcVersionString}";
# Prebuilt wheels are dynamically linked against things that nix can't find.
# Run `autoPatchelfHook` to automagically fix them.
@ -212,6 +230,7 @@ buildPythonPackage {
broken =
!(cudaSupport -> (cudaPackagesGoogle ? cudatoolkit) && lib.versionAtLeast cudatoolkit.version "11.1")
|| !(cudaSupport -> (cudaPackagesGoogle ? cudnn) && lib.versionAtLeast cudnn.version "8.2")
|| !(cudaSupport -> stdenv.isLinux);
|| !(cudaSupport -> stdenv.isLinux)
|| !(cudaSupport -> (gpuSrcs ? "cuda${cudaVersion}-${pythonVersion}"));
};
}