diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix index 9970783aa3bc..07f4f5efedbd 100644 --- a/pkgs/development/python-modules/jax/default.nix +++ b/pkgs/development/python-modules/jax/default.nix @@ -21,7 +21,7 @@ let in buildPythonPackage rec { pname = "jax"; - version = "0.3.16"; + version = "0.3.23"; format = "setuptools"; disabled = pythonOlder "3.7"; @@ -30,7 +30,7 @@ buildPythonPackage rec { owner = "google"; repo = pname; rev = "jax-v${version}"; - hash = "sha256-4idh7boqBXSO9vEHxEcrzXjBIrKmmXiCf6cXh7En1/I="; + hash = "sha256-ruXOwpBwpi1G8jgH9nhbWbs14JupwWkjh+Wzrj8HVU4="; }; # jaxlib is _not_ included in propagatedBuildInputs because there are @@ -92,9 +92,8 @@ buildPythonPackage rec { "tests/sparse_test.py" ]; - pythonImportsCheck = [ - "jax" - ]; + # As of 0.3.22, `import jax` does not work without jaxlib being installed. + pythonImportsCheck = [ ]; meta = with lib; { description = "Differentiate, compile, and transform Numpy code";