Skip to content

Commit 7933944

Browse files
Parity Agentnikhilsk
authored andcommitted
feat(docker): no support for --gpus flag in
1 parent 6abfecb commit 7933944

File tree

5 files changed

+277
-12
lines changed

5 files changed

+277
-12
lines changed

cmd/amd-ctk/runtime/runtime.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
package runtime
1818

1919
import (
20+
"fmt"
21+
"os"
22+
2023
"github.com/ROCm/container-toolkit/cmd/amd-ctk/runtime/configure"
2124
"github.com/urfave/cli/v2"
2225
)
@@ -31,7 +34,24 @@ func AddNewCommand() *cli.Command {
3134

3235
runtimeCmd.Subcommands = []*cli.Command{
3336
configure.AddNewCommand(),
37+
addConfigureHookCommand(),
3438
}
3539

3640
return &runtimeCmd
3741
}
42+
43+
func addConfigureHookCommand() *cli.Command {
44+
return &cli.Command{
45+
Name: "configure-hook",
46+
Usage: "Install amd-container-runtime-hook as OCI hook",
47+
Action: func(c *cli.Context) error {
48+
hookPath := "/usr/bin/amd-container-runtime-hook"
49+
if _, err := os.Stat(hookPath); os.IsNotExist(err) {
50+
return fmt.Errorf("hook binary not found at %s", hookPath)
51+
}
52+
fmt.Printf("AMD Container Runtime Hook is available at: %s\n", hookPath)
53+
fmt.Println("Add this hook to your runtime configuration to enable --gpus flag support")
54+
return nil
55+
},
56+
}
57+
}

