@@ -18,7 +18,6 @@ package v1alpha2
1818
1919import (
2020 "context"
21- "encoding/json"
2221 "fmt"
2322 "sort"
2423 "strings"
@@ -33,15 +32,23 @@ import (
3332var llamastacklog = logf .Log .WithName ("llamastackdistribution-webhook" )
3433
3534// LlamaStackDistributionValidator validates LlamaStackDistribution resources.
36- type LlamaStackDistributionValidator struct {}
35+ type LlamaStackDistributionValidator struct {
36+ // EmbeddedDistributionNames is the list of known distribution names from
37+ // the embedded distribution registry. Injected at setup time to avoid
38+ // import cycles with pkg/config.
39+ EmbeddedDistributionNames []string
40+ }
3741
3842var _ admission.CustomValidator = & LlamaStackDistributionValidator {}
3943
4044// SetupWebhookWithManager registers the validating webhook.
41- func SetupWebhookWithManager (mgr ctrl.Manager ) error {
45+ // embeddedDistNames should be the result of config.EmbeddedDistributionNames().
46+ func SetupWebhookWithManager (mgr ctrl.Manager , embeddedDistNames []string ) error {
4247 return ctrl .NewWebhookManagedBy (mgr ).
4348 For (& LlamaStackDistribution {}).
44- WithValidator (& LlamaStackDistributionValidator {}).
49+ WithValidator (& LlamaStackDistributionValidator {
50+ EmbeddedDistributionNames : embeddedDistNames ,
51+ }).
4552 Complete ()
4653}
4754
@@ -55,7 +62,7 @@ func (v *LlamaStackDistributionValidator) ValidateCreate(_ context.Context, obj
5562 return nil , fmt .Errorf ("failed to validate: expected *LlamaStackDistribution, got %T" , obj )
5663 }
5764 llamastacklog .Info ("validating create" , "name" , r .Name )
58- return validate (r )
65+ return v . validate (r )
5966}
6067
6168// ValidateUpdate implements admission.CustomValidator.
@@ -65,65 +72,63 @@ func (v *LlamaStackDistributionValidator) ValidateUpdate(_ context.Context, _, n
6572 return nil , fmt .Errorf ("failed to validate: expected *LlamaStackDistribution, got %T" , newObj )
6673 }
6774 llamastacklog .Info ("validating update" , "name" , r .Name )
68- return validate (r )
75+ return v . validate (r )
6976}
7077
7178// ValidateDelete implements admission.CustomValidator.
7279func (v * LlamaStackDistributionValidator ) ValidateDelete (_ context.Context , _ runtime.Object ) (admission.Warnings , error ) {
7380 return nil , nil
7481}
7582
76- func validate (r * LlamaStackDistribution ) (admission.Warnings , error ) {
77- var allErrs field.ErrorList
78- var warnings admission.Warnings
79-
80- if r .Spec .Providers != nil {
81- if errs := validateProviderIDUniqueness (r .Spec .Providers ); len (errs ) > 0 {
82- allErrs = append (allErrs , errs ... )
83- }
83+ func (v * LlamaStackDistributionValidator ) validate (r * LlamaStackDistribution ) (admission.Warnings , error ) {
84+ allErrs := v .collectValidationErrors (r )
85+ if len (allErrs ) > 0 {
86+ return nil , allErrs .ToAggregate ()
8487 }
88+ return nil , nil
89+ }
8590
86- if r .Spec .Resources != nil && r .Spec .Providers != nil {
87- if errs := validateProviderReferences (r .Spec .Resources , r .Spec .Providers ); len (errs ) > 0 {
88- allErrs = append (allErrs , errs ... )
89- }
91+ func (v * LlamaStackDistributionValidator ) collectValidationErrors (r * LlamaStackDistribution ) field.ErrorList {
92+ var allErrs field.ErrorList
93+
94+ if r .Spec .Distribution .Name != "" {
95+ allErrs = append (allErrs , validateDistributionName (r .Spec .Distribution .Name , v .EmbeddedDistributionNames )... )
9096 }
9197
92- if len (r .Spec .Disabled ) > 0 && r .Spec .Providers != nil {
93- if warns := checkDisabledConflicts (r .Spec .Disabled , r .Spec .Providers ); len (warns ) > 0 {
94- warnings = append (warnings , warns ... )
95- }
98+ if r .Spec .Providers != nil {
99+ allErrs = append (allErrs , validateProviderIDUniqueness (r .Spec .Providers )... )
96100 }
97101
98- if len ( allErrs ) > 0 {
99- return warnings , allErrs . ToAggregate ( )
102+ if r . Spec . Resources != nil && r . Spec . Providers != nil {
103+ allErrs = append ( allErrs , validateProviderReferences ( r . Spec . Resources , r . Spec . Providers ) ... )
100104 }
101105
102- return warnings , nil
106+ return allErrs
103107}
104108
109+ // validateProviderIDUniqueness ensures provider IDs are unique across all API types.
110+ // Per-slice uniqueness is handled by CEL; this validates cross-slice uniqueness.
105111func validateProviderIDUniqueness (spec * ProvidersSpec ) field.ErrorList {
106112 var errs field.ErrorList
107113 seenIDs := make (map [string ]string )
108114
109115 fields := []struct {
110- name string
111- raw [] byte
116+ name string
117+ configs [] ProviderConfig
112118 }{
113- {"inference" , jsonRawBytes ( spec .Inference ) },
114- {"safety" , jsonRawBytes ( spec .Safety ) },
115- {"vectorIo" , jsonRawBytes ( spec .VectorIo ) },
116- {"toolRuntime" , jsonRawBytes ( spec .ToolRuntime ) },
117- {"telemetry" , jsonRawBytes ( spec .Telemetry ) },
119+ {"inference" , spec .Inference },
120+ {"safety" , spec .Safety },
121+ {"vectorIo" , spec .VectorIo },
122+ {"toolRuntime" , spec .ToolRuntime },
123+ {"telemetry" , spec .Telemetry },
118124 }
119125
120126 for _ , f := range fields {
121- if len (f .raw ) == 0 {
122- continue
123- }
124-
125- ids := extractProviderIDs (f .raw )
126- for _ , id := range ids {
127+ for _ , pc := range f .configs {
128+ id := pc .ID
129+ if id == "" {
130+ id = deriveProviderID (pc .Provider )
131+ }
127132 if existingAPI , exists := seenIDs [id ]; exists {
128133 errs = append (errs , field .Invalid (
129134 field .NewPath ("spec" , "providers" , f .name ),
@@ -138,16 +143,13 @@ func validateProviderIDUniqueness(spec *ProvidersSpec) field.ErrorList {
138143 return errs
139144}
140145
146+ // validateProviderReferences ensures model provider references point to configured providers.
141147func validateProviderReferences (resources * ResourcesSpec , providers * ProvidersSpec ) field.ErrorList {
142148 var errs field.ErrorList
143149
144150 providerIDs := collectAllProviderIDs (providers )
145151
146- for i , raw := range resources .Models {
147- var mc ModelConfig
148- if err := json .Unmarshal (raw .Raw , & mc ); err != nil {
149- continue
150- }
152+ for i , mc := range resources .Models {
151153 if mc .Provider != "" {
152154 if _ , ok := providerIDs [mc .Provider ]; ! ok {
153155 errs = append (errs , field .Invalid (
@@ -162,70 +164,30 @@ func validateProviderReferences(resources *ResourcesSpec, providers *ProvidersSp
162164 return errs
163165}
164166
165- func checkDisabledConflicts (disabled []string , providers * ProvidersSpec ) admission.Warnings {
166- var warnings admission.Warnings
167-
168- apiFieldMap := map [string ][]byte {
169- "inference" : jsonRawBytes (providers .Inference ),
170- "safety" : jsonRawBytes (providers .Safety ),
171- "vector_io" : jsonRawBytes (providers .VectorIo ),
172- "tool_runtime" : jsonRawBytes (providers .ToolRuntime ),
173- "telemetry" : jsonRawBytes (providers .Telemetry ),
174- }
175-
176- for _ , api := range disabled {
177- if raw , ok := apiFieldMap [api ]; ok && len (raw ) > 0 {
178- warnings = append (warnings , fmt .Sprintf (
179- "API %q is disabled but has providers configured; disabled takes precedence and provider config will be ignored" ,
180- api ,
181- ))
182- }
183- }
184-
185- return warnings
186- }
187-
188- func jsonRawBytes (raw interface { MarshalJSON () ([]byte , error ) }) []byte {
189- if raw == nil {
167+ // validateDistributionName validates that distribution.name is in the embedded
168+ // distribution registry.
169+ func validateDistributionName (name string , knownNames []string ) field.ErrorList {
170+ if len (knownNames ) == 0 {
190171 return nil
191172 }
192- b , err := json .Marshal (raw )
193- if err != nil || string (b ) == "null" {
194- return nil
195- }
196- return b
197- }
198173
199- func extractProviderIDs (raw []byte ) []string {
200- var single struct {
201- ID string `json:"id"`
202- Provider string `json:"provider"`
203- }
204- if err := json .Unmarshal (raw , & single ); err == nil && single .Provider != "" {
205- id := single .ID
206- if id == "" {
207- id = deriveProviderID (single .Provider )
174+ for _ , n := range knownNames {
175+ if n == name {
176+ return nil
208177 }
209- return []string {id }
210178 }
211179
212- var list []struct {
213- ID string `json:"id"`
214- Provider string `json:"provider"`
215- }
216- if err := json .Unmarshal (raw , & list ); err == nil {
217- var ids []string
218- for _ , p := range list {
219- id := p .ID
220- if id == "" {
221- id = deriveProviderID (p .Provider )
222- }
223- ids = append (ids , id )
224- }
225- return ids
226- }
180+ sorted := make ([]string , len (knownNames ))
181+ copy (sorted , knownNames )
182+ sort .Strings (sorted )
227183
228- return nil
184+ var errs field.ErrorList
185+ errs = append (errs , field .Invalid (
186+ field .NewPath ("spec" , "distribution" , "name" ),
187+ name ,
188+ fmt .Sprintf ("unknown distribution %q; available distributions: %s" , name , strings .Join (sorted , ", " )),
189+ ))
190+ return errs
229191}
230192
231193// deriveProviderID strips a "remote::" or similar prefix from a provider type
@@ -239,14 +201,18 @@ func deriveProviderID(providerType string) string {
239201
240202func collectAllProviderIDs (spec * ProvidersSpec ) map [string ]bool {
241203 ids := make (map [string ]bool )
242- for _ , raw := range [][]byte {
243- jsonRawBytes ( spec .Inference ) ,
244- jsonRawBytes ( spec .Safety ) ,
245- jsonRawBytes ( spec .VectorIo ) ,
246- jsonRawBytes ( spec .ToolRuntime ) ,
247- jsonRawBytes ( spec .Telemetry ) ,
204+ for _ , slice := range [][]ProviderConfig {
205+ spec .Inference ,
206+ spec .Safety ,
207+ spec .VectorIo ,
208+ spec .ToolRuntime ,
209+ spec .Telemetry ,
248210 } {
249- for _ , id := range extractProviderIDs (raw ) {
211+ for _ , pc := range slice {
212+ id := pc .ID
213+ if id == "" {
214+ id = deriveProviderID (pc .Provider )
215+ }
250216 ids [id ] = true
251217 }
252218 }
0 commit comments