Skip to content

Commit 958b6f3

Browse files
committed
sdd: Re-implement with updated spec
This commit address the changes in the spec and implements the updated config Signed-off-by: Vaishnavi Hire <vhire@redhat.com> Assisted-by: Claude Opus 4.6
1 parent 42a5991 commit 958b6f3

25 files changed

Lines changed: 2170 additions & 864 deletions

api/v1alpha1/llamastackdistribution_conversion.go

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import (
2323

2424
v1alpha2 "github.com/llamastack/llama-stack-k8s-operator/api/v1alpha2"
2525
corev1 "k8s.io/api/core/v1"
26-
apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1"
2726
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2827
"sigs.k8s.io/controller-runtime/pkg/conversion"
2928
)
@@ -289,10 +288,9 @@ func convertToNetworkAccess(src *LlamaStackDistribution, n *v1alpha2.NetworkingS
289288
}
290289
hasContent := false
291290
if src.Spec.Network.ExposeRoute {
292-
if raw, err := json.Marshal(true); err == nil {
293-
n.Expose = &apiextensionsv1.JSON{Raw: raw}
294-
hasContent = true
295-
}
291+
enabled := true
292+
n.Expose = &v1alpha2.ExposeConfig{Enabled: &enabled}
293+
hasContent = true
296294
}
297295
if src.Spec.Network.AllowedFrom != nil {
298296
n.AllowedFrom = &v1alpha2.AllowedFromSpec{
@@ -469,8 +467,7 @@ func convertFromNetworking(src *v1alpha2.LlamaStackDistribution, dst *LlamaStack
469467
convertFromTLS(src, dst, n)
470468

471469
// Expose → ExposeRoute
472-
expose, _ := parseExpose(n.Expose)
473-
if expose {
470+
if n.Expose != nil && n.Expose.Enabled != nil && *n.Expose.Enabled {
474471
if dst.Spec.Network == nil {
475472
dst.Spec.Network = &NetworkSpec{}
476473
}
@@ -489,27 +486,6 @@ func convertFromNetworking(src *v1alpha2.LlamaStackDistribution, dst *LlamaStack
489486
}
490487
}
491488

492-
func parseExpose(raw *apiextensionsv1.JSON) (bool, string) {
493-
if raw == nil || len(raw.Raw) == 0 {
494-
return false, ""
495-
}
496-
var boolVal bool
497-
if err := json.Unmarshal(raw.Raw, &boolVal); err == nil {
498-
return boolVal, ""
499-
}
500-
var obj struct {
501-
Enabled *bool `json:"enabled,omitempty"`
502-
Hostname string `json:"hostname,omitempty"`
503-
}
504-
if err := json.Unmarshal(raw.Raw, &obj); err == nil {
505-
if obj.Enabled != nil {
506-
return *obj.Enabled, obj.Hostname
507-
}
508-
return true, obj.Hostname
509-
}
510-
return false, ""
511-
}
512-
513489
func convertFromStatus(src *v1alpha2.LlamaStackDistribution, dst *LlamaStackDistribution) {
514490
dst.Status.Phase = DistributionPhase(src.Status.Phase)
515491
dst.Status.Version = VersionInfo{

api/v1alpha1/llamastackdistribution_conversion_test.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ limitations under the License.
1717
package v1alpha1
1818

1919
import (
20-
"encoding/json"
2120
"testing"
2221

2322
v1alpha2 "github.com/llamastack/llama-stack-k8s-operator/api/v1alpha2"
@@ -179,10 +178,8 @@ func TestConvertToV1Alpha2WithNetwork(t *testing.T) {
179178

180179
require.NotNil(t, dst.Spec.Networking)
181180
require.NotNil(t, dst.Spec.Networking.Expose)
182-
183-
var exposed bool
184-
require.NoError(t, json.Unmarshal(dst.Spec.Networking.Expose.Raw, &exposed))
185-
assert.True(t, exposed)
181+
require.NotNil(t, dst.Spec.Networking.Expose.Enabled)
182+
assert.True(t, *dst.Spec.Networking.Expose.Enabled)
186183

187184
require.NotNil(t, dst.Spec.Networking.AllowedFrom)
188185
assert.Equal(t, []string{"ns1", "ns2"}, dst.Spec.Networking.AllowedFrom.Namespaces)

api/v1alpha2/llamastackdistribution_types.go

Lines changed: 149 additions & 34 deletions
Large diffs are not rendered by default.

api/v1alpha2/llamastackdistribution_webhook.go

Lines changed: 73 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ package v1alpha2
1818

1919
import (
2020
"context"
21-
"encoding/json"
2221
"fmt"
2322
"sort"
2423
"strings"
@@ -33,15 +32,23 @@ import (
3332
var llamastacklog = logf.Log.WithName("llamastackdistribution-webhook")
3433

3534
// LlamaStackDistributionValidator validates LlamaStackDistribution resources.
36-
type LlamaStackDistributionValidator struct{}
35+
type LlamaStackDistributionValidator struct {
36+
// EmbeddedDistributionNames is the list of known distribution names from
37+
// the embedded distribution registry. Injected at setup time to avoid
38+
// import cycles with pkg/config.
39+
EmbeddedDistributionNames []string
40+
}
3741

3842
var _ admission.CustomValidator = &LlamaStackDistributionValidator{}
3943

4044
// SetupWebhookWithManager registers the validating webhook.
41-
func SetupWebhookWithManager(mgr ctrl.Manager) error {
45+
// embeddedDistNames should be the result of config.EmbeddedDistributionNames().
46+
func SetupWebhookWithManager(mgr ctrl.Manager, embeddedDistNames []string) error {
4247
return ctrl.NewWebhookManagedBy(mgr).
4348
For(&LlamaStackDistribution{}).
44-
WithValidator(&LlamaStackDistributionValidator{}).
49+
WithValidator(&LlamaStackDistributionValidator{
50+
EmbeddedDistributionNames: embeddedDistNames,
51+
}).
4552
Complete()
4653
}
4754

@@ -55,7 +62,7 @@ func (v *LlamaStackDistributionValidator) ValidateCreate(_ context.Context, obj
5562
return nil, fmt.Errorf("failed to validate: expected *LlamaStackDistribution, got %T", obj)
5663
}
5764
llamastacklog.Info("validating create", "name", r.Name)
58-
return validate(r)
65+
return v.validate(r)
5966
}
6067

6168
// ValidateUpdate implements admission.CustomValidator.
@@ -65,65 +72,63 @@ func (v *LlamaStackDistributionValidator) ValidateUpdate(_ context.Context, _, n
6572
return nil, fmt.Errorf("failed to validate: expected *LlamaStackDistribution, got %T", newObj)
6673
}
6774
llamastacklog.Info("validating update", "name", r.Name)
68-
return validate(r)
75+
return v.validate(r)
6976
}
7077

7178
// ValidateDelete implements admission.CustomValidator.
7279
func (v *LlamaStackDistributionValidator) ValidateDelete(_ context.Context, _ runtime.Object) (admission.Warnings, error) {
7380
return nil, nil
7481
}
7582

76-
func validate(r *LlamaStackDistribution) (admission.Warnings, error) {
77-
var allErrs field.ErrorList
78-
var warnings admission.Warnings
79-
80-
if r.Spec.Providers != nil {
81-
if errs := validateProviderIDUniqueness(r.Spec.Providers); len(errs) > 0 {
82-
allErrs = append(allErrs, errs...)
83-
}
83+
func (v *LlamaStackDistributionValidator) validate(r *LlamaStackDistribution) (admission.Warnings, error) {
84+
allErrs := v.collectValidationErrors(r)
85+
if len(allErrs) > 0 {
86+
return nil, allErrs.ToAggregate()
8487
}
88+
return nil, nil
89+
}
8590

86-
if r.Spec.Resources != nil && r.Spec.Providers != nil {
87-
if errs := validateProviderReferences(r.Spec.Resources, r.Spec.Providers); len(errs) > 0 {
88-
allErrs = append(allErrs, errs...)
89-
}
91+
func (v *LlamaStackDistributionValidator) collectValidationErrors(r *LlamaStackDistribution) field.ErrorList {
92+
var allErrs field.ErrorList
93+
94+
if r.Spec.Distribution.Name != "" {
95+
allErrs = append(allErrs, validateDistributionName(r.Spec.Distribution.Name, v.EmbeddedDistributionNames)...)
9096
}
9197

92-
if len(r.Spec.Disabled) > 0 && r.Spec.Providers != nil {
93-
if warns := checkDisabledConflicts(r.Spec.Disabled, r.Spec.Providers); len(warns) > 0 {
94-
warnings = append(warnings, warns...)
95-
}
98+
if r.Spec.Providers != nil {
99+
allErrs = append(allErrs, validateProviderIDUniqueness(r.Spec.Providers)...)
96100
}
97101

98-
if len(allErrs) > 0 {
99-
return warnings, allErrs.ToAggregate()
102+
if r.Spec.Resources != nil && r.Spec.Providers != nil {
103+
allErrs = append(allErrs, validateProviderReferences(r.Spec.Resources, r.Spec.Providers)...)
100104
}
101105

102-
return warnings, nil
106+
return allErrs
103107
}
104108

109+
// validateProviderIDUniqueness ensures provider IDs are unique across all API types.
110+
// Per-slice uniqueness is handled by CEL; this validates cross-slice uniqueness.
105111
func validateProviderIDUniqueness(spec *ProvidersSpec) field.ErrorList {
106112
var errs field.ErrorList
107113
seenIDs := make(map[string]string)
108114

109115
fields := []struct {
110-
name string
111-
raw []byte
116+
name string
117+
configs []ProviderConfig
112118
}{
113-
{"inference", jsonRawBytes(spec.Inference)},
114-
{"safety", jsonRawBytes(spec.Safety)},
115-
{"vectorIo", jsonRawBytes(spec.VectorIo)},
116-
{"toolRuntime", jsonRawBytes(spec.ToolRuntime)},
117-
{"telemetry", jsonRawBytes(spec.Telemetry)},
119+
{"inference", spec.Inference},
120+
{"safety", spec.Safety},
121+
{"vectorIo", spec.VectorIo},
122+
{"toolRuntime", spec.ToolRuntime},
123+
{"telemetry", spec.Telemetry},
118124
}
119125

120126
for _, f := range fields {
121-
if len(f.raw) == 0 {
122-
continue
123-
}
124-
125-
ids := extractProviderIDs(f.raw)
126-
for _, id := range ids {
127+
for _, pc := range f.configs {
128+
id := pc.ID
129+
if id == "" {
130+
id = deriveProviderID(pc.Provider)
131+
}
127132
if existingAPI, exists := seenIDs[id]; exists {
128133
errs = append(errs, field.Invalid(
129134
field.NewPath("spec", "providers", f.name),
@@ -138,16 +143,13 @@ func validateProviderIDUniqueness(spec *ProvidersSpec) field.ErrorList {
138143
return errs
139144
}
140145

146+
// validateProviderReferences ensures model provider references point to configured providers.
141147
func validateProviderReferences(resources *ResourcesSpec, providers *ProvidersSpec) field.ErrorList {
142148
var errs field.ErrorList
143149

144150
providerIDs := collectAllProviderIDs(providers)
145151

146-
for i, raw := range resources.Models {
147-
var mc ModelConfig
148-
if err := json.Unmarshal(raw.Raw, &mc); err != nil {
149-
continue
150-
}
152+
for i, mc := range resources.Models {
151153
if mc.Provider != "" {
152154
if _, ok := providerIDs[mc.Provider]; !ok {
153155
errs = append(errs, field.Invalid(
@@ -162,70 +164,30 @@ func validateProviderReferences(resources *ResourcesSpec, providers *ProvidersSp
162164
return errs
163165
}
164166

165-
func checkDisabledConflicts(disabled []string, providers *ProvidersSpec) admission.Warnings {
166-
var warnings admission.Warnings
167-
168-
apiFieldMap := map[string][]byte{
169-
"inference": jsonRawBytes(providers.Inference),
170-
"safety": jsonRawBytes(providers.Safety),
171-
"vector_io": jsonRawBytes(providers.VectorIo),
172-
"tool_runtime": jsonRawBytes(providers.ToolRuntime),
173-
"telemetry": jsonRawBytes(providers.Telemetry),
174-
}
175-
176-
for _, api := range disabled {
177-
if raw, ok := apiFieldMap[api]; ok && len(raw) > 0 {
178-
warnings = append(warnings, fmt.Sprintf(
179-
"API %q is disabled but has providers configured; disabled takes precedence and provider config will be ignored",
180-
api,
181-
))
182-
}
183-
}
184-
185-
return warnings
186-
}
187-
188-
func jsonRawBytes(raw interface{ MarshalJSON() ([]byte, error) }) []byte {
189-
if raw == nil {
167+
// validateDistributionName validates that distribution.name is in the embedded
168+
// distribution registry.
169+
func validateDistributionName(name string, knownNames []string) field.ErrorList {
170+
if len(knownNames) == 0 {
190171
return nil
191172
}
192-
b, err := json.Marshal(raw)
193-
if err != nil || string(b) == "null" {
194-
return nil
195-
}
196-
return b
197-
}
198173

199-
func extractProviderIDs(raw []byte) []string {
200-
var single struct {
201-
ID string `json:"id"`
202-
Provider string `json:"provider"`
203-
}
204-
if err := json.Unmarshal(raw, &single); err == nil && single.Provider != "" {
205-
id := single.ID
206-
if id == "" {
207-
id = deriveProviderID(single.Provider)
174+
for _, n := range knownNames {
175+
if n == name {
176+
return nil
208177
}
209-
return []string{id}
210178
}
211179

212-
var list []struct {
213-
ID string `json:"id"`
214-
Provider string `json:"provider"`
215-
}
216-
if err := json.Unmarshal(raw, &list); err == nil {
217-
var ids []string
218-
for _, p := range list {
219-
id := p.ID
220-
if id == "" {
221-
id = deriveProviderID(p.Provider)
222-
}
223-
ids = append(ids, id)
224-
}
225-
return ids
226-
}
180+
sorted := make([]string, len(knownNames))
181+
copy(sorted, knownNames)
182+
sort.Strings(sorted)
227183

228-
return nil
184+
var errs field.ErrorList
185+
errs = append(errs, field.Invalid(
186+
field.NewPath("spec", "distribution", "name"),
187+
name,
188+
fmt.Sprintf("unknown distribution %q; available distributions: %s", name, strings.Join(sorted, ", ")),
189+
))
190+
return errs
229191
}
230192

231193
// deriveProviderID strips a "remote::" or similar prefix from a provider type
@@ -239,14 +201,18 @@ func deriveProviderID(providerType string) string {
239201

240202
func collectAllProviderIDs(spec *ProvidersSpec) map[string]bool {
241203
ids := make(map[string]bool)
242-
for _, raw := range [][]byte{
243-
jsonRawBytes(spec.Inference),
244-
jsonRawBytes(spec.Safety),
245-
jsonRawBytes(spec.VectorIo),
246-
jsonRawBytes(spec.ToolRuntime),
247-
jsonRawBytes(spec.Telemetry),
204+
for _, slice := range [][]ProviderConfig{
205+
spec.Inference,
206+
spec.Safety,
207+
spec.VectorIo,
208+
spec.ToolRuntime,
209+
spec.Telemetry,
248210
} {
249-
for _, id := range extractProviderIDs(raw) {
211+
for _, pc := range slice {
212+
id := pc.ID
213+
if id == "" {
214+
id = deriveProviderID(pc.Provider)
215+
}
250216
ids[id] = true
251217
}
252218
}

0 commit comments

Comments
 (0)