Merge pull request #26580 from twhitehead/theano-cleanup
Theano cleanup
This commit is contained in:
commit
52e7817aa2
@ -1,8 +1,8 @@
|
||||
{ stdenv
|
||||
, runCommandCC
|
||||
, lib
|
||||
, fetchPypi
|
||||
, gcc
|
||||
, writeScriptBin
|
||||
, buildPythonPackage
|
||||
, isPyPy
|
||||
, pythonOlder
|
||||
@ -24,19 +24,24 @@ assert cudaSupport -> nvidia_x11 != null
|
||||
&& cudnn != null;
|
||||
|
||||
let
|
||||
extraFlags =
|
||||
lib.optionals cudaSupport [ "-I ${cudatoolkit}/include" "-L ${cudatoolkit}/lib" ]
|
||||
++ lib.optionals cudnnSupport [ "-I ${cudnn}/include" "-L ${cudnn}/lib" ]
|
||||
++ lib.optionals cudaSupport [ "-I ${libgpuarray}/include" "-L ${libgpuarray}/lib" ];
|
||||
wrapped = command: buildTop: buildInputs:
|
||||
runCommandCC "${command}-wrapped" { inherit buildInputs; } ''
|
||||
type -P '${command}' || { echo '${command}: not found'; exit 1; }
|
||||
cat > "$out" <<EOF
|
||||
#!$(type -P bash)
|
||||
$(declare -xp | sed -e '/^[^=]\+="\('"''${NIX_STORE//\//\\/}"'\|[^\/]\)/!d')
|
||||
declare -x NIX_BUILD_TOP="${buildTop}"
|
||||
$(type -P '${command}') "\$@"
|
||||
EOF
|
||||
chmod +x "$out"
|
||||
'';
|
||||
|
||||
gcc_ = writeScriptBin "g++" ''
|
||||
#!${stdenv.shell}
|
||||
export NIX_CC_WRAPPER_${stdenv.cc.infixSalt}_TARGET_HOST=1
|
||||
export NIX_CFLAGS_COMPILE="$NIX_CFLAGS_COMPILE ${toString extraFlags}"
|
||||
exec ${gcc}/bin/g++ "$@"
|
||||
'';
|
||||
# Theano spews warnings and disabled flags if the compiler isn't named g++
|
||||
cxx_compiler = wrapped "g++" "\\$HOME/.theano"
|
||||
( stdenv.lib.optional cudaSupport libgpuarray_
|
||||
++ stdenv.lib.optional cudnnSupport cudnn );
|
||||
|
||||
libgpuarray_ = libgpuarray.override { inherit cudaSupport; };
|
||||
libgpuarray_ = libgpuarray.override { inherit cudaSupport cudatoolkit; };
|
||||
|
||||
in buildPythonPackage rec {
|
||||
pname = "Theano";
|
||||
@ -50,12 +55,15 @@ in buildPythonPackage rec {
|
||||
};
|
||||
|
||||
postPatch = ''
|
||||
sed -i 's,g++,${gcc_}/bin/g++,g' theano/configdefaults.py
|
||||
'' + lib.optionalString cudnnSupport ''
|
||||
sed -i \
|
||||
-e "s,ctypes.util.find_library('cudnn'),'${cudnn}/lib/libcudnn.so',g" \
|
||||
-e "s/= _dnn_check_compile()/= (True, None)/g" \
|
||||
theano/gpuarray/dnn.py
|
||||
substituteInPlace theano/configdefaults.py \
|
||||
--replace 'StrParam(param, is_valid=warn_cxx)' 'StrParam('\'''${cxx_compiler}'\''', is_valid=warn_cxx)' \
|
||||
--replace 'rc == 0 and config.cxx != ""' 'config.cxx != ""'
|
||||
'' + stdenv.lib.optionalString cudaSupport ''
|
||||
substituteInPlace theano/configdefaults.py \
|
||||
--replace 'StrParam(get_cuda_root)' 'StrParam('\'''${cudatoolkit}'\''')'
|
||||
'' + stdenv.lib.optionalString cudnnSupport ''
|
||||
substituteInPlace theano/configdefaults.py \
|
||||
--replace 'StrParam(default_dnn_base_path)' 'StrParam('\'''${cudnn}'\''')'
|
||||
'';
|
||||
|
||||
preCheck = ''
|
||||
|
Loading…
Reference in New Issue
Block a user