Merge pull request #219778 from samuela/samuela/jax
Update JAX and fix aarch64-darwin build
This commit is contained in:
commit
8faef6de41
@ -5,6 +5,7 @@
|
||||
, etils
|
||||
, fetchFromGitHub
|
||||
, jaxlib
|
||||
, jaxlib-bin
|
||||
, lapack
|
||||
, matplotlib
|
||||
, numpy
|
||||
@ -13,15 +14,20 @@
|
||||
, pytest-xdist
|
||||
, pythonOlder
|
||||
, scipy
|
||||
, stdenv
|
||||
, typing-extensions
|
||||
}:
|
||||
|
||||
let
|
||||
usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl";
|
||||
# jaxlib is broken on aarch64-* as of 2023-03-05, but the binary wheels work
|
||||
# fine. jaxlib is only used in the checkPhase, so switching backends does not
|
||||
# impact package behavior. Get rid of this once jaxlib is fixed on aarch64-*.
|
||||
jaxlib' = if jaxlib.meta.broken then jaxlib-bin else jaxlib;
|
||||
in
|
||||
buildPythonPackage rec {
|
||||
pname = "jax";
|
||||
version = "0.4.1";
|
||||
version = "0.4.5";
|
||||
format = "setuptools";
|
||||
|
||||
disabled = pythonOlder "3.7";
|
||||
@ -29,14 +35,14 @@ buildPythonPackage rec {
|
||||
src = fetchFromGitHub {
|
||||
owner = "google";
|
||||
repo = pname;
|
||||
rev = "refs/tags/jaxlib-v${version}";
|
||||
hash = "sha256-ajLI0iD0YZRK3/uKSbhlIZGc98MdW174vA34vhoy7Iw=";
|
||||
# google/jax contains tags for jax and jaxlib. Only use jax tags!
|
||||
rev = "refs/tags/${pname}-v${version}";
|
||||
hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA=";
|
||||
};
|
||||
|
||||
# jaxlib is _not_ included in propagatedBuildInputs because there are
|
||||
# different versions of jaxlib depending on the desired target hardware. The
|
||||
# JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the
|
||||
# CPU wheel is packaged.
|
||||
# JAX project ships separate wheels for CPU, GPU, and TPU.
|
||||
propagatedBuildInputs = [
|
||||
absl-py
|
||||
etils
|
||||
@ -47,7 +53,7 @@ buildPythonPackage rec {
|
||||
] ++ etils.optional-dependencies.epath;
|
||||
|
||||
nativeCheckInputs = [
|
||||
jaxlib
|
||||
jaxlib'
|
||||
matplotlib
|
||||
pytestCheckHook
|
||||
pytest-xdist
|
||||
@ -83,6 +89,11 @@ buildPythonPackage rec {
|
||||
"test_custom_linear_solve_cholesky"
|
||||
"test_custom_root_with_aux"
|
||||
"testEigvalsGrad_shape"
|
||||
] ++ lib.optionals (stdenv.isAarch64 && stdenv.isDarwin) [
|
||||
# See https://github.com/google/jax/issues/14793.
|
||||
"test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop"
|
||||
"testQdwhWithRandomMatrix3"
|
||||
"testScanGrad_jit_scan"
|
||||
];
|
||||
|
||||
# See https://github.com/google/jax/issues/11722. This is a temporary fix in
|
||||
|
@ -39,7 +39,7 @@ assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1";
|
||||
assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2";
|
||||
|
||||
let
|
||||
version = "0.3.22";
|
||||
version = "0.4.4";
|
||||
|
||||
pythonVersion = python.pythonVersion;
|
||||
|
||||
@ -50,21 +50,21 @@ let
|
||||
cpuSrcs = {
|
||||
"x86_64-linux" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-w2wo0jk+1BdEkNwfSZRQbebdI4Ac8Kgn0MB0cIMcWU4=";
|
||||
hash = "sha256-4VT909AB+ti5HzQvsaZWNY6MS/GItlVEFH9qeZnUuKQ=";
|
||||
};
|
||||
"aarch64-darwin" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl";
|
||||
hash = "sha256-7Ir55ZhBkccqfoa56WVBF8QwFAC2ws4KFHDkfVw6zm0=";
|
||||
hash = "sha256-wuOmoCeTldslSa0MommQeTe+RYKhUMam1ZXrgSov+8U=";
|
||||
};
|
||||
"x86_64-darwin" = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_10_14_x86_64.whl";
|
||||
hash = "sha256-bOoQI+T+YsTUNA+cDu6wwYTcq9fyyzCpK9qrdCrNVoA=";
|
||||
hash = "sha256-arfiTw8yafJwjRwJhKby2O7y3+4ksh3PjaKW9JgJ1ok=";
|
||||
};
|
||||
};
|
||||
|
||||
gpuSrc = fetchurl {
|
||||
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl";
|
||||
hash = "sha256-rabU62p4fF7Tu/6t8LNYZdf6YO06jGry/JtyFZeamCs=";
|
||||
hash = "sha256-bJ62DdzuPSV311ZI2R/LJQ3fOkDibtz2+8wDKw31FLk=";
|
||||
};
|
||||
in
|
||||
buildPythonPackage rec {
|
||||
@ -77,7 +77,13 @@ buildPythonPackage rec {
|
||||
# python version.
|
||||
disabled = !(pythonVersion == "3.10");
|
||||
|
||||
src = if !cudaSupport then cpuSrcs."${stdenv.hostPlatform.system}" else gpuSrc;
|
||||
# See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6.
|
||||
src =
|
||||
if !cudaSupport then
|
||||
(
|
||||
cpuSrcs."${stdenv.hostPlatform.system}"
|
||||
or (throw "jaxlib-bin is not supported on ${stdenv.hostPlatform.system}")
|
||||
) else gpuSrc;
|
||||
|
||||
# Prebuilt wheels are dynamically linked against things that nix can't find.
|
||||
# Run `autoPatchelfHook` to automagically fix them.
|
||||
|
@ -52,7 +52,7 @@ let
|
||||
inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;
|
||||
|
||||
pname = "jaxlib";
|
||||
version = "0.3.22";
|
||||
version = "0.4.4";
|
||||
|
||||
meta = with lib; {
|
||||
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
|
||||
@ -137,8 +137,9 @@ let
|
||||
src = fetchFromGitHub {
|
||||
owner = "google";
|
||||
repo = "jax";
|
||||
rev = "${pname}-v${version}";
|
||||
hash = "sha256-bnczJ8ma/UMKhA5MUQ6H4az+Tj+By14ZTG6lQQwptQs=";
|
||||
# google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
|
||||
rev = "refs/tags/${pname}-v${version}";
|
||||
hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo=";
|
||||
};
|
||||
|
||||
nativeBuildInputs = [
|
||||
@ -242,9 +243,9 @@ let
|
||||
|
||||
sha256 =
|
||||
if cudaSupport then
|
||||
"sha256-4yu4y4SwSQoeaOz9yojhvCRGSC6jp61ycVDIKyIK/l8="
|
||||
"sha256-cgsiloW77p4+TKRrYequZ/UwKwfO2jsHKtZ+aA30H7E="
|
||||
else
|
||||
"sha256-CyRfPfJc600M7VzR3/SQX/EAyeaXRJwDQWot5h2XnFU=";
|
||||
"sha256-D7WYG3YUaWq+4APYx8WpA191VVtoHG0fth3uEHXOeos=";
|
||||
};
|
||||
|
||||
buildAttrs = {
|
||||
|
Loading…
Reference in New Issue
Block a user