Skip to content

Commit 59a70e8

Browse files
Parity Agentnikhilsk
authored andcommitted
feat(docker): no support for --gpus flag in
1 parent 6abfecb commit 59a70e8

File tree

4 files changed

+381
-12
lines changed

4 files changed

+381
-12
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/**
2+
# Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package main
18+
19+
import (
20+
"github.com/ROCm/container-toolkit/internal/amdgpu"
21+
gpuTracker "github.com/ROCm/container-toolkit/internal/gpu-tracker"
22+
"github.com/opencontainers/runtime-spec/specs-go"
23+
)
24+
25+
// ociHookHandler wraps OCI functionality for hook-specific operations
26+
type ociHookHandler struct {
27+
spec *specs.Spec
28+
bundlePath string
29+
containerId string
30+
31+
reserveGPUs func(string, string) ([]int, error)
32+
}
33+
34+
func newOCIHandlerFromSpec(spec *specs.Spec, bundlePath, containerId string) (*ociHookHandler, error) {
35+
gpuTracker, err := gpuTracker.New()
36+
if err != nil {
37+
return nil, err
38+
}
39+
40+
return &ociHookHandler{
41+
spec: spec,
42+
bundlePath: bundlePath,
43+
containerId: containerId,
44+
reserveGPUs: gpuTracker.ReserveGPUs,
45+
}, nil
46+
}
47+
48+
func (h *ociHookHandler) InjectGPUs() error {
49+
// Extract AMD_VISIBLE_DEVICES from spec environment
50+
amdDevices, err := h.getRequestedDevices()
51+
if err != nil {
52+
return err
53+
}
54+
55+
if len(amdDevices) == 0 {
56+
return nil // No devices requested
57+
}
58+
59+
// Get all GPUs
60+
gpus, err := amdgpu.GetAMDGPUs()
61+
if err != nil {
62+
return err
63+
}
64+
65+
// Add each requested GPU
66+
for _, idx := range amdDevices {
67+
if idx >= len(gpus) {
68+
continue // Skip invalid indices
69+
}
70+
for _, drmDev := range gpus[idx].DrmDevices {
71+
gpu, err := amdgpu.GetAMDGPU(drmDev)
72+
if err != nil {
73+
return err
74+
}
75+
if err := h.addGPUDevice(gpu); err != nil {
76+
return err
77+
}
78+
}
79+
}
80+
81+
// Add /dev/kfd
82+
kfd, err := amdgpu.GetAMDGPU("/dev/kfd")
83+
if err != nil {
84+
return err
85+
}
86+
return h.addGPUDevice(kfd)
87+
}
88+
89+
func (h *ociHookHandler) getRequestedDevices() ([]int, error) {
90+
if h.spec.Process == nil {
91+
return []int{}, nil
92+
}
93+
94+
for _, env := range h.spec.Process.Env {
95+
if len(env) > 20 && env[:20] == "AMD_VISIBLE_DEVICES=" {
96+
gpuList := env[20:]
97+
return h.reserveGPUs(gpuList, h.containerId)
98+
}
99+
if len(env) > 16 && env[:16] == "DOCKER_RESOURCE_" {
100+
gpuList := env[16:]
101+
idx := 16
102+
for idx < len(env) && env[idx] != '=' {
103+
idx++
104+
}
105+
if idx < len(env) {
106+
gpuList = env[idx+1:]
107+
return h.reserveGPUs(gpuList, h.containerId)
108+
}
109+
}
110+
}
111+
return []int{}, nil
112+
}
113+
114+
func (h *ociHookHandler) addGPUDevice(gpu amdgpu.AMDGPU) error {
115+
// Add to devices list
116+
dev := specs.LinuxDevice{
117+
Path: gpu.Path,
118+
Type: gpu.DevType,
119+
Major: gpu.Major,
120+
Minor: gpu.Minor,
121+
FileMode: &gpu.FileMode,
122+
GID: &gpu.Gid,
123+
UID: &gpu.Uid,
124+
}
125+
126+
if h.spec.Linux == nil {
127+
h.spec.Linux = &specs.Linux{}
128+
}
129+
h.spec.Linux.Devices = append(h.spec.Linux.Devices, dev)
130+
131+
// Add to cgroup resources
132+
rdev := specs.LinuxDeviceCgroup{
133+
Allow: gpu.Allow,
134+
Type: gpu.DevType,
135+
Major: &gpu.Major,
136+
Minor: &gpu.Minor,
137+
Access: gpu.Access,
138+
}
139+
140+
if h.spec.Linux.Resources == nil {
141+
h.spec.Linux.Resources = &specs.LinuxResources{}
142+
}
143+
h.spec.Linux.Resources.Devices = append(h.spec.Linux.Resources.Devices, rdev)
144+
145+
return nil
146+
}
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
/**
2+
# Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package main
18+
19+
import (
20+
"encoding/json"
21+
"flag"
22+
"fmt"
23+
"os"
24+
"path/filepath"
25+
26+
gpuTracker "github.com/ROCm/container-toolkit/internal/gpu-tracker"
27+
"github.com/ROCm/container-toolkit/internal/logger"
28+
"github.com/opencontainers/runtime-spec/specs-go"
29+
)
30+
31+
var (
32+
debugFlag = flag.Bool("debug", false, "enable debug logging")
33+
versionFlag = flag.Bool("version", false, "print version and exit")
34+
)
35+
36+
// HookState represents the container state passed to the hook via stdin
37+
type HookState struct {
38+
Pid int `json:"pid,omitempty"`
39+
Bundle string `json:"bundle"`
40+
}
41+
42+
func main() {
43+
flag.Parse()
44+
45+
if *versionFlag {
46+
fmt.Printf("amd-container-runtime-hook version 1.0.0\n")
47+
return
48+
}
49+
50+
logger.Init(*debugFlag)
51+
logger.Log.Printf("AMD Container Runtime Hook started")
52+
53+
// OCI hooks receive container state on stdin
54+
var state HookState
55+
decoder := json.NewDecoder(os.Stdin)
56+
if err := decoder.Decode(&state); err != nil {
57+
logger.Log.Printf("Failed to decode container state: %v", err)
58+
os.Exit(1)
59+
}
60+
61+
logger.Log.Printf("Processing container in bundle: %s", state.Bundle)
62+
63+
// Load OCI spec
64+
specPath := filepath.Join(state.Bundle, "config.json")
65+
spec, err := loadSpec(specPath)
66+
if err != nil {
67+
logger.Log.Printf("Failed to load spec from %s: %v", specPath, err)
68+
os.Exit(1)
69+
}
70+
71+
// Check if GPU devices are requested
72+
if !needsGPUs(spec) {
73+
logger.Log.Printf("No AMD GPU devices requested, exiting")
74+
return
75+
}
76+
77+
// Create OCI handler with spec already loaded
78+
containerId := filepath.Base(state.Bundle)
79+
ociHandler, err := newOCIHandlerFromSpec(spec, state.Bundle, containerId)
80+
if err != nil {
81+
logger.Log.Printf("Failed to create OCI handler: %v", err)
82+
releaseGPUs(containerId)
83+
os.Exit(1)
84+
}
85+
86+
// Add GPU devices (without cleanup hook - Docker manages hook lifecycle)
87+
if err := ociHandler.InjectGPUs(); err != nil {
88+
logger.Log.Printf("Failed to inject GPUs: %v", err)
89+
releaseGPUs(containerId)
90+
os.Exit(1)
91+
}
92+
93+
// Write modified spec back
94+
if err := writeSpec(specPath, spec); err != nil {
95+
logger.Log.Printf("Failed to write spec to %s: %v", specPath, err)
96+
releaseGPUs(containerId)
97+
os.Exit(1)
98+
}
99+
100+
logger.Log.Printf("Successfully injected AMD GPU devices")
101+
}
102+
103+
func loadSpec(path string) (*specs.Spec, error) {
104+
file, err := os.Open(path)
105+
if err != nil {
106+
return nil, err
107+
}
108+
defer file.Close()
109+
110+
var spec specs.Spec
111+
if err := json.NewDecoder(file).Decode(&spec); err != nil {
112+
return nil, err
113+
}
114+
return &spec, nil
115+
}
116+
117+
func writeSpec(path string, spec *specs.Spec) error {
118+
file, err := os.Create(path)
119+
if err != nil {
120+
return err
121+
}
122+
defer file.Close()
123+
124+
encoder := json.NewEncoder(file)
125+
encoder.SetIndent("", " ")
126+
return encoder.Encode(spec)
127+
}
128+
129+
func needsGPUs(spec *specs.Spec) bool {
130+
if spec.Process == nil {
131+
return false
132+
}
133+
for _, env := range spec.Process.Env {
134+
if len(env) >= 20 && env[:20] == "AMD_VISIBLE_DEVICES=" {
135+
return true
136+
}
137+
if len(env) >= 16 && env[:16] == "DOCKER_RESOURCE_" {
138+
return true
139+
}
140+
}
141+
return false
142+
}
143+
144+
func releaseGPUs(containerId string) {
145+
tracker, err := gpuTracker.New()
146+
if err != nil {
147+
logger.Log.Printf("Failed to create GPU tracker for cleanup: %v", err)
148+
return
149+
}
150+
tracker.ReleaseGPUs(containerId)
151+
}

cmd/amd-ctk/runtime/runtime.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
package runtime
1818

1919
import (
20+
"fmt"
21+
"os"
22+
2023
"github.com/ROCm/container-toolkit/cmd/amd-ctk/runtime/configure"
2124
"github.com/urfave/cli/v2"
2225
)
@@ -31,7 +34,69 @@ func AddNewCommand() *cli.Command {
3134

3235
runtimeCmd.Subcommands = []*cli.Command{
3336
configure.AddNewCommand(),
37+
configureHookCommand(),
3438
}
3539

3640
return &runtimeCmd
3741
}
42+
43+
func configureHookCommand() *cli.Command {
44+
return &cli.Command{
45+
Name: "configure-hooks",
46+
Usage: "Configure OCI prestart hooks for --gpus flag support",
47+
Flags: []cli.Flag{
48+
&cli.StringFlag{
49+
Name: "hook-path",
50+
Usage: "Path to amd-container-runtime-hook binary",
51+
Value: "/usr/bin/amd-container-runtime-hook",
52+
},
53+
&cli.StringFlag{
54+
Name: "config-path",
55+
Usage: "Path to Docker daemon.json",
56+
Value: "/etc/docker/daemon.json",
57+
},
58+
&cli.BoolFlag{
59+
Name: "remove",
60+
Usage: "Remove hook configuration",
61+
},
62+
},
63+
Action: func(c *cli.Context) error {
64+
hookPath := c.String("hook-path")
65+
configPath := c.String("config-path")
66+
remove := c.Bool("remove")
67+
68+
if remove {
69+
fmt.Println("Removing hook configuration not yet implemented")
70+
return nil
71+
}
72+
73+
// Verify hook binary exists
74+
if _, err := os.Stat(hookPath); os.IsNotExist(err) {
75+
return fmt.Errorf("hook binary not found at %s", hookPath)
76+
}
77+
78+
fmt.Printf("Hook configuration:\n")
79+
fmt.Printf(" Hook binary: %s\n", hookPath)
80+
fmt.Printf(" Config file: %s\n", configPath)
81+
fmt.Printf("\nTo enable --gpus flag support, add this to %s:\n\n", configPath)
82+
fmt.Printf(`{
83+
"runtimes": {
84+
"amd": {
85+
"path": "amd-container-runtime",
86+
"runtimeArgs": []
87+
}
88+
},
89+
"hooks": {
90+
"prestart": [
91+
{
92+
"path": "%s"
93+
}
94+
]
95+
}
96+
}
97+
`, hookPath)
98+
fmt.Printf("\nThen restart Docker: sudo systemctl restart docker\n")
99+
return nil
100+
},
101+
}
102+
}

0 commit comments

Comments
 (0)