diff --git a/cmd/amd-ctk/runtime/runtime.go b/cmd/amd-ctk/runtime/runtime.go index 6ee0294..65e2d1e 100644 --- a/cmd/amd-ctk/runtime/runtime.go +++ b/cmd/amd-ctk/runtime/runtime.go @@ -17,6 +17,9 @@ package runtime import ( + "fmt" + "os" + "github.com/ROCm/container-toolkit/cmd/amd-ctk/runtime/configure" "github.com/urfave/cli/v2" ) @@ -31,7 +34,24 @@ func AddNewCommand() *cli.Command { runtimeCmd.Subcommands = []*cli.Command{ configure.AddNewCommand(), + addConfigureHookCommand(), } return &runtimeCmd } + +func addConfigureHookCommand() *cli.Command { + return &cli.Command{ + Name: "configure-hook", + Usage: "Install amd-container-runtime-hook as OCI hook", + Action: func(c *cli.Context) error { + hookPath := "/usr/bin/amd-container-runtime-hook" + if _, err := os.Stat(hookPath); os.IsNotExist(err) { + return fmt.Errorf("hook binary not found at %s", hookPath) + } + fmt.Printf("AMD Container Runtime Hook is available at: %s\n", hookPath) + fmt.Println("Add this hook to your runtime configuration to enable --gpus flag support") + return nil + }, + } +} diff --git a/cmd/container-runtime-hook/hook.go b/cmd/container-runtime-hook/hook.go new file mode 100644 index 0000000..d50e24a --- /dev/null +++ b/cmd/container-runtime-hook/hook.go @@ -0,0 +1,116 @@ +/** +# Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package main + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "os" + "strings" + + gpuTracker "github.com/ROCm/container-toolkit/internal/gpu-tracker" + "github.com/ROCm/container-toolkit/internal/logger" + "github.com/ROCm/container-toolkit/internal/oci" +) + +func doPrestart() error { + logger.Log.Println("Running prestart hook") + + // Read hook state from stdin (Docker/containerd provides this) + hookState, err := ioutil.ReadAll(os.Stdin) + if err != nil { + return fmt.Errorf("failed to read hook state from stdin: %v", err) + } + + // Create OCI interface for hook context + ociInterface, err := oci.NewFromStdin() + if err != nil { + return fmt.Errorf("failed to create OCI interface: %v", err) + } + + // Load spec from bundle path in hook state + ociImpl, ok := ociInterface.(*oci.oci_t) + if !ok { + return fmt.Errorf("failed to cast OCI interface to oci_t") + } + + if err := ociImpl.LoadSpecFromHookState(hookState); err != nil { + return fmt.Errorf("failed to load spec from hook state: %v", err) + } + + // Check if GPU devices are requested + spec := ociInterface.GetSpec() + if spec == nil || spec.Process == nil { + logger.Log.Println("No process spec found, skipping GPU configuration") + return nil + } + + hasGPURequest := false + for _, env := range spec.Process.Env { + if strings.HasPrefix(env, "AMD_VISIBLE_DEVICES=") || + strings.HasPrefix(env, "DOCKER_RESOURCE_") { + hasGPURequest = true + break + } + } + + if !hasGPURequest { + logger.Log.Println("No GPU devices requested, skipping configuration") + return nil + } + + // Add GPU devices to spec + if err := ociInterface.UpdateSpec(oci.AddGPUDevices); err != nil { + return fmt.Errorf("failed to add GPU devices: %v", err) + } + + // Write updated spec back + if err := ociInterface.WriteSpec(); err != nil { + return fmt.Errorf("failed to write updated spec: %v", err) + } + + logger.Log.Println("Successfully configured GPU devices") + return nil +} + +func doPoststop() error { + logger.Log.Println("Running poststop hook") + + // Read hook state to get container ID + hookState, err := ioutil.ReadAll(os.Stdin) + if err != nil { + return fmt.Errorf("failed to read hook state from stdin: %v", err) + } + + var state struct { + ID string `json:"id"` + } + if err := json.Unmarshal(hookState, &state); err != nil { + return fmt.Errorf("failed to parse hook state: %v", err) + } + + // Release GPUs via tracker + tracker, err := gpuTracker.New() + if err != nil { + return fmt.Errorf("failed to create GPU tracker: %v", err) + } + + tracker.ReleaseGPUs(state.ID) + logger.Log.Printf("Released GPUs for container %s", state.ID) + return nil +} diff --git a/cmd/container-runtime-hook/main.go b/cmd/container-runtime-hook/main.go new file mode 100644 index 0000000..10c60dc --- /dev/null +++ b/cmd/container-runtime-hook/main.go @@ -0,0 +1,68 @@ +/** +# Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +**/ + +package main + +import ( + "flag" + "fmt" + "os" + + "github.com/ROCm/container-toolkit/internal/logger" +) + +var ( + versionFlag = flag.Bool("version", false, "Display version information") +) + +func main() { + flag.Parse() + logger.Init(false) + + if *versionFlag { + fmt.Println("AMD Container Runtime Hook version 1.0.0") + return + } + + args := flag.Args() + if len(args) == 0 { + fmt.Fprintf(os.Stderr, "Usage: amd-container-runtime-hook \n") + fmt.Fprintf(os.Stderr, "Commands:\n") + fmt.Fprintf(os.Stderr, " prestart - Configure GPU devices before container start\n") + fmt.Fprintf(os.Stderr, " poststop - Release GPU resources after container stop\n") + os.Exit(2) + } + + command := args[0] + switch command { + case "prestart": + if err := doPrestart(); err != nil { + logger.Log.Printf("prestart hook failed: %v", err) + os.Exit(1) + } + case "poststop": + if err := doPoststop(); err != nil { + logger.Log.Printf("poststop hook failed: %v", err) + os.Exit(1) + } + case "poststart": + // No-op for compatibility + os.Exit(0) + default: + fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command) + os.Exit(2) + } +} diff --git a/cmd/container-runtime/main.go b/cmd/container-runtime/main.go index fdcab9b..0a41415 100644 --- a/cmd/container-runtime/main.go +++ b/cmd/container-runtime/main.go @@ -31,12 +31,7 @@ func main() { rt, err := runtime.New(os.Args) if err != nil { logger.Log.Printf("Failed to create container runtime, err = %v", err) - gpuTracker, err := gpuTracker.New() - if err != nil { - logger.Log.Printf("Failed to create GPU tracker, err = %v", err) - os.Exit(1) - } - gpuTracker.ReleaseGPUs(os.Args[len(os.Args)-1]) + releaseGPUsOnError(os.Args) os.Exit(1) } @@ -44,12 +39,20 @@ func main() { err = rt.Run() if err != nil { logger.Log.Printf("Failed to run container runtime, err = %v", err) - gpuTracker, err := gpuTracker.New() - if err != nil { - logger.Log.Printf("Failed to create GPU tracker, err = %v", err) - os.Exit(1) - } - gpuTracker.ReleaseGPUs(os.Args[len(os.Args)-1]) + releaseGPUsOnError(os.Args) os.Exit(1) } } + +func releaseGPUsOnError(args []string) { + if len(args) == 0 { + return + } + containerId := args[len(args)-1] + gpuTracker, err := gpuTracker.New() + if err != nil { + logger.Log.Printf("Failed to create GPU tracker, err = %v", err) + return + } + gpuTracker.ReleaseGPUs(containerId) +} diff --git a/internal/oci/oci.go b/internal/oci/oci.go index 2ce17d4..19ae00b 100644 --- a/internal/oci/oci.go +++ b/internal/oci/oci.go @@ -50,6 +50,12 @@ type Interface interface { // PrintSpec prints the current spec on the console PrintSpec() error + + // GetSpec returns the loaded OCI spec + GetSpec() *specs.Spec + + // GetContainerId returns the container ID + GetContainerId() string } // GetGPUs is the type for functions that return the lists of all the GPU devices on the system @@ -362,6 +368,48 @@ func New(argv []string) (Interface, error) { return oci, nil } +// NewFromStdin creates OCI interface from hook state read from stdin (for hook usage) +func NewFromStdin() (Interface, error) { + gpuTracker, err := gpuTracker.New() + if err != nil { + return nil, err + } + + oci := &oci_t{ + hookPath: DEFAULT_HOOK_PATH, + getGPUs: amdgpu.GetAMDGPUs, + getGPU: amdgpu.GetAMDGPU, + getUniqueIdToDeviceIndexMap: amdgpu.GetUniqueIdToDeviceIndexMap, + reserveGPUs: gpuTracker.ReserveGPUs, + } + + return oci, nil +} + +// LoadSpecFromHookState reads OCI spec from hook state provided on stdin +func (oci *oci_t) LoadSpecFromHookState(hookState []byte) error { + var state struct { + Pid int `json:"pid,omitempty"` + Bundle string `json:"bundle"` + BundlePath string `json:"bundlePath"` + ID string `json:"id"` + } + + if err := json.Unmarshal(hookState, &state); err != nil { + return fmt.Errorf("failed to decode hook state: %v", err) + } + + oci.containerId = state.ID + bundlePath := state.Bundle + if bundlePath == "" { + bundlePath = state.BundlePath + } + oci.origSpecPath = bundlePath + oci.updatedSpecPath = bundlePath + + return oci.getSpec() +} + // HasHelpOption returns true if the arguments passed include the help option func (oci *oci_t) HasHelpOption() bool { return oci.hasHelpOption @@ -419,3 +467,13 @@ func (oci *oci_t) PrintSpec() error { return nil } + +// GetSpec returns the loaded OCI spec +func (oci *oci_t) GetSpec() *specs.Spec { + return oci.spec +} + +// GetContainerId returns the container ID +func (oci *oci_t) GetContainerId() string { + return oci.containerId +}