Skip to content

Commit

Permalink
KEP-2170: Add validation to Torch numProcPerNode field
Browse files Browse the repository at this point in the history
Signed-off-by: Antonin Stefanutti <[email protected]>
  • Loading branch information
astefanutti committed Jan 31, 2025
1 parent 7ed7368 commit e7ec8fb
Show file tree
Hide file tree
Showing 16 changed files with 58 additions and 32 deletions.
4 changes: 2 additions & 2 deletions api.v2/openapi-spec/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@
},
"numProcPerNode": {
"description": "Number of processes per node. This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI. Supported values: `auto`, `cpu`, `gpu`, or int value. Defaults to `auto`.",
"type": "string"
"$ref": "#/definitions/k8s.io.apimachinery.pkg.util.intstr.IntOrString"
}
}
},
Expand Down Expand Up @@ -716,7 +716,7 @@
},
"numProcPerNode": {
"description": "Number of processes/workers/slots on every training node. For the Torch runtime: `auto`, `cpu`, `gpu`, or int value can be set. For the MPI runtime only int value can be set.",
"type": "string"
"$ref": "#/definitions/k8s.io.apimachinery.pkg.util.intstr.IntOrString"
},
"resourcesPerNode": {
"description": "Compute resources for each training node.",
Expand Down
6 changes: 3 additions & 3 deletions hack/swagger/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/x448/float16 v0.8.4 // indirect
golang.org/x/net v0.30.0 // indirect
golang.org/x/sys v0.26.0 // indirect
golang.org/x/text v0.19.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.21.0 // indirect
google.golang.org/protobuf v1.35.1 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
Expand Down
3 changes: 3 additions & 0 deletions hack/swagger/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
Expand All @@ -89,10 +90,12 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,12 +583,17 @@ spec:
type: integer
type: object
numProcPerNode:
anyOf:
- type: integer
- type: string
description: |-
Number of processes per node.
This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI.
Supported values: `auto`, `cpu`, `gpu`, or int value.
Defaults to `auto`.
type: string
x-kubernetes-int-or-string: true
x-kubernetes-validations:
- rule: self > 0 || self in ['auto', 'cpu', 'gpu']
type: object
type: object
podGroupPolicy:
Expand Down
7 changes: 6 additions & 1 deletion manifests/v2/base/crds/kubeflow.org_trainingruntimes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -583,12 +583,17 @@ spec:
type: integer
type: object
numProcPerNode:
anyOf:
- type: integer
- type: string
description: |-
Number of processes per node.
This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI.
Supported values: `auto`, `cpu`, `gpu`, or int value.
Defaults to `auto`.
type: string
x-kubernetes-int-or-string: true
x-kubernetes-validations:
- rule: self > 0 || self in ['auto', 'cpu', 'gpu']
type: object
type: object
podGroupPolicy:
Expand Down
5 changes: 4 additions & 1 deletion manifests/v2/base/crds/kubeflow.org_trainjobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3138,11 +3138,14 @@ spec:
format: int32
type: integer
numProcPerNode:
anyOf:
- type: integer
- type: string
description: |-
Number of processes/workers/slots on every training node.
For the Torch runtime: `auto`, `cpu`, `gpu`, or int value can be set.
For the MPI runtime only int value can be set.
type: string
x-kubernetes-int-or-string: true
resourcesPerNode:
description: Compute resources for each training node.
properties:
Expand Down
5 changes: 3 additions & 2 deletions pkg/apis/kubeflow.org/v2alpha1/trainingruntime_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package v2alpha1
import (
autoscalingv2 "k8s.io/api/autoscaling/v2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
)

Expand Down Expand Up @@ -171,9 +172,9 @@ type TorchMLPolicySource struct {
// Number of processes per node.
// This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI.
// Supported values: `auto`, `cpu`, `gpu`, or int value.
// TODO (andreyvelich): Add kubebuilder validation.
// Defaults to `auto`.
NumProcPerNode *string `json:"numProcPerNode,omitempty"`
// +kubebuilder:validation:XValidation:rule="self > 0 || self in ['auto', 'cpu', 'gpu']"
NumProcPerNode *intstr.IntOrString `json:"numProcPerNode,omitempty"`

// Elastic policy for the PyTorch training.
ElasticPolicy *TorchElasticPolicy `json:"elasticPolicy,omitempty"`
Expand Down
3 changes: 2 additions & 1 deletion pkg/apis/kubeflow.org/v2alpha1/trainjob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package v2alpha1
import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/intstr"
)

const (
Expand Down Expand Up @@ -194,7 +195,7 @@ type Trainer struct {
// Number of processes/workers/slots on every training node.
// For the Torch runtime: `auto`, `cpu`, `gpu`, or int value can be set.
// For the MPI runtime only int value can be set.
NumProcPerNode *string `json:"numProcPerNode,omitempty"`
NumProcPerNode *intstr.IntOrString `json:"numProcPerNode,omitempty"`
}

// DatasetConfig represents the desired dataset configuration.
Expand Down
5 changes: 3 additions & 2 deletions pkg/apis/kubeflow.org/v2alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 4 additions & 6 deletions pkg/apis/kubeflow.org/v2alpha1/zz_generated.openapi.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions pkg/runtime.v2/core/trainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package core
import (
"context"
"fmt"
"k8s.io/apimachinery/pkg/util/intstr"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -263,7 +264,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
"succeeded to build JobSet with Torch values from the TrainJob": {
trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec).
TorchPolicy(100, "auto").
TorchPolicy(100, intstr.FromString("auto")).
ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
Obj(),
).Obj(),
Expand All @@ -273,7 +274,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
Trainer(
testingutil.MakeTrainJobTrainerWrapper().
NumNodes(30).
NumProcPerNode("3").
NumProcPerNode(intstr.FromInt32(3)).
Obj(),
).
Obj(),
Expand Down Expand Up @@ -317,7 +318,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
"succeeded to build JobSet with Torch values from the Runtime and envs.": {
trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec).
TorchPolicy(100, "auto").
TorchPolicy(100, intstr.FromString("auto")).
ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
ContainerTrainerEnv(
[]corev1.EnvVar{
Expand Down
7 changes: 4 additions & 3 deletions pkg/runtime.v2/framework/plugins/torch/torch.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package torch
import (
"context"
"fmt"
"k8s.io/apimachinery/pkg/util/intstr"

corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/sets"
Expand Down Expand Up @@ -61,9 +62,9 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *kubeflowv2.TrainJo
}
info.Trainer.NumNodes = numNodes

numProcPerNode := info.RuntimePolicy.MLPolicy.Torch.NumProcPerNode
numProcPerNode := ptr.Deref(info.RuntimePolicy.MLPolicy.Torch.NumProcPerNode, intstr.FromString("auto"))
if trainJob.Spec.Trainer != nil && trainJob.Spec.Trainer.NumProcPerNode != nil {
numProcPerNode = trainJob.Spec.Trainer.NumProcPerNode
numProcPerNode = ptr.Deref(trainJob.Spec.Trainer.NumProcPerNode, intstr.FromString("auto"))
}

// Update envs for Info object.
Expand All @@ -78,7 +79,7 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *kubeflowv2.TrainJo
},
{
Name: constants.TorchEnvNumProcPerNode,
Value: ptr.Deref(numProcPerNode, "auto"),
Value: numProcPerNode.String(),
},
{
Name: constants.TorchEnvNodeRank,
Expand Down
5 changes: 3 additions & 2 deletions pkg/util.v2/testing/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/utils/ptr"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"
Expand Down Expand Up @@ -392,7 +393,7 @@ func (t *TrainJobTrainerWrapper) NumNodes(numNodes int32) *TrainJobTrainerWrappe
return t
}

func (t *TrainJobTrainerWrapper) NumProcPerNode(numProcPerNode string) *TrainJobTrainerWrapper {
func (t *TrainJobTrainerWrapper) NumProcPerNode(numProcPerNode intstr.IntOrString) *TrainJobTrainerWrapper {
t.Trainer.NumProcPerNode = &numProcPerNode
return t
}
Expand Down Expand Up @@ -689,7 +690,7 @@ func (s *TrainingRuntimeSpecWrapper) NumNodes(numNodes int32) *TrainingRuntimeSp
return s
}

func (s *TrainingRuntimeSpecWrapper) TorchPolicy(numNodes int32, numProcPerNode string) *TrainingRuntimeSpecWrapper {
func (s *TrainingRuntimeSpecWrapper) TorchPolicy(numNodes int32, numProcPerNode intstr.IntOrString) *TrainingRuntimeSpecWrapper {
s.MLPolicy = &kubeflowv2.MLPolicy{
NumNodes: &numNodes,
MLPolicySource: kubeflowv2.MLPolicySource{
Expand Down
3 changes: 2 additions & 1 deletion test/integration/controller.v2/trainjob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/intstr"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
Expand Down Expand Up @@ -278,7 +279,7 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() {
trainingRuntime = testingutil.MakeTrainingRuntimeWrapper(ns.Name, "alpha").
RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "alpha").Spec).
TorchPolicy(100, "auto").
TorchPolicy(100, intstr.FromString("auto")).
ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
Obj()).
Obj()
Expand Down

0 comments on commit e7ec8fb

Please sign in to comment.