diff --git a/pkgs/development/python-modules/pytorch/default.nix b/pkgs/development/python-modules/pytorch/default.nix index 0b2525abfeb1..22a87b386e18 100644 --- a/pkgs/development/python-modules/pytorch/default.nix +++ b/pkgs/development/python-modules/pytorch/default.nix @@ -1,5 +1,5 @@ { stdenv, lib, fetchFromGitHub, fetchpatch, buildPythonPackage, python, - cudaSupport ? false, cudatoolkit ? null, cudnn ? null, nccl ? null, magma ? null, + cudaSupport ? false, cudatoolkit, cudnn, nccl, magma, mklDnnSupport ? true, useSystemNccl ? true, MPISupport ? false, mpi, buildDocs ? false, @@ -30,8 +30,6 @@ isPy3k, pythonOlder }: # assert that everything needed for cuda is present and that the correct cuda versions are used -assert !cudaSupport || cudatoolkit != null; -assert cudnn == null || cudatoolkit != null; assert !cudaSupport || (let majorIs = lib.versions.major cudatoolkit.version; in majorIs == "9" || majorIs == "10" || majorIs == "11");