cmd/container-runtime-hook/hook.go

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/**
2+
# Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package main
18+
19+
import (
20+
"encoding/json"
21+
"fmt"
22+
"io/ioutil"
23+
"os"
24+
"strings"
25+
26+
gpuTracker "github.com/ROCm/container-toolkit/internal/gpu-tracker"
27+
"github.com/ROCm/container-toolkit/internal/logger"
28+
"github.com/ROCm/container-toolkit/internal/oci"
29+
)
30+
31+
func doPrestart() error {
32+
logger.Log.Println("Running prestart hook")
33+
34+
// Read hook state from stdin (Docker/containerd provides this)
35+
hookState, err := ioutil.ReadAll(os.Stdin)
36+
if err != nil {
37+
return fmt.Errorf("failed to read hook state from stdin: %v", err)
38+
}
39+
40+
// Create OCI interface for hook context
41+
ociInterface, err := oci.NewFromStdin()
42+
if err != nil {
43+
return fmt.Errorf("failed to create OCI interface: %v", err)
44+
}
45+
46+
// Load spec from bundle path in hook state
47+
ociImpl, ok := ociInterface.(*oci.oci_t)
48+
if !ok {
49+
return fmt.Errorf("failed to cast OCI interface to oci_t")
50+
}
51+
52+
if err := ociImpl.LoadSpecFromHookState(hookState); err != nil {
53+
return fmt.Errorf("failed to load spec from hook state: %v", err)
54+
}
55+
56+
// Check if GPU devices are requested
57+
spec := ociInterface.GetSpec()
58+
if spec == nil || spec.Process == nil {
59+
logger.Log.Println("No process spec found, skipping GPU configuration")
60+
return nil
61+
}
62+
63+
hasGPURequest := false
64+
for _, env := range spec.Process.Env {
65+
if strings.HasPrefix(env, "AMD_VISIBLE_DEVICES=") ||
66+
strings.HasPrefix(env, "DOCKER_RESOURCE_") {
67+
hasGPURequest = true
68+
break
69+
}
70+
}
71+
72+
if !hasGPURequest {
73+
logger.Log.Println("No GPU devices requested, skipping configuration")
74+
return nil
75+
}
76+
77+
// Add GPU devices to spec
78+
if err := ociInterface.UpdateSpec(oci.AddGPUDevices); err != nil {
79+
return fmt.Errorf("failed to add GPU devices: %v", err)
80+
}
81+
82+
// Write updated spec back
83+
if err := ociInterface.WriteSpec(); err != nil {
84+
return fmt.Errorf("failed to write updated spec: %v", err)
85+
}
86+
87+
logger.Log.Println("Successfully configured GPU devices")
88+
return nil
89+
}
90+
91+
func doPoststop() error {
92+
logger.Log.Println("Running poststop hook")
93+
94+
// Read hook state to get container ID
95+
hookState, err := ioutil.ReadAll(os.Stdin)
96+
if err != nil {
97+
return fmt.Errorf("failed to read hook state from stdin: %v", err)
98+
}
99+
100+
var state struct {
101+
ID string `json:"id"`
102+
}
103+
if err := json.Unmarshal(hookState, &state); err != nil {
104+
return fmt.Errorf("failed to parse hook state: %v", err)
105+
}
106+
107+
// Release GPUs via tracker
108+
tracker, err := gpuTracker.New()
109+
if err != nil {
110+
return fmt.Errorf("failed to create GPU tracker: %v", err)
111+
}
112+
113+
tracker.ReleaseGPUs(state.ID)
114+
logger.Log.Printf("Released GPUs for container %s", state.ID)
115+
return nil
116+
}

cmd/container-runtime-hook/main.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/**
2+
# Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package main
18+
19+
import (
20+
"flag"
21+
"fmt"
22+
"os"
23+
24+
"github.com/ROCm/container-toolkit/internal/logger"
25+
)
26+
27+
var (
28+
versionFlag = flag.Bool("version", false, "Display version information")
29+
)
30+
31+
func main() {
32+
flag.Parse()
33+
logger.Init(false)
34+
35+
if *versionFlag {
36+
fmt.Println("AMD Container Runtime Hook version 1.0.0")
37+
return
38+
}
39+
40+
args := flag.Args()
41+
if len(args) == 0 {
42+
fmt.Fprintf(os.Stderr, "Usage: amd-container-runtime-hook <command>\n")
43+
fmt.Fprintf(os.Stderr, "Commands:\n")
44+
fmt.Fprintf(os.Stderr, " prestart - Configure GPU devices before container start\n")
45+
fmt.Fprintf(os.Stderr, " poststop - Release GPU resources after container stop\n")
46+
os.Exit(2)
47+
}
48+
49+
command := args[0]
50+
switch command {
51+
case "prestart":
52+
if err := doPrestart(); err != nil {
53+
logger.Log.Printf("prestart hook failed: %v", err)
54+
os.Exit(1)
55+
}
56+
case "poststop":
57+
if err := doPoststop(); err != nil {
58+
logger.Log.Printf("poststop hook failed: %v", err)
59+
os.Exit(1)
60+
}
61+
case "poststart":
62+
// No-op for compatibility
63+
os.Exit(0)
64+
default:
65+
fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command)
66+
os.Exit(2)
67+
}
68+
}

cmd/container-runtime/main.go

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,28 @@ func main() {
3131
rt, err := runtime.New(os.Args)
3232
if err != nil {
3333
logger.Log.Printf("Failed to create container runtime, err = %v", err)
34-
gpuTracker, err := gpuTracker.New()
35-
if err != nil {
36-
logger.Log.Printf("Failed to create GPU tracker, err = %v", err)
37-
os.Exit(1)
38-
}
39-
gpuTracker.ReleaseGPUs(os.Args[len(os.Args)-1])
34+
releaseGPUsOnError(os.Args)
4035
os.Exit(1)
4136
}
4237

4338
logger.Log.Printf("Running ROCm container runtime")
4439
err = rt.Run()
4540
if err != nil {
4641
logger.Log.Printf("Failed to run container runtime, err = %v", err)
47-
gpuTracker, err := gpuTracker.New()
48-
if err != nil {
49-
logger.Log.Printf("Failed to create GPU tracker, err = %v", err)
50-
os.Exit(1)
51-
}
52-
gpuTracker.ReleaseGPUs(os.Args[len(os.Args)-1])
42+
releaseGPUsOnError(os.Args)
5343
os.Exit(1)
5444
}
5545
}
46+
47+
func releaseGPUsOnError(args []string) {
48+
if len(args) == 0 {
49+
return
50+
}
51+
containerId := args[len(args)-1]
52+
gpuTracker, err := gpuTracker.New()
53+
if err != nil {
54+
logger.Log.Printf("Failed to create GPU tracker, err = %v", err)
55+
return
56+
}
57+
gpuTracker.ReleaseGPUs(containerId)
58+
}

internal/oci/oci.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ type Interface interface {
5050

5151
// PrintSpec prints the current spec on the console
5252
PrintSpec() error
53+
54+
// GetSpec returns the loaded OCI spec
55+
GetSpec() *specs.Spec
56+
57+
// GetContainerId returns the container ID
58+
GetContainerId() string
5359
}
5460

5561
// GetGPUs is the type for functions that return the lists of all the GPU devices on the system
@@ -362,6 +368,48 @@ func New(argv []string) (Interface, error) {
362368
return oci, nil
363369
}
364370

371+
// NewFromStdin creates OCI interface from hook state read from stdin (for hook usage)
372+
func NewFromStdin() (Interface, error) {
373+
gpuTracker, err := gpuTracker.New()
374+
if err != nil {
375+
return nil, err
376+
}
377+
378+
oci := &oci_t{
379+
hookPath: DEFAULT_HOOK_PATH,
380+
getGPUs: amdgpu.GetAMDGPUs,
381+
getGPU: amdgpu.GetAMDGPU,
382+
getUniqueIdToDeviceIndexMap: amdgpu.GetUniqueIdToDeviceIndexMap,
383+
reserveGPUs: gpuTracker.ReserveGPUs,
384+
}
385+
386+
return oci, nil
387+
}
388+
389+
// LoadSpecFromHookState reads OCI spec from hook state provided on stdin
390+
func (oci *oci_t) LoadSpecFromHookState(hookState []byte) error {
391+
var state struct {
392+
Pid int `json:"pid,omitempty"`
393+
Bundle string `json:"bundle"`
394+
BundlePath string `json:"bundlePath"`
395+
ID string `json:"id"`
396+
}
397+
398+
if err := json.Unmarshal(hookState, &state); err != nil {
399+
return fmt.Errorf("failed to decode hook state: %v", err)
400+
}
401+
402+
oci.containerId = state.ID
403+
bundlePath := state.Bundle
404+
if bundlePath == "" {
405+
bundlePath = state.BundlePath
406+
}
407+
oci.origSpecPath = bundlePath
408+
oci.updatedSpecPath = bundlePath
409+
410+
return oci.getSpec()
411+
}
412+
365413
// HasHelpOption returns true if the arguments passed include the help option
366414
func (oci *oci_t) HasHelpOption() bool {
367415
return oci.hasHelpOption
@@ -419,3 +467,13 @@ func (oci *oci_t) PrintSpec() error {
419467

420468
return nil
421469
}
470+
471+
// GetSpec returns the loaded OCI spec
472+
func (oci *oci_t) GetSpec() *specs.Spec {
473+
return oci.spec
474+
}
475+
476+
// GetContainerId returns the container ID
477+
func (oci *oci_t) GetContainerId() string {
478+
return oci.containerId
479+
}

0 commit comments

Comments
 (0)