Skip to content

Commit 9c530ec

Browse files
committed
Synchronize to gateway when custom model updates
1 parent 358459f commit 9c530ec

File tree

4 files changed

+162
-40
lines changed

4 files changed

+162
-40
lines changed

ai-provider/model-runtime/model-providers/bedrock/bedrock.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ provider_credential_schema:
8787
placeholder:
8888
en_US: A model you have access to (e.g. amazon.titan-text-lite-v1) for validation.
8989
zh_Hans: 为了进行验证,请输入一个您可用的模型名称 (例如:amazon.titan-text-lite-v1)
90+
model_config:
91+
access_configuration_status: true
92+
access_configuration_demo: "{}"
9093
address: https://bedrock-runtime.amazonaws.com
9194
sort: 4
9295
recommend: true

ai-provider/model-runtime/model-providers/groq/groq.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ provider_credential_schema:
3232
en_US: Enter your API Key
3333
- variable: base_url
3434
label:
35-
en_US: https://router.huggingface.co/hf-inference/v1
35+
en_US: https://api.groq.com/openai/v1
3636
type: text-input
3737
required: false
3838
placeholder:

gateway/profession.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ var dynamicResourceMap = map[string]Worker{
7070
Profession: ProfessionAIResource,
7171
Driver: "ai-key",
7272
},
73+
"ai-model": {
74+
Profession: ProfessionAIResource,
75+
Driver: "ai-model",
76+
},
7377
}
7478

7579
type Worker struct {

module/ai-model/iml.go

Lines changed: 154 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,27 @@
11
package ai_model
22

33
import (
4+
"context"
5+
"errors"
46
"fmt"
7+
"slices"
8+
"time"
9+
10+
"gorm.io/gorm"
11+
12+
"github.com/APIParkLab/APIPark/service/cluster"
13+
14+
"github.com/APIParkLab/APIPark/gateway"
15+
516
model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime"
617
model_dto "github.com/APIParkLab/APIPark/module/ai-model/dto"
718
"github.com/APIParkLab/APIPark/service/ai"
819
ai_api "github.com/APIParkLab/APIPark/service/ai-api"
920
ai_model "github.com/APIParkLab/APIPark/service/ai-model"
1021
"github.com/gin-gonic/gin"
1122
"github.com/google/uuid"
12-
"slices"
1323

24+
"github.com/eolinker/eosc/log"
1425
"github.com/eolinker/go-common/store"
1526
)
1627

@@ -22,6 +33,7 @@ type imlProviderModelModule struct {
2233
providerService ai.IProviderService `autowired:""`
2334
aiApiService ai_api.IAPIService `autowired:""`
2435
providerModelService ai_model.IProviderModelService `autowired:""`
36+
clusterService cluster.IClusterService `autowired:""`
2537
transaction store.ITransaction `autowired:""`
2638
}
2739

@@ -50,55 +62,89 @@ func (i *imlProviderModelModule) UpdateProviderModel(ctx *gin.Context, provider
5062
return fmt.Errorf("ai provider not found")
5163
}
5264
// check provider exist
53-
providerInfo, err := i.providerService.Get(ctx, provider)
65+
_, err := i.providerService.Get(ctx, provider)
5466
if err != nil {
5567
return err
5668
}
57-
if providerInfo == nil {
58-
return fmt.Errorf("provider not found")
59-
}
6069
modelInfo, _ := i.providerModelService.Get(ctx, input.Id)
6170
if modelInfo == nil || modelInfo.Provider != provider {
6271
return fmt.Errorf("model not found")
6372
}
64-
if err := i.providerModelService.Save(ctx, input.Id, &ai_model.Model{
65-
AccessConfiguration: &input.AccessConfiguration,
66-
ModelParameters: &input.ModelParameters,
67-
}); err != nil {
68-
return err
69-
}
70-
// update provider model
71-
iModel, _ := model_runtime.NewCustomizeModel(input.Id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters)
72-
p.SetModel(input.Id, iModel)
73+
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
74+
if err = i.providerModelService.Save(ctx, input.Id, &ai_model.Model{
75+
AccessConfiguration: &input.AccessConfiguration,
76+
ModelParameters: &input.ModelParameters,
77+
}); err != nil {
78+
return err
79+
}
80+
81+
// update provider model
82+
iModel, err := model_runtime.NewCustomizeModel(input.Id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters)
83+
if err != nil {
84+
return err
85+
}
86+
// 判断是否需要发布model
87+
if p.GetModelConfig().AccessConfigurationStatus {
88+
if err := i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
89+
newModel(provider, input.Name, input.AccessConfiguration),
90+
}, true); err != nil {
91+
return err
92+
}
93+
}
94+
95+
p.SetModel(input.Id, iModel)
96+
return nil
97+
})
7398

7499
return nil
75100
}
76101

