Skip to content

Commit 75caeeb

Browse files
author
Akshay Chitneni
committed
Adding v2 trainjob validation webhook
Signed-off-by: Akshay Chitneni <[email protected]>
1 parent f64bdf2 commit 75caeeb

19 files changed

+457
-88
lines changed

pkg/constants/constants.go

+20-17
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55

66
batchv1 "k8s.io/api/batch/v1"
7+
"k8s.io/apimachinery/pkg/util/sets"
78
)
89

910
const (
@@ -61,23 +62,6 @@ const (
6162
// {"type": "Suspended", "status": "True", "reason": "Resumed"} condition.
6263
TrainJobResumedMessage = "TrainJob is resumed"
6364

64-
// Distributed envs for torchrun.
65-
// Ref: https://github.com/pytorch/pytorch/blob/3a0d0885171376ed610c8175a19ba40411fc6f3f/torch/distributed/argparse_util.py#L45
66-
// TorchEnvNumNodes is the env name for the number of training nodes.
67-
TorchEnvNumNodes string = "PET_NNODES"
68-
69-
// TorchEnvNumProcPerNode is the env name for the number of procs per node (e.g. number of GPUs per Pod).
70-
TorchEnvNumProcPerNode string = "PET_NPROC_PER_NODE"
71-
72-
// TorchEnvNodeRank is the env name for the node RANK
73-
TorchEnvNodeRank string = "PET_NODE_RANK"
74-
75-
// TorchEnvMasterAddr is the env name for the master node address.
76-
TorchEnvMasterAddr string = "PET_MASTER_ADDR"
77-
78-
// TorchEnvMasterPort is the env name for the master node port.
79-
TorchEnvMasterPort string = "PET_MASTER_PORT"
80-
8165
// JobLauncher is the Job name for the launcher.
8266
JobLauncher string = "launcher"
8367

@@ -131,9 +115,28 @@ const (
131115

132116
// OpenMPIEnvDefaultSlots is the OpenMPI default number of slots env key.
133117
OpenMPIEnvDefaultSlots string = "OMPI_MCA_orte_set_default_slots"
118+
// Distributed envs for torchrun.
119+
// Ref: https://github.com/pytorch/pytorch/blob/3a0d0885171376ed610c8175a19ba40411fc6f3f/torch/distributed/argparse_util.py#L45
120+
// TorchEnvNumNodes is the env name for the number of training nodes.
121+
TorchEnvNumNodes string = "PET_NNODS"
122+
123+
// TorchEnvNumProcPerNode is the env name for the number of procs per node (e.g. number of GPUs per Pod).
124+
TorchEnvNumProcPerNode string = "PET_NPROC_PER_NODE"
125+
126+
// TorchEnvNodeRank is the env name for the node RANK
127+
TorchEnvNodeRank string = "PET_NODE_RANK"
128+
129+
// TorchEnvMasterAddr is the env name for the master node address.
130+
TorchEnvMasterAddr string = "PET_MASTER_ADDR"
131+
132+
// TorchEnvMasterPort is the env name for the master node port.
133+
TorchEnvMasterPort string = "PET_MASTER_PORT"
134134
)
135135

136136
var (
137137
// JobCompletionIndexFieldPath is the field path for the Job completion index annotation.
138138
JobCompletionIndexFieldPath string = fmt.Sprintf("metadata.annotations['%s']", batchv1.JobCompletionIndexAnnotation)
139+
140+
// Torchrun reserved env names
141+
TorchRunReservedEnvNames = sets.New(TorchEnvNumNodes, TorchEnvNumProcPerNode, TorchEnvNodeRank, TorchEnvMasterAddr, TorchEnvMasterPort)
139142
)

pkg/controller/trainjob_controller.go

+2-11
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ import (
4141
jobruntimes "github.com/kubeflow/trainer/pkg/runtime"
4242
)
4343

44-
var errorUnsupportedRuntime = errors.New("the specified runtime is not supported")
45-
4644
type objsOpState int
4745

4846
const (
@@ -85,10 +83,10 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
8583
return ctrl.Result{}, nil
8684
}
8785

88-
runtimeRefGK := runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()
86+
runtimeRefGK := jobruntimes.RuntimeRefToRuntimeRegistryKey(trainJob.Spec.RuntimeRef)
8987
runtime, ok := r.runtimes[runtimeRefGK]
9088
if !ok {
91-
return ctrl.Result{}, fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK)
89+
return ctrl.Result{}, fmt.Errorf("unsupported runtime: %s", runtimeRefGK)
9290
}
9391
opState, err := r.reconcileObjects(ctx, runtime, &trainJob)
9492

@@ -215,13 +213,6 @@ func isTrainJobFinished(trainJob *trainer.TrainJob) bool {
215213
meta.IsStatusConditionTrue(trainJob.Status.Conditions, trainer.TrainJobFailed)
216214
}
217215

218-
func runtimeRefToGroupKind(runtimeRef trainer.RuntimeRef) schema.GroupKind {
219-
return schema.GroupKind{
220-
Group: ptr.Deref(runtimeRef.APIGroup, ""),
221-
Kind: ptr.Deref(runtimeRef.Kind, ""),
222-
}
223-
}
224-
225216
func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager, options controller.Options) error {
226217
b := ctrl.NewControllerManagedBy(mgr).
227218
WithOptions(options).

pkg/runtime/core/clustertrainingruntime.go

+9-4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"k8s.io/apimachinery/pkg/util/validation/field"
2727
"sigs.k8s.io/controller-runtime/pkg/client"
2828
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
29+
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
2930

3031
trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
3132
"github.com/kubeflow/trainer/pkg/runtime"
@@ -69,14 +70,18 @@ func (r *ClusterTrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBu
6970
}
7071

7172
func (r *ClusterTrainingRuntime) ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
73+
clusterTrainingRuntime := &trainer.ClusterTrainingRuntime{}
7274
if err := r.client.Get(ctx, client.ObjectKey{
73-
Namespace: old.Namespace,
74-
Name: old.Spec.RuntimeRef.Name,
75+
Name: new.Spec.RuntimeRef.Name,
7576
}, &trainer.ClusterTrainingRuntime{}); err != nil {
7677
return nil, field.ErrorList{
77-
field.Invalid(field.NewPath("spec", "RuntimeRef"), old.Spec.RuntimeRef,
78+
field.Invalid(field.NewPath("spec", "RuntimeRef"), new.Spec.RuntimeRef,
7879
fmt.Sprintf("%v: specified clusterTrainingRuntime must be created before the TrainJob is created", err)),
7980
}
8081
}
81-
return r.framework.RunCustomValidationPlugins(old, new)
82+
info, _ := r.runtimeInfo(ctx, new, clusterTrainingRuntime.Spec.Template, clusterTrainingRuntime.Spec.MLPolicy, clusterTrainingRuntime.Spec.PodGroupPolicy)
83+
jobSetTemplate := jobsetv1alpha2.JobSet{
84+
Spec: clusterTrainingRuntime.Spec.Template.Spec,
85+
}
86+
return r.framework.RunCustomValidationPlugins(jobSetTemplate.DeepCopy(), info, old, new)
8287
}

pkg/runtime/core/trainingruntime.go

+37-22
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ import (
2121
"errors"
2222
"fmt"
2323

24-
"github.com/kubeflow/trainer/pkg/apply"
25-
"github.com/kubeflow/trainer/pkg/constants"
2624
corev1 "k8s.io/api/core/v1"
2725
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2826
"k8s.io/apimachinery/pkg/runtime/schema"
@@ -35,6 +33,8 @@ import (
3533
jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2"
3634

3735
trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
36+
"github.com/kubeflow/trainer/pkg/apply"
37+
"github.com/kubeflow/trainer/pkg/constants"
3838
"github.com/kubeflow/trainer/pkg/runtime"
3939
fwkcore "github.com/kubeflow/trainer/pkg/runtime/framework/core"
4040
fwkplugins "github.com/kubeflow/trainer/pkg/runtime/framework/plugins"
@@ -89,6 +89,29 @@ func (r *TrainingRuntime) NewObjects(ctx context.Context, trainJob *trainer.Trai
8989
func (r *TrainingRuntime) buildObjects(
9090
ctx context.Context, trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy,
9191
) ([]any, error) {
92+
93+
info, err := r.runtimeInfo(ctx, trainJob, jobSetTemplateSpec, mlPolicy, podGroupPolicy)
94+
if err != nil {
95+
return nil, err
96+
}
97+
if err = r.framework.RunEnforceMLPolicyPlugins(info, trainJob); err != nil {
98+
return nil, err
99+
}
100+
101+
if err = r.framework.RunEnforcePodGroupPolicyPlugins(info, trainJob); err != nil {
102+
return nil, err
103+
}
104+
105+
if err = r.framework.RunPodNetworkPlugins(info, trainJob); err != nil {
106+
return nil, err
107+
}
108+
109+
return r.framework.RunComponentBuilderPlugins(ctx, info, trainJob)
110+
}
111+
112+
func (r *TrainingRuntime) runtimeInfo(
113+
ctx context.Context, trainJob *trainer.TrainJob, jobSetTemplateSpec trainer.JobSetTemplateSpec, mlPolicy *trainer.MLPolicy, podGroupPolicy *trainer.PodGroupPolicy) (*runtime.Info, error) {
114+
92115
propagationLabels := jobSetTemplateSpec.Labels
93116
if propagationLabels == nil && trainJob.Spec.Labels != nil {
94117
propagationLabels = make(map[string]string, len(trainJob.Spec.Labels))
@@ -140,21 +163,7 @@ func (r *TrainingRuntime) buildObjects(
140163
)
141164
}
142165

143-
info := runtime.NewInfo(opts...)
144-
145-
if err = r.framework.RunEnforceMLPolicyPlugins(info, trainJob); err != nil {
146-
return nil, err
147-
}
148-
149-
if err = r.framework.RunEnforcePodGroupPolicyPlugins(info, trainJob); err != nil {
150-
return nil, err
151-
}
152-
153-
if err = r.framework.RunPodNetworkPlugins(info, trainJob); err != nil {
154-
return nil, err
155-
}
156-
157-
return r.framework.RunComponentBuilderPlugins(ctx, info, trainJob)
166+
return runtime.NewInfo(opts...), nil
158167
}
159168

160169
func syncPodSets(info *runtime.Info) {
@@ -198,14 +207,20 @@ func (r *TrainingRuntime) EventHandlerRegistrars() []runtime.ReconcilerBuilder {
198207
}
199208

200209
func (r *TrainingRuntime) ValidateObjects(ctx context.Context, old, new *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
210+
trainingRuntime := &trainer.TrainingRuntime{}
201211
if err := r.client.Get(ctx, client.ObjectKey{
202-
Namespace: old.Namespace,
203-
Name: old.Spec.RuntimeRef.Name,
204-
}, &trainer.TrainingRuntime{}); err != nil {
212+
Namespace: new.Namespace,
213+
Name: new.Spec.RuntimeRef.Name,
214+
}, trainingRuntime); err != nil {
205215
return nil, field.ErrorList{
206-
field.Invalid(field.NewPath("spec", "runtimeRef"), old.Spec.RuntimeRef,
216+
field.Invalid(field.NewPath("spec", "runtimeRef"), new.Spec.RuntimeRef,
207217
fmt.Sprintf("%v: specified trainingRuntime must be created before the TrainJob is created", err)),
208218
}
209219
}
210-
return r.framework.RunCustomValidationPlugins(old, new)
220+
info, _ := r.runtimeInfo(ctx, new, trainingRuntime.Spec.Template, trainingRuntime.Spec.MLPolicy, trainingRuntime.Spec.PodGroupPolicy) // ignoring the error here as the runtime configured should be valid
221+
222+
jobSetTemplate := jobsetv1alpha2.JobSet{
223+
Spec: trainingRuntime.Spec.Template.Spec,
224+
}
225+
return r.framework.RunCustomValidationPlugins(jobSetTemplate.DeepCopy(), info, old, new)
211226
}

pkg/runtime/framework/core/framework.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,11 @@ func (f *Framework) RunEnforcePodGroupPolicyPlugins(info *runtime.Info, trainJob
101101
return nil
102102
}
103103

104-
func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
104+
func (f *Framework) RunCustomValidationPlugins(runtimeJobTemplate client.Object, info *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
105105
var aggregatedWarnings admission.Warnings
106106
var aggregatedErrors field.ErrorList
107107
for _, plugin := range f.customValidationPlugins {
108-
warnings, errs := plugin.Validate(oldObj, newObj)
108+
warnings, errs := plugin.Validate(runtimeJobTemplate, info, oldObj, newObj)
109109
if len(warnings) != 0 {
110110
aggregatedWarnings = append(aggregatedWarnings, warnings...)
111111
}

pkg/runtime/framework/core/framework_test.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ func TestNew(t *testing.T) {
8787
customValidationPlugins: []framework.CustomValidationPlugin{
8888
&mpi.MPI{},
8989
&torch.Torch{},
90+
&jobset.JobSet{},
9091
},
9192
watchExtensionPlugins: []framework.WatchExtensionPlugin{
9293
&coscheduling.CoScheduling{},
@@ -379,7 +380,9 @@ func TestRunCustomValidationPlugins(t *testing.T) {
379380
if err != nil {
380381
t.Fatal(err)
381382
}
382-
warnings, errs := fwk.RunCustomValidationPlugins(tc.oldObj, tc.newObj)
383+
runtimeInfo := runtime.NewInfo()
384+
jobSetTemplate := testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test")
385+
warnings, errs := fwk.RunCustomValidationPlugins(jobSetTemplate, runtimeInfo, tc.oldObj, tc.newObj)
383386
if diff := cmp.Diff(tc.wantWarnings, warnings, cmpopts.SortSlices(func(a, b string) bool { return a < b })); len(diff) != 0 {
384387
t.Errorf("Unexpected warninigs (-want,+got):\n%s", diff)
385388
}

pkg/runtime/framework/interface.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121

2222
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2323
"k8s.io/apimachinery/pkg/util/validation/field"
24+
"sigs.k8s.io/controller-runtime/pkg/client"
2425
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
2526

2627
trainer "github.com/kubeflow/trainer/pkg/apis/trainer/v1alpha1"
@@ -33,7 +34,7 @@ type Plugin interface {
3334

3435
type CustomValidationPlugin interface {
3536
Plugin
36-
Validate(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList)
37+
Validate(runtimeJobTemplate client.Object, info *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList)
3738
}
3839

3940
type WatchExtensionPlugin interface {

pkg/runtime/framework/plugins/jobset/jobset.go

+43
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,15 @@ import (
2727
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2828
apiruntime "k8s.io/apimachinery/pkg/runtime"
2929
"k8s.io/apimachinery/pkg/runtime/schema"
30+
"k8s.io/apimachinery/pkg/util/sets"
31+
"k8s.io/apimachinery/pkg/util/validation/field"
3032
metav1ac "k8s.io/client-go/applyconfigurations/meta/v1"
3133
"k8s.io/utils/ptr"
3234
ctrl "sigs.k8s.io/controller-runtime"
3335
"sigs.k8s.io/controller-runtime/pkg/builder"
3436
"sigs.k8s.io/controller-runtime/pkg/cache"
3537
"sigs.k8s.io/controller-runtime/pkg/client"
38+
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
3639
jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2"
3740
jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2"
3841

@@ -53,6 +56,7 @@ var _ framework.WatchExtensionPlugin = (*JobSet)(nil)
5356
var _ framework.PodNetworkPlugin = (*JobSet)(nil)
5457
var _ framework.ComponentBuilderPlugin = (*JobSet)(nil)
5558
var _ framework.TerminalConditionPlugin = (*JobSet)(nil)
59+
var _ framework.CustomValidationPlugin = (*JobSet)(nil)
5660

5761
const Name = constants.JobSetKind
5862

@@ -71,6 +75,45 @@ func (j *JobSet) Name() string {
7175
return Name
7276
}
7377

78+
func (j *JobSet) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
79+
80+
var allErrs field.ErrorList
81+
specPath := field.NewPath("spec")
82+
runtimeRefPath := specPath.Child("runtimeRef")
83+
84+
jobSet, ok := runtimeJobTemplate.(*jobsetv1alpha2.JobSet)
85+
if !ok {
86+
return nil, nil
87+
}
88+
89+
rJobContainerNames := make(map[string]sets.Set[string])
90+
for _, rJob := range jobSet.Spec.ReplicatedJobs {
91+
rJobContainerNames[rJob.Name] = sets.New[string]()
92+
for _, c := range rJob.Template.Spec.Template.Spec.Containers {
93+
rJobContainerNames[rJob.Name].Insert(c.Name)
94+
}
95+
}
96+
97+
if newObj.Spec.ModelConfig != nil && newObj.Spec.ModelConfig.Input != nil {
98+
if containerSet, ok := rJobContainerNames[constants.JobInitializer]; !ok {
99+
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("must have %s job when trainJob is configured with input modelConfig", constants.JobInitializer)))
100+
} else if !containerSet.Has(constants.ContainerModelInitializer) {
101+
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("must have container with name - %s in the %s job", constants.ContainerModelInitializer, constants.JobInitializer)))
102+
}
103+
}
104+
105+
if newObj.Spec.DatasetConfig != nil {
106+
if containerSet, ok := rJobContainerNames[constants.JobInitializer]; !ok {
107+
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("must have %s job when trainJob is configured with input datasetConfig", constants.JobInitializer)))
108+
} else {
109+
if !containerSet.Has(constants.ContainerDatasetInitializer) {
110+
allErrs = append(allErrs, field.Invalid(runtimeRefPath, newObj.Spec.RuntimeRef, fmt.Sprintf("must have container with name - %s in the %s job", constants.ContainerDatasetInitializer, constants.JobInitializer)))
111+
}
112+
}
113+
}
114+
return nil, allErrs
115+
}
116+
74117
func (j *JobSet) ReconcilerBuilders() []runtime.ReconcilerBuilder {
75118
if _, err := j.restMapper.RESTMapping(
76119
schema.GroupKind{Group: jobsetv1alpha2.GroupVersion.Group, Kind: constants.JobSetKind},

pkg/runtime/framework/plugins/mpi/mpi.go

+16-3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131

3232
corev1 "k8s.io/api/core/v1"
3333
apiruntime "k8s.io/apimachinery/pkg/runtime"
34+
"k8s.io/apimachinery/pkg/util/intstr"
3435
"k8s.io/apimachinery/pkg/util/validation/field"
3536
corev1ac "k8s.io/client-go/applyconfigurations/core/v1"
3637
metav1ac "k8s.io/client-go/applyconfigurations/meta/v1"
@@ -75,13 +76,25 @@ func (m *MPI) Name() string {
7576
return Name
7677
}
7778

78-
// TODO: Need to implement validations for MPI Policy.
7979
// TODO (andreyvelich): Add validation to check that TrainJob doesn't have MPI envs.
8080
// TODO (andreyvelich): We should validate that envs from different plugins don't conflict with each other.
8181
// Ref: https://github.com/kubeflow/trainer/pull/2308#discussion_r1823229940
8282

83-
func (m *MPI) Validate(oldObj, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
84-
return nil, nil
83+
func (m *MPI) Validate(runtimeJobTemplate client.Object, runtimeInfo *runtime.Info, oldJobObj, newJobObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
84+
var allErrs field.ErrorList
85+
if runtimeInfo == nil || runtimeInfo.RuntimePolicy.MLPolicy == nil || runtimeInfo.RuntimePolicy.MLPolicy.MPI == nil {
86+
return nil, allErrs
87+
}
88+
89+
specPath := field.NewPath("spec")
90+
if newJobObj.Spec.Trainer != nil && newJobObj.Spec.Trainer.NumProcPerNode != nil {
91+
numProcPerNodePath := specPath.Child("trainer").Child("numProcPerNode")
92+
numProcPerNode := *newJobObj.Spec.Trainer.NumProcPerNode
93+
if numProcPerNode.Type != intstr.Int {
94+
allErrs = append(allErrs, field.Invalid(numProcPerNodePath, newJobObj.Spec.Trainer.NumProcPerNode, "must have an int value"))
95+
}
96+
}
97+
return nil, allErrs
8598
}
8699

87100
func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) error {

0 commit comments

Comments
 (0)