From e7ec8fbd021ef96e42bd57679306537e8b6b6a0a Mon Sep 17 00:00:00 2001 From: Antonin Stefanutti Date: Fri, 31 Jan 2025 17:34:45 +0100 Subject: [PATCH] KEP-2170: Add validation to Torch numProcPerNode field Signed-off-by: Antonin Stefanutti --- api.v2/openapi-spec/swagger.json | 4 ++-- hack/swagger/go.mod | 6 +++--- hack/swagger/go.sum | 3 +++ .../crds/kubeflow.org_clustertrainingruntimes.yaml | 7 ++++++- .../v2/base/crds/kubeflow.org_trainingruntimes.yaml | 7 ++++++- manifests/v2/base/crds/kubeflow.org_trainjobs.yaml | 5 ++++- .../kubeflow.org/v2alpha1/trainingruntime_types.go | 5 +++-- pkg/apis/kubeflow.org/v2alpha1/trainjob_types.go | 3 ++- .../kubeflow.org/v2alpha1/zz_generated.deepcopy.go | 5 +++-- pkg/apis/kubeflow.org/v2alpha1/zz_generated.openapi.go | 10 ++++------ .../kubeflow.org/v2alpha1/torchmlpolicysource.go | 8 ++++++-- .../kubeflow.org/v2alpha1/trainer.go | 5 +++-- pkg/runtime.v2/core/trainingruntime_test.go | 7 ++++--- pkg/runtime.v2/framework/plugins/torch/torch.go | 7 ++++--- pkg/util.v2/testing/wrapper.go | 5 +++-- .../controller.v2/trainjob_controller_test.go | 3 ++- 16 files changed, 58 insertions(+), 32 deletions(-) diff --git a/api.v2/openapi-spec/swagger.json b/api.v2/openapi-spec/swagger.json index 077820f620..e82feee139 100644 --- a/api.v2/openapi-spec/swagger.json +++ b/api.v2/openapi-spec/swagger.json @@ -517,7 +517,7 @@ }, "numProcPerNode": { "description": "Number of processes per node. This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI. Supported values: `auto`, `cpu`, `gpu`, or int value. Defaults to `auto`.", - "type": "string" + "$ref": "#/definitions/k8s.io.apimachinery.pkg.util.intstr.IntOrString" } } }, @@ -716,7 +716,7 @@ }, "numProcPerNode": { "description": "Number of processes/workers/slots on every training node. For the Torch runtime: `auto`, `cpu`, `gpu`, or int value can be set. For the MPI runtime only int value can be set.", - "type": "string" + "$ref": "#/definitions/k8s.io.apimachinery.pkg.util.intstr.IntOrString" }, "resourcesPerNode": { "description": "Compute resources for each training node.", diff --git a/hack/swagger/go.mod b/hack/swagger/go.mod index 2871baf675..54aa315901 100644 --- a/hack/swagger/go.mod +++ b/hack/swagger/go.mod @@ -28,9 +28,9 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/x448/float16 v0.8.4 // indirect - golang.org/x/net v0.30.0 // indirect - golang.org/x/sys v0.26.0 // indirect - golang.org/x/text v0.19.0 // indirect + golang.org/x/net v0.33.0 // indirect + golang.org/x/sys v0.28.0 // indirect + golang.org/x/text v0.21.0 // indirect google.golang.org/protobuf v1.35.1 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/hack/swagger/go.sum b/hack/swagger/go.sum index 1ac64d041c..8f9527be7c 100644 --- a/hack/swagger/go.sum +++ b/hack/swagger/go.sum @@ -80,6 +80,7 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -89,10 +90,12 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= diff --git a/manifests/v2/base/crds/kubeflow.org_clustertrainingruntimes.yaml b/manifests/v2/base/crds/kubeflow.org_clustertrainingruntimes.yaml index 4d281801c1..f62650aa87 100644 --- a/manifests/v2/base/crds/kubeflow.org_clustertrainingruntimes.yaml +++ b/manifests/v2/base/crds/kubeflow.org_clustertrainingruntimes.yaml @@ -583,12 +583,17 @@ spec: type: integer type: object numProcPerNode: + anyOf: + - type: integer + - type: string description: |- Number of processes per node. This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI. Supported values: `auto`, `cpu`, `gpu`, or int value. Defaults to `auto`. - type: string + x-kubernetes-int-or-string: true + x-kubernetes-validations: + - rule: self > 0 || self in ['auto', 'cpu', 'gpu'] type: object type: object podGroupPolicy: diff --git a/manifests/v2/base/crds/kubeflow.org_trainingruntimes.yaml b/manifests/v2/base/crds/kubeflow.org_trainingruntimes.yaml index 0ae165315c..021ea0f60e 100644 --- a/manifests/v2/base/crds/kubeflow.org_trainingruntimes.yaml +++ b/manifests/v2/base/crds/kubeflow.org_trainingruntimes.yaml @@ -583,12 +583,17 @@ spec: type: integer type: object numProcPerNode: + anyOf: + - type: integer + - type: string description: |- Number of processes per node. This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI. Supported values: `auto`, `cpu`, `gpu`, or int value. Defaults to `auto`. - type: string + x-kubernetes-int-or-string: true + x-kubernetes-validations: + - rule: self > 0 || self in ['auto', 'cpu', 'gpu'] type: object type: object podGroupPolicy: diff --git a/manifests/v2/base/crds/kubeflow.org_trainjobs.yaml b/manifests/v2/base/crds/kubeflow.org_trainjobs.yaml index 037e04191d..c22cc27896 100644 --- a/manifests/v2/base/crds/kubeflow.org_trainjobs.yaml +++ b/manifests/v2/base/crds/kubeflow.org_trainjobs.yaml @@ -3138,11 +3138,14 @@ spec: format: int32 type: integer numProcPerNode: + anyOf: + - type: integer + - type: string description: |- Number of processes/workers/slots on every training node. For the Torch runtime: `auto`, `cpu`, `gpu`, or int value can be set. For the MPI runtime only int value can be set. - type: string + x-kubernetes-int-or-string: true resourcesPerNode: description: Compute resources for each training node. properties: diff --git a/pkg/apis/kubeflow.org/v2alpha1/trainingruntime_types.go b/pkg/apis/kubeflow.org/v2alpha1/trainingruntime_types.go index d2ddceb140..4ac928895f 100644 --- a/pkg/apis/kubeflow.org/v2alpha1/trainingruntime_types.go +++ b/pkg/apis/kubeflow.org/v2alpha1/trainingruntime_types.go @@ -19,6 +19,7 @@ package v2alpha1 import ( autoscalingv2 "k8s.io/api/autoscaling/v2" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/intstr" jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" ) @@ -171,9 +172,9 @@ type TorchMLPolicySource struct { // Number of processes per node. // This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI. // Supported values: `auto`, `cpu`, `gpu`, or int value. - // TODO (andreyvelich): Add kubebuilder validation. // Defaults to `auto`. - NumProcPerNode *string `json:"numProcPerNode,omitempty"` + // +kubebuilder:validation:XValidation:rule="self > 0 || self in ['auto', 'cpu', 'gpu']" + NumProcPerNode *intstr.IntOrString `json:"numProcPerNode,omitempty"` // Elastic policy for the PyTorch training. ElasticPolicy *TorchElasticPolicy `json:"elasticPolicy,omitempty"` diff --git a/pkg/apis/kubeflow.org/v2alpha1/trainjob_types.go b/pkg/apis/kubeflow.org/v2alpha1/trainjob_types.go index 04f995c1fa..c1e2971d21 100644 --- a/pkg/apis/kubeflow.org/v2alpha1/trainjob_types.go +++ b/pkg/apis/kubeflow.org/v2alpha1/trainjob_types.go @@ -19,6 +19,7 @@ package v2alpha1 import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/intstr" ) const ( @@ -194,7 +195,7 @@ type Trainer struct { // Number of processes/workers/slots on every training node. // For the Torch runtime: `auto`, `cpu`, `gpu`, or int value can be set. // For the MPI runtime only int value can be set. - NumProcPerNode *string `json:"numProcPerNode,omitempty"` + NumProcPerNode *intstr.IntOrString `json:"numProcPerNode,omitempty"` } // DatasetConfig represents the desired dataset configuration. diff --git a/pkg/apis/kubeflow.org/v2alpha1/zz_generated.deepcopy.go b/pkg/apis/kubeflow.org/v2alpha1/zz_generated.deepcopy.go index cb4b3a5eeb..9c408b132e 100644 --- a/pkg/apis/kubeflow.org/v2alpha1/zz_generated.deepcopy.go +++ b/pkg/apis/kubeflow.org/v2alpha1/zz_generated.deepcopy.go @@ -24,6 +24,7 @@ import ( v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" runtime "k8s.io/apimachinery/pkg/runtime" + intstr "k8s.io/apimachinery/pkg/util/intstr" ) // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. @@ -576,7 +577,7 @@ func (in *TorchMLPolicySource) DeepCopyInto(out *TorchMLPolicySource) { *out = *in if in.NumProcPerNode != nil { in, out := &in.NumProcPerNode, &out.NumProcPerNode - *out = new(string) + *out = new(intstr.IntOrString) **out = **in } if in.ElasticPolicy != nil { @@ -786,7 +787,7 @@ func (in *Trainer) DeepCopyInto(out *Trainer) { } if in.NumProcPerNode != nil { in, out := &in.NumProcPerNode, &out.NumProcPerNode - *out = new(string) + *out = new(intstr.IntOrString) **out = **in } return diff --git a/pkg/apis/kubeflow.org/v2alpha1/zz_generated.openapi.go b/pkg/apis/kubeflow.org/v2alpha1/zz_generated.openapi.go index 5248d0ef02..da42681fad 100644 --- a/pkg/apis/kubeflow.org/v2alpha1/zz_generated.openapi.go +++ b/pkg/apis/kubeflow.org/v2alpha1/zz_generated.openapi.go @@ -974,8 +974,7 @@ func schema_pkg_apis_kubefloworg_v2alpha1_TorchMLPolicySource(ref common.Referen "numProcPerNode": { SchemaProps: spec.SchemaProps{ Description: "Number of processes per node. This value is inserted into the `--nproc-per-node` argument of the `torchrun` CLI. Supported values: `auto`, `cpu`, `gpu`, or int value. Defaults to `auto`.", - Type: []string{"string"}, - Format: "", + Ref: ref("k8s.io/apimachinery/pkg/util/intstr.IntOrString"), }, }, "elasticPolicy": { @@ -988,7 +987,7 @@ func schema_pkg_apis_kubefloworg_v2alpha1_TorchMLPolicySource(ref common.Referen }, }, Dependencies: []string{ - "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1.TorchElasticPolicy"}, + "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1.TorchElasticPolicy", "k8s.io/apimachinery/pkg/util/intstr.IntOrString"}, } } @@ -1352,15 +1351,14 @@ func schema_pkg_apis_kubefloworg_v2alpha1_Trainer(ref common.ReferenceCallback) "numProcPerNode": { SchemaProps: spec.SchemaProps{ Description: "Number of processes/workers/slots on every training node. For the Torch runtime: `auto`, `cpu`, `gpu`, or int value can be set. For the MPI runtime only int value can be set.", - Type: []string{"string"}, - Format: "", + Ref: ref("k8s.io/apimachinery/pkg/util/intstr.IntOrString"), }, }, }, }, }, Dependencies: []string{ - "k8s.io/api/core/v1.EnvVar", "k8s.io/api/core/v1.ResourceRequirements"}, + "k8s.io/api/core/v1.EnvVar", "k8s.io/api/core/v1.ResourceRequirements", "k8s.io/apimachinery/pkg/util/intstr.IntOrString"}, } } diff --git a/pkg/client/applyconfiguration/kubeflow.org/v2alpha1/torchmlpolicysource.go b/pkg/client/applyconfiguration/kubeflow.org/v2alpha1/torchmlpolicysource.go index 401234ac01..8a7c4883a2 100644 --- a/pkg/client/applyconfiguration/kubeflow.org/v2alpha1/torchmlpolicysource.go +++ b/pkg/client/applyconfiguration/kubeflow.org/v2alpha1/torchmlpolicysource.go @@ -16,10 +16,14 @@ package v2alpha1 +import ( + intstr "k8s.io/apimachinery/pkg/util/intstr" +) + // TorchMLPolicySourceApplyConfiguration represents a declarative configuration of the TorchMLPolicySource type for use // with apply. type TorchMLPolicySourceApplyConfiguration struct { - NumProcPerNode *string `json:"numProcPerNode,omitempty"` + NumProcPerNode *intstr.IntOrString `json:"numProcPerNode,omitempty"` ElasticPolicy *TorchElasticPolicyApplyConfiguration `json:"elasticPolicy,omitempty"` } @@ -32,7 +36,7 @@ func TorchMLPolicySource() *TorchMLPolicySourceApplyConfiguration { // WithNumProcPerNode sets the NumProcPerNode field in the declarative configuration to the given value // and returns the receiver, so that objects can be built by chaining "With" function invocations. // If called multiple times, the NumProcPerNode field is set to the value of the last call. -func (b *TorchMLPolicySourceApplyConfiguration) WithNumProcPerNode(value string) *TorchMLPolicySourceApplyConfiguration { +func (b *TorchMLPolicySourceApplyConfiguration) WithNumProcPerNode(value intstr.IntOrString) *TorchMLPolicySourceApplyConfiguration { b.NumProcPerNode = &value return b } diff --git a/pkg/client/applyconfiguration/kubeflow.org/v2alpha1/trainer.go b/pkg/client/applyconfiguration/kubeflow.org/v2alpha1/trainer.go index 49d991a440..f8d19c1275 100644 --- a/pkg/client/applyconfiguration/kubeflow.org/v2alpha1/trainer.go +++ b/pkg/client/applyconfiguration/kubeflow.org/v2alpha1/trainer.go @@ -18,6 +18,7 @@ package v2alpha1 import ( v1 "k8s.io/api/core/v1" + intstr "k8s.io/apimachinery/pkg/util/intstr" ) // TrainerApplyConfiguration represents a declarative configuration of the Trainer type for use @@ -29,7 +30,7 @@ type TrainerApplyConfiguration struct { Env []v1.EnvVar `json:"env,omitempty"` NumNodes *int32 `json:"numNodes,omitempty"` ResourcesPerNode *v1.ResourceRequirements `json:"resourcesPerNode,omitempty"` - NumProcPerNode *string `json:"numProcPerNode,omitempty"` + NumProcPerNode *intstr.IntOrString `json:"numProcPerNode,omitempty"` } // TrainerApplyConfiguration constructs a declarative configuration of the Trainer type for use with @@ -95,7 +96,7 @@ func (b *TrainerApplyConfiguration) WithResourcesPerNode(value v1.ResourceRequir // WithNumProcPerNode sets the NumProcPerNode field in the declarative configuration to the given value // and returns the receiver, so that objects can be built by chaining "With" function invocations. // If called multiple times, the NumProcPerNode field is set to the value of the last call. -func (b *TrainerApplyConfiguration) WithNumProcPerNode(value string) *TrainerApplyConfiguration { +func (b *TrainerApplyConfiguration) WithNumProcPerNode(value intstr.IntOrString) *TrainerApplyConfiguration { b.NumProcPerNode = &value return b } diff --git a/pkg/runtime.v2/core/trainingruntime_test.go b/pkg/runtime.v2/core/trainingruntime_test.go index 2a11716cd3..cbc54efa06 100644 --- a/pkg/runtime.v2/core/trainingruntime_test.go +++ b/pkg/runtime.v2/core/trainingruntime_test.go @@ -19,6 +19,7 @@ package core import ( "context" "fmt" + "k8s.io/apimachinery/pkg/util/intstr" "testing" "github.com/google/go-cmp/cmp" @@ -263,7 +264,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { "succeeded to build JobSet with Torch values from the TrainJob": { trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").RuntimeSpec( testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec). - TorchPolicy(100, "auto"). + TorchPolicy(100, intstr.FromString("auto")). ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). Obj(), ).Obj(), @@ -273,7 +274,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { Trainer( testingutil.MakeTrainJobTrainerWrapper(). NumNodes(30). - NumProcPerNode("3"). + NumProcPerNode(intstr.FromInt32(3)). Obj(), ). Obj(), @@ -317,7 +318,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { "succeeded to build JobSet with Torch values from the Runtime and envs.": { trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").RuntimeSpec( testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec). - TorchPolicy(100, "auto"). + TorchPolicy(100, intstr.FromString("auto")). ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). ContainerTrainerEnv( []corev1.EnvVar{ diff --git a/pkg/runtime.v2/framework/plugins/torch/torch.go b/pkg/runtime.v2/framework/plugins/torch/torch.go index 4e9c40585f..41a9983ab6 100644 --- a/pkg/runtime.v2/framework/plugins/torch/torch.go +++ b/pkg/runtime.v2/framework/plugins/torch/torch.go @@ -19,6 +19,7 @@ package torch import ( "context" "fmt" + "k8s.io/apimachinery/pkg/util/intstr" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/sets" @@ -61,9 +62,9 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *kubeflowv2.TrainJo } info.Trainer.NumNodes = numNodes - numProcPerNode := info.RuntimePolicy.MLPolicy.Torch.NumProcPerNode + numProcPerNode := ptr.Deref(info.RuntimePolicy.MLPolicy.Torch.NumProcPerNode, intstr.FromString("auto")) if trainJob.Spec.Trainer != nil && trainJob.Spec.Trainer.NumProcPerNode != nil { - numProcPerNode = trainJob.Spec.Trainer.NumProcPerNode + numProcPerNode = ptr.Deref(trainJob.Spec.Trainer.NumProcPerNode, intstr.FromString("auto")) } // Update envs for Info object. @@ -78,7 +79,7 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info, trainJob *kubeflowv2.TrainJo }, { Name: constants.TorchEnvNumProcPerNode, - Value: ptr.Deref(numProcPerNode, "auto"), + Value: numProcPerNode.String(), }, { Name: constants.TorchEnvNodeRank, diff --git a/pkg/util.v2/testing/wrapper.go b/pkg/util.v2/testing/wrapper.go index fdbb3dd6c7..ee8547c5d9 100644 --- a/pkg/util.v2/testing/wrapper.go +++ b/pkg/util.v2/testing/wrapper.go @@ -22,6 +22,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/utils/ptr" jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1" @@ -392,7 +393,7 @@ func (t *TrainJobTrainerWrapper) NumNodes(numNodes int32) *TrainJobTrainerWrappe return t } -func (t *TrainJobTrainerWrapper) NumProcPerNode(numProcPerNode string) *TrainJobTrainerWrapper { +func (t *TrainJobTrainerWrapper) NumProcPerNode(numProcPerNode intstr.IntOrString) *TrainJobTrainerWrapper { t.Trainer.NumProcPerNode = &numProcPerNode return t } @@ -689,7 +690,7 @@ func (s *TrainingRuntimeSpecWrapper) NumNodes(numNodes int32) *TrainingRuntimeSp return s } -func (s *TrainingRuntimeSpecWrapper) TorchPolicy(numNodes int32, numProcPerNode string) *TrainingRuntimeSpecWrapper { +func (s *TrainingRuntimeSpecWrapper) TorchPolicy(numNodes int32, numProcPerNode intstr.IntOrString) *TrainingRuntimeSpecWrapper { s.MLPolicy = &kubeflowv2.MLPolicy{ NumNodes: &numNodes, MLPolicySource: kubeflowv2.MLPolicySource{ diff --git a/test/integration/controller.v2/trainjob_controller_test.go b/test/integration/controller.v2/trainjob_controller_test.go index 487114d38e..d4eebee363 100644 --- a/test/integration/controller.v2/trainjob_controller_test.go +++ b/test/integration/controller.v2/trainjob_controller_test.go @@ -26,6 +26,7 @@ import ( "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" @@ -278,7 +279,7 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() { trainingRuntime = testingutil.MakeTrainingRuntimeWrapper(ns.Name, "alpha"). RuntimeSpec( testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "alpha").Spec). - TorchPolicy(100, "auto"). + TorchPolicy(100, intstr.FromString("auto")). ContainerTrainer("test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). Obj()). Obj()