Skip to content

Commit 3dfcfb5

Browse files
committed
Copy Torch from hf-nix
We have decided to make kernel-builder standalone to make it easier to maintain Torch versions as part of the same repo. This copies over Torch from hf-nix. Other bits are still needed, but let's start somewhere.
1 parent 7a24559 commit 3dfcfb5

20 files changed

+2998
-0
lines changed

overlay.nix

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,26 @@ final: prev: {
5050
});
5151

5252
pyclibrary = python-self.callPackage ./pkgs/python-modules/pyclibrary { };
53+
54+
mkTorch = callPackage ./pkgs/python-modules/torch/binary { };
55+
56+
torch-bin_2_8 = mkTorch {
57+
version = "2.8";
58+
xpuPackages = final.xpuPackages_2025_1;
59+
};
60+
61+
torch-bin_2_9 = mkTorch {
62+
version = "2.9";
63+
xpuPackages = final.xpuPackages_2025_2;
64+
};
65+
66+
torch_2_8 = callPackage ./pkgs/python-modules/torch/source/2_8 {
67+
xpuPackages = final.xpuPackages_2025_1;
68+
};
69+
70+
torch_2_9 = callPackage ./pkgs/python-modules/torch/source/2_9 {
71+
xpuPackages = final.xpuPackages_2025_2;
72+
};
5373
}
5474
)
5575
];
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
{
2+
"2.8" = {
3+
# https://github.com/pytorch/pytorch/blob/release/2.8/.ci/manywheel/build_cuda.sh
4+
capsPerCudaVersion = {
5+
"12.9" = [
6+
"7.0"
7+
"7.5"
8+
"8.0"
9+
"8.6"
10+
"9.0"
11+
"10.0"
12+
"12.0"
13+
];
14+
"12.8" = [
15+
"7.0"
16+
"7.5"
17+
"8.0"
18+
"8.6"
19+
"9.0"
20+
"10.0"
21+
"12.0"
22+
];
23+
"12.6" = [
24+
"5.0"
25+
"6.0"
26+
"7.0"
27+
"7.5"
28+
"8.0"
29+
"8.6"
30+
"9.0"
31+
];
32+
# Not a supported upstream configuration, but keep it around for
33+
# builds that fail on newer CUDA versions.
34+
"12.4" = [
35+
"5.0"
36+
"6.0"
37+
"7.0"
38+
"7.5"
39+
"8.0"
40+
"8.6"
41+
"9.0"
42+
];
43+
};
44+
# https://github.com/pytorch/pytorch/blob/ba56102387ef21a3b04b357e5b183d48f0afefc7/.ci/docker/manywheel/build.sh#L82
45+
supportedTorchRocmArchs = [
46+
"gfx900"
47+
"gfx906"
48+
"gfx908"
49+
"gfx90a"
50+
"gfx942"
51+
"gfx1030"
52+
"gfx1100"
53+
"gfx1101"
54+
"gfx1102"
55+
"gfx1200"
56+
"gfx1201"
57+
];
58+
};
59+
60+
"2.9" = {
61+
# https://github.com/pytorch/pytorch/blob/release/2.9/.ci/manywheel/build_cuda.sh
62+
capsPerCudaVersion = {
63+
"13.0" = [
64+
"7.5"
65+
"8.0"
66+
"8.6"
67+
"9.0"
68+
"10.0"
69+
"12.0"
70+
];
71+
# NOTE: 12.9 does not seem to be in RC builds, check if needed for final release.
72+
# https://download.pytorch.org/whl/test/torch/
73+
"12.9" = [
74+
"7.0"
75+
"7.5"
76+
"8.0"
77+
"8.6"
78+
"9.0"
79+
"10.0"
80+
"12.0"
81+
];
82+
"12.8" = [
83+
"7.0"
84+
"7.5"
85+
"8.0"
86+
"8.6"
87+
"9.0"
88+
"10.0"
89+
"12.0"
90+
];
91+
"12.6" = [
92+
"5.0"
93+
"6.0"
94+
"7.0"
95+
"7.5"
96+
"8.0"
97+
"8.6"
98+
"9.0"
99+
];
100+
# Not a supported upstream configuration, but keep it around for
101+
# builds that fail on newer CUDA versions.
102+
"12.4" = [
103+
"5.0"
104+
"6.0"
105+
"7.0"
106+
"7.5"
107+
"8.0"
108+
"8.6"
109+
"9.0"
110+
];
111+
};
112+
113+
supportedTorchRocmArchs = [
114+
# https://github.com/pytorch/pytorch/blob/21fec65781bebe867faf209f89bb687ffd236ca4/.ci/docker/manywheel/build.sh#L92
115+
"gfx900"
116+
"gfx906"
117+
"gfx908"
118+
"gfx90a"
119+
"gfx942"
120+
"gfx1030"
121+
"gfx1100"
122+
"gfx1101"
123+
"gfx1102"
124+
"gfx1200"
125+
"gfx1201"
126+
];
127+
};
128+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
{
2+
config,
3+
lib,
4+
stdenv,
5+
6+
cudaSupport ? config.cudaSupport,
7+
rocmSupport ? config.rocmSupport,
8+
xpuSupport ? (config.xpuSupport or false),
9+
10+
callPackage,
11+
cudaPackages,
12+
rocmPackages,
13+
}:
14+
15+
{
16+
xpuPackages,
17+
version,
18+
}:
19+
20+
let
21+
system = stdenv.hostPlatform.system;
22+
flattenVersion = version: lib.replaceStrings [ "." ] [ "" ] (lib.versions.pad 2 version);
23+
framework =
24+
if cudaSupport then
25+
"cu${flattenVersion cudaPackages.cudaMajorMinorVersion}"
26+
else if rocmSupport then
27+
"rocm${flattenVersion (lib.versions.majorMinor rocmPackages.rocm.version)}"
28+
else if xpuSupport then
29+
"xpu"
30+
else
31+
"cpu";
32+
torchVersions = builtins.fromJSON (builtins.readFile ./torch-versions-hash.json);
33+
torchBySystem = torchVersions.${version} or (throw "Unsupported torch version: ${version}");
34+
torchByFramework =
35+
torchBySystem.${system} or (throw "Unsupported system: ${system} for torch version: ${version}");
36+
urlHash =
37+
torchByFramework.${framework}
38+
or (throw "Unsupported framework: ${framework} for torch version: ${version} on system: ${system}");
39+
in
40+
callPackage ./generic.nix {
41+
inherit xpuPackages;
42+
inherit (urlHash) url hash version;
43+
}

0 commit comments

Comments
 (0)