1- final : prev : {
2- cmakeNvccThreadsHook = prev . callPackage ./pkgs/cmake-nvcc-threads-hook { } ;
3-
4- # Local packages
1+ final : prev :
2+ let
3+ # For XPU we use MKL from the joined oneAPI toolkit.
4+ useMKL = final . stdenv . isx86_64 && ! ( final . config . xpuSupport or false ) ;
5+ in
6+ {
7+ # Use MKL for BLAS/LAPACK on x86_64.
8+ blas = if useMKL then prev . blas . override { blasProvider = prev . mkl ; } else prev . blas ;
9+ lapack = if useMKL then prev . lapack . override { lapackProvider = prev . mkl ; } else prev . blas ;
510
611 build2cmake = prev . callPackage ./pkgs/build2cmake { } ;
712
13+ cmakeNvccThreadsHook = prev . callPackage ./pkgs/cmake-nvcc-threads-hook { } ;
14+
815 get-kernel-check = prev . callPackage ./pkgs/get-kernel-check { } ;
916
1017 kernel-abi-check = prev . callPackage ./pkgs/kernel-abi-check { } ;
1118
1219 kernel-layout-check = prev . callPackage ./pkgs/kernel-layout-check { } ;
1320
21+ # Used by ROCm.
22+ libffi_3_2 = final . libffi_3_3 . overrideAttrs (
23+ finalAttrs : _ : {
24+ version = "3.2.1" ;
25+ src = final . fetchurl {
26+ url = with finalAttrs ; "https://gcc.gnu.org/pub/${ pname } /${ pname } -${ version } .tar.gz" ;
27+ hash = "sha256-0G67jh2aItGeONY/24OVQlPzm+3F1GIyoFZFaFciyjc=" ;
28+ } ;
29+ }
30+ ) ;
31+
32+ magma = ( prev . callPackage ./pkgs/magma { } ) . magma ;
33+
34+ magma-hip =
35+ ( prev . callPackage ./pkgs/magma {
36+ cudaSupport = false ;
37+ rocmSupport = true ;
38+ } ) . magma ;
39+
40+ nvtx = final . callPackage ./pkgs/nvtx { } ;
41+
42+ metal-cpp = final . callPackage ./pkgs/metal-cpp { } ;
43+
1444 rewrite-nix-paths-macho = prev . callPackage ./pkgs/rewrite-nix-paths-macho { } ;
1545
1646 remove-bytecode-hook = prev . callPackage ./pkgs/remove-bytecode-hook { } ;
1747
1848 stdenvGlibc_2_27 = prev . callPackage ./pkgs/stdenv-glibc-2_27 { } ;
1949
50+ ucx = prev . ucx . overrideAttrs (
51+ _ : prevAttrs : {
52+ buildInputs = prevAttrs . buildInputs ++ [ final . cudaPackages . cuda_nvcc ] ;
53+ }
54+ ) ;
55+
2056 # Python packages
2157 pythonPackagesExtensions = prev . pythonPackagesExtensions ++ [
2258 (
@@ -53,6 +89,14 @@ final: prev: {
5389
5490 mkTorch = callPackage ./pkgs/python-modules/torch/binary { } ;
5591
92+ scipy = python-super . scipy . overrideAttrs (
93+ _ : prevAttrs : {
94+ # Three tests have a slight deviance.
95+ doCheck = false ;
96+ doInstallCheck = false ;
97+ }
98+ ) ;
99+
56100 torch-bin_2_8 = mkTorch {
57101 version = "2.8" ;
58102 xpuPackages = final . xpuPackages_2025_1 ;
@@ -70,7 +114,60 @@ final: prev: {
70114 torch_2_9 = callPackage ./pkgs/python-modules/torch/source/2_9 {
71115 xpuPackages = final . xpuPackages_2025_2 ;
72116 } ;
117+
118+ triton-xpu_2_8 = callPackage ./pkgs/python-modules/triton-xpu {
119+ torchVersion = "2.8" ;
120+ xpuPackages = final . xpuPackages_2025_1 ;
121+ } ;
122+
123+ triton-xpu_2_9 = callPackage ./pkgs/python-modules/triton-xpu {
124+ torchVersion = "2.9" ;
125+ xpuPackages = final . xpuPackages_2025_2 ;
126+ } ;
73127 }
74128 )
129+ ( import ./pkgs/python-modules/hooks )
75130 ] ;
131+
132+ xpuPackages = final . xpuPackages_2025_1 ;
76133}
134+ // ( import ./pkgs/cutlass { pkgs = final ; } )
135+ // (
136+ let
137+ flattenVersion = prev . lib . strings . replaceStrings [ "." ] [ "_" ] ;
138+ readPackageMetadata = path : ( builtins . fromJSON ( builtins . readFile path ) ) ;
139+ versions = [
140+ "6.3.4"
141+ "6.4.2"
142+ "7.0.1"
143+ ] ;
144+ newRocmPackages = final . callPackage ./pkgs/rocm-packages { } ;
145+ in
146+ builtins . listToAttrs (
147+ map ( version : {
148+ name = "rocmPackages_${ flattenVersion ( prev . lib . versions . majorMinor version ) } " ;
149+ value = newRocmPackages {
150+ packageMetadata = readPackageMetadata ./pkgs/rocm-packages/rocm-${ version } -metadata.json ;
151+ } ;
152+ } ) versions
153+ )
154+ )
155+ // (
156+ let
157+ flattenVersion = prev . lib . strings . replaceStrings [ "." ] [ "_" ] ;
158+ readPackageMetadata = path : ( builtins . fromJSON ( builtins . readFile path ) ) ;
159+ xpuVersions = [
160+ "2025.1.3"
161+ "2025.2.1"
162+ ] ;
163+ newXpuPackages = final . callPackage ./pkgs/xpu-packages { } ;
164+ in
165+ builtins . listToAttrs (
166+ map ( version : {
167+ name = "xpuPackages_${ flattenVersion ( prev . lib . versions . majorMinor version ) } " ;
168+ value = newXpuPackages {
169+ packageMetadata = readPackageMetadata ./pkgs/xpu-packages/intel-deep-learning-${ version } .json ;
170+ } ;
171+ } ) xpuVersions
172+ )
173+ )
0 commit comments