Merge pull request #286948 from GaetanLepage/jax

python311Packages.{jax,jaxlib,jaxlib-bin}: 0.4.23 -> 0.4.24
This commit is contained in:
Samuel Ainsworth 2024-02-14 00:32:41 -05:00 committed by GitHub
commit 34be29dfc2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 43 additions and 28 deletions

View file

@ -28,7 +28,7 @@ let
in
buildPythonPackage rec {
pname = "jax";
version = "0.4.23";
version = "0.4.24";
pyproject = true;
disabled = pythonOlder "3.9";
@ -38,7 +38,7 @@ buildPythonPackage rec {
repo = "jax";
# google/jax contains tags for jax and jaxlib. Only use jax tags!
rev = "refs/tags/${pname}-v${version}";
hash = "sha256-PDa3yVH/sszGbWkVkJ+19FdOr3oqdYk+OdbeUTMTDuU=";
hash = "sha256-hmx7eo3pephc6BQfoJ3U0QwWBWmhkAc+7S4QmW32qQs=";
};
nativeBuildInputs = [
@ -89,6 +89,9 @@ buildPythonPackage rec {
"testKde3"
"testKde5"
"testKde6"
# Invokes python manually in a subprocess, which does not have the correct dependencies
# ImportError: This version of jax requires jaxlib version >= 0.4.19.
"test_no_log_spam"
] ++ lib.optionals usingMKL [
# See
# * https://github.com/google/jax/issues/9705

View file

@ -35,7 +35,7 @@
let
inherit (cudaPackagesGoogle) cudatoolkit cudnn cudaVersion;
version = "0.4.23";
version = "0.4.24";
inherit (python) pythonVersion;
@ -56,65 +56,65 @@ let
"3.9-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp39";
hash = "sha256-maN9RzK6/hYIuPRd8n8n5qa/HyPgAf6UD+mlqzZ1/Xc=";
hash = "sha256-6P5ArMoLZiUkHUoQ/mJccbNj5/7el/op+Qo6cGQ33xE=";
};
"3.9-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp39";
hash = "sha256-gdb07c12HCfK5VXT2C+9lYKSiIpPgD8sNmd4eG2M6M4=";
hash = "sha256-23JQZRwMLtt7sK/JlCBqqRyfTVIAVJFN2sL+nAkQgvU=";
};
"3.9-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp39";
hash = "sha256-TdU4wEoqEhsDq18MuLEpmKqpU51+xUYp/reZqEDJK1U=";
hash = "sha256-OgMedn9GHGs5THZf3pkP3Aw/jJ0vL5qK1b+Lzf634Ik=";
};
"3.10-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp310";
hash = "sha256-cnX75aSJxoPFUCYD1V5QgyPNovS9lSGqg4PGdPsKsvM=";
hash = "sha256-/VwUIIa7mTs/wLz0ArsEfNrz2pGriVVT5GX9XRFRxfY=";
};
"3.10-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp310";
hash = "sha256-H9sbeR4+4XytREYLP0LJphqGkQqHcinTC9NlT0Rj1aA=";
hash = "sha256-LgICOyDGts840SQQJh+yOMobMASb62llvJjpGvhzrSw=";
};
"3.10-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp310";
hash = "sha256-43VuBgGvdjauWPQtJK9w5GBI/++JvV4FwwO4maIXfDY=";
hash = "sha256-vhyULw+zBpz1UEi2tqgBMQEzY9a6YBgEIg6A4PPh3bQ=";
};
"3.11-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp311";
hash = "sha256-mEdm0wmyHKg4RlA7q9/e1OOu+BfGcPKBCSvLwXfFhJI=";
hash = "sha256-VJO/VVwBFkOEtq4y/sLVgAV8Cung01JULiuT6W96E/8=";
};
"3.11-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp311";
hash = "sha256-1kb/m8DODrtXO2drIfpttCLC72oNVszADbSDspllQVs=";
hash = "sha256-VtuwXxurpSp1KI8ty1bizs5cdy8GEBN2MgS227sOCmE=";
};
"3.11-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp311";
hash = "sha256-jhLX4ps+EtU1sku722v51mz2SSamo4/dkdRWX3zFcRE=";
hash = "sha256-4Dj5dEGKb9hpg3HlVogNO1Gc9UibJhy1eym2mjivxAQ=";
};
"3.12-x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
dist = "cp312";
hash = "sha256-oimiuQopgN1oKhbDc7SsRJPnA6JiEI9UieikWR2qpVk=";
hash = "sha256-TlrGVtb3NTLmhnILWPLJR+jISCZ5SUV4wxNFpSfkCBo=";
};
"3.12-aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
dist = "cp312";
hash = "sha256-J4zaKcx0c0Bgk7w/n6klqDlgY9IqTNINfg6g033LUDk=";
hash = "sha256-FIwK5CGykQjteuWzLZnbtAggIxLQeGV96bXlZGEytN0=";
};
"3.12-x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
dist = "cp312";
hash = "sha256-UFEE/mBitEOVUoijhUfphyy24QfWPZ+FQPsQ0cjY79A=";
hash = "sha256-9/jw/wr6oUD9pOadVAaMRL086iVMUXwVgnUMcG1UNvE=";
};
};
@ -130,31 +130,35 @@ let
gpuSrcs = {
"cuda12.2-3.9" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp39-cp39-manylinux2014_x86_64.whl";
hash = "sha256-our2mSwHPdjVoDAZP+9aNUkJ+vxv1Tq7G5UqA9HvhNI=";
hash = "sha256-xdJKLPtx+CIza2CrWKM3M0cZJzyNFVTTTsvlgh38bfM=";
};
"cuda12.2-3.10" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-jkIABnJZnn7A6n9VGs/MldzdDiKwWh0fEvl7Vqn85Kg=";
hash = "sha256-QCjrOczD2mp+CDwVXBc0/4rJnAizeV62AK0Dpx9X6TE=";
};
"cuda12.2-3.11" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl";
hash = "sha256-dMUcRnHjl8NyUeO3P1x7CNgF0iAHFKIzUtHh+/CNkow=";
hash = "sha256-Ipy3vk1yUplpNzECAFt63aOIhgEWgXG7hkoeTIk9bQQ=";
};
"cuda12.2-3.12" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-${version}+cuda12.cudnn89-cp312-cp312-manylinux2014_x86_64.whl";
hash = "sha256-kXJ6bUwX+QybqYPV9Kpwv+lhdoGEFRr4+1T0vfXoWRo=";
hash = "sha256-LSnZHaUga/8Z65iKXWBnZDk4yUpNykFTu3vukCchO6Q=";
};
"cuda11.8-3.9" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp39-cp39-manylinux2014_x86_64.whl";
hash = "sha256-m2Y5p12gF3OaADu+aGw5RjcKFrj9RB8xzNWnKNpSz60=";
hash = "sha256-UmyugL0VjlXkiD7fuDPWgW8XUpr/QaP5ggp6swoZTzU=";
};
"cuda11.8-3.10" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-aQ7iX3o0kQ4liPexv7dkBVWVTUpaty83L083MybGkf0=";
hash = "sha256-luKULEiV1t/sO6eckDxddJTiOFa0dtJeDlrvp+WYmHk=";
};
"cuda11.8-3.11" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp311-cp311-manylinux2014_x86_64.whl";
hash = "sha256-uIEyjEmv0HBaiYVl5PuICTI9XnH4zAfQ1l9tjALRcP4=";
hash = "sha256-4+uJ8Ij6mFGEmjFEgi3fLnSLZs+v18BRoOt7mZuqydw=";
};
"cuda11.8-3.12" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp312-cp312-manylinux2014_x86_64.whl";
hash = "sha256-bUDFb94Ar/65SzzR9RLIs/SL/HdjaPT1Su5whmjkS00=";
};
};

