Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions cmd/amd-ctk/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
package runtime

import (
"fmt"
"os"

"github.com/ROCm/container-toolkit/cmd/amd-ctk/runtime/configure"
"github.com/urfave/cli/v2"
)
Expand All @@ -31,7 +34,24 @@ func AddNewCommand() *cli.Command {

runtimeCmd.Subcommands = []*cli.Command{
configure.AddNewCommand(),
addConfigureHookCommand(),
}

return &runtimeCmd
}

func addConfigureHookCommand() *cli.Command {
return &cli.Command{
Name: "configure-hook",
Usage: "Install amd-container-runtime-hook as OCI hook",
Action: func(c *cli.Context) error {
hookPath := "/usr/bin/amd-container-runtime-hook"
if _, err := os.Stat(hookPath); os.IsNotExist(err) {
return fmt.Errorf("hook binary not found at %s", hookPath)
}
fmt.Printf("AMD Container Runtime Hook is available at: %s\n", hookPath)
fmt.Println("Add this hook to your runtime configuration to enable --gpus flag support")
return nil
},
}
}
116 changes: 116 additions & 0 deletions cmd/container-runtime-hook/hook.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/**
# Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package main

import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"strings"

gpuTracker "github.com/ROCm/container-toolkit/internal/gpu-tracker"
"github.com/ROCm/container-toolkit/internal/logger"
"github.com/ROCm/container-toolkit/internal/oci"
)

func doPrestart() error {
logger.Log.Println("Running prestart hook")

// Read hook state from stdin (Docker/containerd provides this)
hookState, err := ioutil.ReadAll(os.Stdin)
if err != nil {
return fmt.Errorf("failed to read hook state from stdin: %v", err)
}

// Create OCI interface for hook context
ociInterface, err := oci.NewFromStdin()
if err != nil {
return fmt.Errorf("failed to create OCI interface: %v", err)
}

// Load spec from bundle path in hook state
ociImpl, ok := ociInterface.(*oci.oci_t)

Check failure on line 47 in cmd/container-runtime-hook/hook.go

View workflow job for this annotation

GitHub Actions / test

oci_t not exported by package oci
if !ok {
return fmt.Errorf("failed to cast OCI interface to oci_t")
}

if err := ociImpl.LoadSpecFromHookState(hookState); err != nil {
return fmt.Errorf("failed to load spec from hook state: %v", err)
}

// Check if GPU devices are requested
spec := ociInterface.GetSpec()
if spec == nil || spec.Process == nil {
logger.Log.Println("No process spec found, skipping GPU configuration")
return nil
}

hasGPURequest := false
for _, env := range spec.Process.Env {
if strings.HasPrefix(env, "AMD_VISIBLE_DEVICES=") ||
strings.HasPrefix(env, "DOCKER_RESOURCE_") {
hasGPURequest = true
break
}
}

if !hasGPURequest {
logger.Log.Println("No GPU devices requested, skipping configuration")
return nil
}

// Add GPU devices to spec
if err := ociInterface.UpdateSpec(oci.AddGPUDevices); err != nil {
return fmt.Errorf("failed to add GPU devices: %v", err)
}

// Write updated spec back
if err := ociInterface.WriteSpec(); err != nil {
return fmt.Errorf("failed to write updated spec: %v", err)
}

logger.Log.Println("Successfully configured GPU devices")
return nil
}

func doPoststop() error {
logger.Log.Println("Running poststop hook")

// Read hook state to get container ID
hookState, err := ioutil.ReadAll(os.Stdin)
if err != nil {
return fmt.Errorf("failed to read hook state from stdin: %v", err)
}

var state struct {
ID string `json:"id"`
}
if err := json.Unmarshal(hookState, &state); err != nil {
return fmt.Errorf("failed to parse hook state: %v", err)
}

// Release GPUs via tracker
tracker, err := gpuTracker.New()
if err != nil {
return fmt.Errorf("failed to create GPU tracker: %v", err)
}

tracker.ReleaseGPUs(state.ID)
logger.Log.Printf("Released GPUs for container %s", state.ID)
return nil
}
68 changes: 68 additions & 0 deletions cmd/container-runtime-hook/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/**
# Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/

package main

import (
"flag"
"fmt"
"os"

"github.com/ROCm/container-toolkit/internal/logger"
)

var (
versionFlag = flag.Bool("version", false, "Display version information")
)

func main() {
flag.Parse()
logger.Init(false)

if *versionFlag {
fmt.Println("AMD Container Runtime Hook version 1.0.0")
return
}

args := flag.Args()
if len(args) == 0 {
fmt.Fprintf(os.Stderr, "Usage: amd-container-runtime-hook <command>\n")
fmt.Fprintf(os.Stderr, "Commands:\n")
fmt.Fprintf(os.Stderr, " prestart - Configure GPU devices before container start\n")
fmt.Fprintf(os.Stderr, " poststop - Release GPU resources after container stop\n")
os.Exit(2)
}

command := args[0]
switch command {
case "prestart":
if err := doPrestart(); err != nil {
logger.Log.Printf("prestart hook failed: %v", err)
os.Exit(1)
}
case "poststop":
if err := doPoststop(); err != nil {
logger.Log.Printf("poststop hook failed: %v", err)
os.Exit(1)
}
case "poststart":
// No-op for compatibility
os.Exit(0)
default:
fmt.Fprintf(os.Stderr, "Unknown command: %s\n", command)
os.Exit(2)
}
}
27 changes: 15 additions & 12 deletions cmd/container-runtime/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,28 @@ func main() {
rt, err := runtime.New(os.Args)
if err != nil {
logger.Log.Printf("Failed to create container runtime, err = %v", err)
gpuTracker, err := gpuTracker.New()
if err != nil {
logger.Log.Printf("Failed to create GPU tracker, err = %v", err)
os.Exit(1)
}
gpuTracker.ReleaseGPUs(os.Args[len(os.Args)-1])
releaseGPUsOnError(os.Args)
os.Exit(1)
}

logger.Log.Printf("Running ROCm container runtime")
err = rt.Run()
if err != nil {
logger.Log.Printf("Failed to run container runtime, err = %v", err)
gpuTracker, err := gpuTracker.New()
if err != nil {
logger.Log.Printf("Failed to create GPU tracker, err = %v", err)
os.Exit(1)
}
gpuTracker.ReleaseGPUs(os.Args[len(os.Args)-1])
releaseGPUsOnError(os.Args)
os.Exit(1)
}
}

func releaseGPUsOnError(args []string) {
if len(args) == 0 {
return
}
containerId := args[len(args)-1]
gpuTracker, err := gpuTracker.New()
if err != nil {
logger.Log.Printf("Failed to create GPU tracker, err = %v", err)
return
}
gpuTracker.ReleaseGPUs(containerId)
}
58 changes: 58 additions & 0 deletions internal/oci/oci.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ type Interface interface {

// PrintSpec prints the current spec on the console
PrintSpec() error

// GetSpec returns the loaded OCI spec
GetSpec() *specs.Spec

// GetContainerId returns the container ID
GetContainerId() string
}

// GetGPUs is the type for functions that return the lists of all the GPU devices on the system
Expand Down Expand Up @@ -362,6 +368,48 @@ func New(argv []string) (Interface, error) {
return oci, nil
}

// NewFromStdin creates OCI interface from hook state read from stdin (for hook usage)
func NewFromStdin() (Interface, error) {
gpuTracker, err := gpuTracker.New()
if err != nil {
return nil, err
}

oci := &oci_t{
hookPath: DEFAULT_HOOK_PATH,
getGPUs: amdgpu.GetAMDGPUs,
getGPU: amdgpu.GetAMDGPU,
getUniqueIdToDeviceIndexMap: amdgpu.GetUniqueIdToDeviceIndexMap,
reserveGPUs: gpuTracker.ReserveGPUs,
}

return oci, nil
}

// LoadSpecFromHookState reads OCI spec from hook state provided on stdin
func (oci *oci_t) LoadSpecFromHookState(hookState []byte) error {
var state struct {
Pid int `json:"pid,omitempty"`
Bundle string `json:"bundle"`
BundlePath string `json:"bundlePath"`
ID string `json:"id"`
}

if err := json.Unmarshal(hookState, &state); err != nil {
return fmt.Errorf("failed to decode hook state: %v", err)
}

oci.containerId = state.ID
bundlePath := state.Bundle
if bundlePath == "" {
bundlePath = state.BundlePath
}
oci.origSpecPath = bundlePath
oci.updatedSpecPath = bundlePath

return oci.getSpec()
}

// HasHelpOption returns true if the arguments passed include the help option
func (oci *oci_t) HasHelpOption() bool {
return oci.hasHelpOption
Expand Down Expand Up @@ -419,3 +467,13 @@ func (oci *oci_t) PrintSpec() error {

return nil
}

// GetSpec returns the loaded OCI spec
func (oci *oci_t) GetSpec() *specs.Spec {
return oci.spec
}

// GetContainerId returns the container ID
func (oci *oci_t) GetContainerId() string {
return oci.containerId
}
Loading