2021-08-22 21:39:04 +01:00
{ buildPythonPackage , fetchFromGitHub , lib
# propagatedBuildInputs
, absl-py , numpy , opt-einsum
# checkInputs
, jaxlib , pytestCheckHook
} :
buildPythonPackage rec {
pname = " j a x " ;
version = " 0 . 2 . 1 9 " ;
# Fetching from pypi doesn't allow us to run the test suite. See https://discourse.nixos.org/t/pythonremovetestsdir-hook-being-run-before-checkphase/14612/3.
src = fetchFromGitHub {
owner = " g o o g l e " ;
repo = pname ;
rev = " j a x - v ${ version } " ;
sha256 = " s h a 2 5 6 - p V n 6 2 G 7 p y d R 7 y b k f 7 g S b u 0 F l E q 2 c 0 U S 6 H 2 G T B A l j u p 4 = " ;
} ;
# jaxlib is _not_ included in propagatedBuildInputs because there are
# different versions of jaxlib depending on the desired target hardware. The
# JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the
# CPU wheel is packaged.
propagatedBuildInputs = [ absl-py numpy opt-einsum ] ;
checkInputs = [ jaxlib pytestCheckHook ] ;
# NOTE: Don't run the tests in the expiremental directory as they require flax
# which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2.
# Not a big deal, this is how the JAX docs suggest running the test suite
# anyhow.
pytestFlagsArray = [ " - W i g n o r e : : D e p r e c a t i o n W a r n i n g " " t e s t s / " ] ;
meta = with lib ; {
2021-09-01 22:04:02 +01:00
description = " D i f f e r e n t i a t e , c o m p i l e , a n d t r a n s f o r m N u m p y c o d e " ;
2021-08-22 21:39:04 +01:00
homepage = " h t t p s : / / g i t h u b . c o m / g o o g l e / j a x " ;
license = licenses . asl20 ;
maintainers = with maintainers ; [ samuela ] ;
} ;
}