From cfdcce705291ebb4659dfaf8684f2cf950ecdebb Mon Sep 17 00:00:00 2001 From: Sam Stites Date: Sun, 20 Oct 2019 17:05:35 -0400 Subject: [PATCH] openmpi: enable cuda support --- .../development/libraries/openmpi/default.nix | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/pkgs/development/libraries/openmpi/default.nix b/pkgs/development/libraries/openmpi/default.nix index 2a293e0b8562..5996dda1e74e 100644 --- a/pkgs/development/libraries/openmpi/default.nix +++ b/pkgs/development/libraries/openmpi/default.nix @@ -1,5 +1,8 @@ { stdenv, fetchurl, fetchpatch, gfortran, perl, libnl -, rdma-core, zlib, numactl, libevent, hwloc, pkgsTargetTarget +, rdma-core, zlib, numactl, libevent, hwloc, pkgsTargetTarget, symlinkJoin + +# Enable CUDA support +, cudaSupport ? false, cudatoolkit ? null # Enable the Sun Grid Engine bindings , enableSGE ? false @@ -8,9 +11,15 @@ , enablePrefix ? false }: +assert !cudaSupport || cudatoolkit != null; + let version = "4.0.2"; + cudatoolkit_joined = symlinkJoin { + name = "${cudatoolkit.name}-unsplit"; + paths = [ cudatoolkit.out cudatoolkit.lib ]; + }; in stdenv.mkDerivation rec { pname = "openmpi"; inherit version; @@ -33,15 +42,20 @@ in stdenv.mkDerivation rec { buildInputs = with stdenv; [ gfortran zlib ] ++ lib.optionals isLinux [ libnl numactl ] + ++ lib.optionals cudaSupport [ cudatoolkit ] ++ [ libevent hwloc ] ++ lib.optional (isLinux || isFreeBSD) rdma-core; nativeBuildInputs = [ perl ]; - configureFlags = with stdenv; [ "--disable-mca-dso" ] + configureFlags = with stdenv; lib.optional (!cudaSupport) "--disable-mca-dso" ++ lib.optional isLinux "--with-libnl=${libnl.dev}" ++ lib.optional enableSGE "--with-sge" ++ lib.optional enablePrefix "--enable-mpirun-prefix-by-default" + # TODO: add UCX support, which is recommended to use with cuda for the most robust OpenMPI build + # https://github.com/openucx/ucx + # https://www.open-mpi.org/faq/?category=buildcuda + ++ lib.optionals cudaSupport [ "--with-cuda=${cudatoolkit_joined}" "--enable-dlopen" ] ; enableParallelBuilding = true; @@ -69,6 +83,10 @@ in stdenv.mkDerivation rec { doCheck = true; + passthru = { + inherit cudaSupport cudatoolkit; + }; + meta = with stdenv.lib; { homepage = https://www.open-mpi.org/; description = "Open source MPI-3 implementation";