77102
func (i *imlProviderModelModule) DeleteProviderModel(ctx *gin.Context, provider string, id string) error {
78103
p, has := model_runtime.GetProvider(provider)
104+
if !has {
105+
return fmt.Errorf("ai provider not found")
106+
}
79107
// check provider exist
80-
providerInfo, err := i.providerService.Get(ctx, provider)
108+
_, err := i.providerService.Get(ctx, provider)
81109
if err != nil {
110+
if !errors.Is(err, gorm.ErrRecordNotFound) {
111+
return fmt.Errorf("provider not found")
112+
}
82113
return err
83114
}
84-
if providerInfo == nil || !has {
85-
return fmt.Errorf("provider not found")
86-
}
87115
modelInfo, _ := i.providerModelService.Get(ctx, id)
88116
if modelInfo == nil || modelInfo.Provider != provider {
89117
return fmt.Errorf("model not found")
90118
}
91-
// check model in use
92-
countMapByModel, _ := i.aiApiService.CountMapByModel(ctx, "", map[string]interface{}{"model": id})
93-
if countValue, has := countMapByModel[id]; has && countValue > 0 {
94-
return fmt.Errorf("model in use")
95-
}
96-
if err := i.providerModelService.Delete(ctx, id); err != nil {
97-
return err
98-
}
99-
p.RemoveModel(id)
119+
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
120+
// check model in use
121+
count, err := i.aiApiService.CountByModel(ctx, id)
122+
if err != nil {
123+
return err
124+
}
125+
if count > 0 {
126+
return fmt.Errorf("model in use")
127+
}
128+
if err := i.providerModelService.Delete(ctx, id); err != nil {
129+
return err
130+
}
131+
err = i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
132+
{
133+
BasicItem: &gateway.BasicItem{
134+
ID: fmt.Sprintf("%s#%s", provider, modelInfo.Name),
135+
Resource: "ai-model",
136+
},
137+
Attr: nil,
138+
},
139+
}, false)
140+
if err != nil {
141+
return err
142+
}
143+
144+
p.RemoveModel(id)
145+
return nil
146+
})
100147

101-
return nil
102148
}
103149

104150
func (i *imlProviderModelModule) AddProviderModel(ctx *gin.Context, provider string, input *model_dto.Model) (*model_dto.SimpleModel, error) {
@@ -115,21 +161,90 @@ func (i *imlProviderModelModule) AddProviderModel(ctx *gin.Context, provider str
115161
return nil, fmt.Errorf("provider model already exist")
116162
}
117163
id := uuid.New().String()
118-
typeValue := "chat"
119-
if err := i.providerModelService.Save(ctx, id, &ai_model.Model{
120-
Name: &input.Name,
121-
Type: &typeValue,
122-
Provider: &provider,
123-
AccessConfiguration: &input.AccessConfiguration,
124-
ModelParameters: &input.ModelParameters,
125-
}); err != nil {
164+
err := i.transaction.Transaction(ctx, func(ctx context.Context) error {
165+
typeValue := "chat"
166+
err := i.providerModelService.Save(ctx, id, &ai_model.Model{
167+
Name: &input.Name,
168+
Type: &typeValue,
169+
Provider: &provider,
170+
AccessConfiguration: &input.AccessConfiguration,
171+
ModelParameters: &input.ModelParameters,
172+
})
173+
if err != nil {
174+
return err
175+
}
176+
// update provider model
177+
iModel, err := model_runtime.NewCustomizeModel(id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters)
178+
if err != nil {
179+
return err
180+
}
181+
// 判断是否需要发布model
182+
if p.GetModelConfig().AccessConfigurationStatus {
183+
if err := i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{
184+
newModel(provider, input.Name, input.AccessConfiguration),
185+
}, true); err != nil {
186+
return err
187+
}
188+
}
189+
190+
p.SetModel(id, iModel)
191+
return nil
192+
})
193+
if err != nil {
126194
return nil, err
127195
}
128-
// update provider model
129-
iModel, _ := model_runtime.NewCustomizeModel(id, input.Name, p.Logo(), input.AccessConfiguration, input.ModelParameters)
130-
p.SetModel(id, iModel)
131196
return &model_dto.SimpleModel{
132197
Id: id,
133198
Name: input.Name,
134199
}, nil
135200
}
201+
202+
func newModel(provider string, model string, config string) *gateway.DynamicRelease {
203+
204+
return &gateway.DynamicRelease{
205+
BasicItem: &gateway.BasicItem{
206+
ID: fmt.Sprintf("%s$%s", provider, model),
207+
Description: fmt.Sprintf("auto generated model: %s, provider: %s", model, provider),
208+
Resource: "ai-model",
209+
Version: time.Now().Format("20060102150405"),
210+
MatchLabels: map[string]string{
211+
"module": "ai-model",
212+
},
213+
},
214+
Attr: map[string]interface{}{
215+
"provider": provider,
216+
"model": model,
217+
"access_config": config,
218+
},
219+
}
220+
}
221+
222+
func (i *imlProviderModelModule) syncGateway(ctx context.Context, clusterId string, releases []*gateway.DynamicRelease, online bool) error {
223+
client, err := i.clusterService.GatewayClient(ctx, clusterId)
224+
if err != nil {
225+
log.Errorf("get apinto client error: %v", err)
226+
return nil
227+
}
228+
defer func() {
229+
err := client.Close(ctx)
230+
if err != nil {
231+
log.Warn("close apinto client:", err)
232+
}
233+
}()
234+
for _, releaseInfo := range releases {
235+
dynamicClient, err := client.Dynamic(releaseInfo.Resource)
236+
if err != nil {
237+
return err
238+
}
239+
if online {
240+
err = dynamicClient.Online(ctx, releaseInfo)
241+
} else {
242+
err = dynamicClient.Offline(ctx, releaseInfo)
243+
}
244+
if err != nil {
245+
return err
246+
}
247+
}
248+
249+
return nil
250+
}

0 commit comments

Comments
 (0)