View file

@ -53,7 +53,7 @@ let
inherit (cudaPackagesGoogle) backendStdenv cudatoolkit cudaFlags cudnn nccl;
pname = "jaxlib";
version = "0.4.23";
version = "0.4.24";
meta = with lib; {
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@ -151,7 +151,7 @@ let
repo = "jax";
# google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
rev = "refs/tags/${pname}-v${version}";
hash = "sha256-PDa3yVH/sszGbWkVkJ+19FdOr3oqdYk+OdbeUTMTDuU=";
hash = "sha256-hmx7eo3pephc6BQfoJ3U0QwWBWmhkAc+7S4QmW32qQs=";
};
nativeBuildInputs = [
@ -195,7 +195,12 @@ let
'';
bazelRunTarget = "//jaxlib/tools:build_wheel";
runTargetFlags = [ "--output_path=$out" "--cpu=${arch}" ];
runTargetFlags = [
"--output_path=$out"
"--cpu=${arch}"
# This has no impact whatsoever...
"--jaxlib_git_hash='12345678'"
];
removeRulesCC = false;
@ -263,10 +268,10 @@ let
];
sha256 = (if cudaSupport then {
x86_64-linux = "sha256-q2wRaoCGnISEdtF6jDMk9Wccy/wTmLusVBI7dDATwi4=";
x86_64-linux = "sha256-c0avcURLAYNiLASjIeu5phXX3ze5TR812SW5SCG/iwk=";
} else {
x86_64-linux = "sha256-0cDJ27HCi3J5xeT6TkTtfUzF/yESBYmEVG1r14kPdRs=";
aarch64-linux = "sha256-WbaN8VYjeW0mDthmtoSTttqd4K/Z8dP5+VkTo10pLtU=";
x86_64-linux = "sha256-1hrQ9ehFy3vBJxKNUzi/T0l+eZxo26Th7i5VRd/9U+0=";
aarch64-linux = "sha256-3QVYJOj1lNHgYVV9rOzVdfhq5q6GDwpcWCjKNrSZ4aU=";
}).${stdenv.system} or (throw "jaxlib: unsupported system: ${stdenv.system}");
};

View file

@ -62,6 +62,9 @@ buildPythonPackage rec {
"test_zero_inflated_logits_probs_agree"
# NameError: unbound axis name: _provenance
"test_model_transformation"
# Using deprecated (removed in jax==0.4.24) jax.core.safe_map
# https://github.com/pyro-ppl/numpyro/issues/1733
"test_beta_bernoulli"
];
# TODO: remove when tensorflow-probability gets fixed.