python3Packages.pytorch: fixup builds with MKL
- Pass `blas.provider` into `buildInputs`, so that CMake can find the actual `mkl` for inspection of its cmake files and headers. - Add `USE_MKL` correctly when the blas provider is `mkl`. - Use the MKLDNN and MKLDNN_CBLAS flags by default, since `mkldnn` is FOSS and always available.. - Remove a patch for MKL 2019, since we've moved to 2020. - Add a pythonImportsCheck for "torch" as a basic sanity-check - Removed some unused variables at the top of the file
This commit is contained in:
parent
cc245fdcd5
commit
fb5bb25c10
@ -1,11 +1,10 @@
|
||||
{ stdenv, fetchurl, fetchgit, fetchpatch, buildPythonPackage, python, pythonOlder,
|
||||
{ stdenv, lib, fetchFromGitHub, fetchpatch, buildPythonPackage, python,
|
||||
cudaSupport ? false, cudatoolkit ? null, cudnn ? null, nccl ? null, magma ? null,
|
||||
mklDnnSupport ? true, useSystemNccl ? true,
|
||||
openMPISupport ? false, openmpi ? null,
|
||||
buildBinaries ? false,
|
||||
buildDocs ? false,
|
||||
cudaArchList ? null,
|
||||
fetchFromGitHub, lib, numpy, pyyaml, cffi, click, typing, cmake, hypothesis, numactl, psutil,
|
||||
numpy, pyyaml, cffi, click, typing, cmake, dnnl, hypothesis, numactl, psutil,
|
||||
linkFarm, symlinkJoin,
|
||||
|
||||
# virtual pkg that consistently instantiates blas across nixpkgs
|
||||
@ -152,7 +151,15 @@ in buildPythonPackage rec {
|
||||
|
||||
BUILD_NAMEDTENSOR = true;
|
||||
BUILD_DOCS = buildDocs;
|
||||
|
||||
USE_MKL = blas.implementation == "mkl";
|
||||
|
||||
# Unlike MKL, MKLDNN is FOSS, so we enable support for it by default. Note
|
||||
# that this was renamed to dnnl and then renamed again to oneDNN upstream, but
|
||||
# pytorch still calls it by the old name mkldnn.
|
||||
USE_MKLDNN = mklDnnSupport;
|
||||
USE_MKLDNN_CBLAS = mklDnnSupport;
|
||||
|
||||
preBuild = ''
|
||||
export MAX_JOBS=$NIX_BUILD_CORES
|
||||
${python.interpreter} setup.py build --cmake-only
|
||||
@ -174,7 +181,6 @@ in buildPythonPackage rec {
|
||||
done
|
||||
'';
|
||||
|
||||
|
||||
# Override the (weirdly) wrong version set by default. See
|
||||
# https://github.com/NixOS/nixpkgs/pull/52437#issuecomment-449718038
|
||||
# https://github.com/pytorch/pytorch/blob/v1.0.0/setup.py#L267
|
||||
@ -199,7 +205,7 @@ in buildPythonPackage rec {
|
||||
ninja
|
||||
] ++ lib.optionals cudaSupport [ cudatoolkit_joined ];
|
||||
|
||||
buildInputs = [ blas ]
|
||||
buildInputs = [ blas blas.provider dnnl ]
|
||||
++ lib.optionals cudaSupport [ cudnn magma nccl ]
|
||||
++ lib.optionals stdenv.isLinux [ numactl ];
|
||||
|
||||
@ -214,10 +220,13 @@ in buildPythonPackage rec {
|
||||
|
||||
checkInputs = [ hypothesis ninja psutil ];
|
||||
|
||||
doCheck = false; # tests take a long time for channel release, so doCheck should be overridden only when developing
|
||||
# Tests take a long time and may be flaky, so just sanity-check imports
|
||||
doCheck = false;
|
||||
pythonImportsCheck = [
|
||||
"torch"
|
||||
];
|
||||
|
||||
checkPhase = with lib.versions; with lib.strings; concatStringsSep " " [
|
||||
# MKL 2019.5-only workaround. See: https://github.com/NixOS/nixpkgs/issues/75611
|
||||
(optionalString (blas.implementation == "mkl" && majorMinor blas.version == "2019.5") "KMP_INIT_AT_FORK=FALSE ")
|
||||
cudaStubEnv
|
||||
"${python.interpreter} test/run_test.py"
|
||||
"--exclude"
|
||||
|
Loading…
Reference in New Issue
Block a user