1
1
package ai_model
2
2
3
3
import (
4
+ "context"
5
+ "errors"
4
6
"fmt"
7
+ "slices"
8
+ "time"
9
+
10
+ "gorm.io/gorm"
11
+
12
+ "github.com/APIParkLab/APIPark/service/cluster"
13
+
14
+ "github.com/APIParkLab/APIPark/gateway"
15
+
5
16
model_runtime "github.com/APIParkLab/APIPark/ai-provider/model-runtime"
6
17
model_dto "github.com/APIParkLab/APIPark/module/ai-model/dto"
7
18
"github.com/APIParkLab/APIPark/service/ai"
8
19
ai_api "github.com/APIParkLab/APIPark/service/ai-api"
9
20
ai_model "github.com/APIParkLab/APIPark/service/ai-model"
10
21
"github.com/gin-gonic/gin"
11
22
"github.com/google/uuid"
12
- "slices"
13
23
24
+ "github.com/eolinker/eosc/log"
14
25
"github.com/eolinker/go-common/store"
15
26
)
16
27
@@ -22,6 +33,7 @@ type imlProviderModelModule struct {
22
33
providerService ai.IProviderService `autowired:""`
23
34
aiApiService ai_api.IAPIService `autowired:""`
24
35
providerModelService ai_model.IProviderModelService `autowired:""`
36
+ clusterService cluster.IClusterService `autowired:""`
25
37
transaction store.ITransaction `autowired:""`
26
38
}
27
39
@@ -50,55 +62,89 @@ func (i *imlProviderModelModule) UpdateProviderModel(ctx *gin.Context, provider
50
62
return fmt .Errorf ("ai provider not found" )
51
63
}
52
64
// check provider exist
53
- providerInfo , err := i .providerService .Get (ctx , provider )
65
+ _ , err := i .providerService .Get (ctx , provider )
54
66
if err != nil {
55
67
return err
56
68
}
57
- if providerInfo == nil {
58
- return fmt .Errorf ("provider not found" )
59
- }
60
69
modelInfo , _ := i .providerModelService .Get (ctx , input .Id )
61
70
if modelInfo == nil || modelInfo .Provider != provider {
62
71
return fmt .Errorf ("model not found" )
63
72
}
64
- if err := i .providerModelService .Save (ctx , input .Id , & ai_model.Model {
65
- AccessConfiguration : & input .AccessConfiguration ,
66
- ModelParameters : & input .ModelParameters ,
67
- }); err != nil {
68
- return err
69
- }
70
- // update provider model
71
- iModel , _ := model_runtime .NewCustomizeModel (input .Id , input .Name , p .Logo (), input .AccessConfiguration , input .ModelParameters )
72
- p .SetModel (input .Id , iModel )
73
+ return i .transaction .Transaction (ctx , func (ctx context.Context ) error {
74
+ if err = i .providerModelService .Save (ctx , input .Id , & ai_model.Model {
75
+ AccessConfiguration : & input .AccessConfiguration ,
76
+ ModelParameters : & input .ModelParameters ,
77
+ }); err != nil {
78
+ return err
79
+ }
80
+
81
+ // update provider model
82
+ iModel , err := model_runtime .NewCustomizeModel (input .Id , input .Name , p .Logo (), input .AccessConfiguration , input .ModelParameters )
83
+ if err != nil {
84
+ return err
85
+ }
86
+ // 判断是否需要发布model
87
+ if p .GetModelConfig ().AccessConfigurationStatus {
88
+ if err := i .syncGateway (ctx , cluster .DefaultClusterID , []* gateway.DynamicRelease {
89
+ newModel (provider , input .Name , input .AccessConfiguration ),
90
+ }, true ); err != nil {
91
+ return err
92
+ }
93
+ }
94
+
95
+ p .SetModel (input .Id , iModel )
96
+ return nil
97
+ })
73
98
74
99
return nil
75
100
}
76
101
77
102
func (i * imlProviderModelModule ) DeleteProviderModel (ctx * gin.Context , provider string , id string ) error {
78
103
p , has := model_runtime .GetProvider (provider )
104
+ if ! has {
105
+ return fmt .Errorf ("ai provider not found" )
106
+ }
79
107
// check provider exist
80
- providerInfo , err := i .providerService .Get (ctx , provider )
108
+ _ , err := i .providerService .Get (ctx , provider )
81
109
if err != nil {
110
+ if ! errors .Is (err , gorm .ErrRecordNotFound ) {
111
+ return fmt .Errorf ("provider not found" )
112
+ }
82
113
return err
83
114
}
84
- if providerInfo == nil || ! has {
85
- return fmt .Errorf ("provider not found" )
86
- }
87
115
modelInfo , _ := i .providerModelService .Get (ctx , id )
88
116
if modelInfo == nil || modelInfo .Provider != provider {
89
117
return fmt .Errorf ("model not found" )
90
118
}
91
- // check model in use
92
- countMapByModel , _ := i .aiApiService .CountMapByModel (ctx , "" , map [string ]interface {}{"model" : id })
93
- if countValue , has := countMapByModel [id ]; has && countValue > 0 {
94
- return fmt .Errorf ("model in use" )
95
- }
96
- if err := i .providerModelService .Delete (ctx , id ); err != nil {
97
- return err
98
- }
99
- p .RemoveModel (id )
119
+ return i .transaction .Transaction (ctx , func (ctx context.Context ) error {
120
+ // check model in use
121
+ count , err := i .aiApiService .CountByModel (ctx , id )
122
+ if err != nil {
123
+ return err
124
+ }
125
+ if count > 0 {
126
+ return fmt .Errorf ("model in use" )
127
+ }
128
+ if err := i .providerModelService .Delete (ctx , id ); err != nil {
129
+ return err
130
+ }
131
+ err = i .syncGateway (ctx , cluster .DefaultClusterID , []* gateway.DynamicRelease {
132
+ {
133
+ BasicItem : & gateway.BasicItem {
134
+ ID : fmt .Sprintf ("%s#%s" , provider , modelInfo .Name ),
135
+ Resource : "ai-model" ,
136
+ },
137
+ Attr : nil ,
138
+ },
139
+ }, false )
140
+ if err != nil {
141
+ return err
142
+ }
143
+
144
+ p .RemoveModel (id )
145
+ return nil
146
+ })
100
147
101
- return nil
102
148
}
103
149
104
150
func (i * imlProviderModelModule ) AddProviderModel (ctx * gin.Context , provider string , input * model_dto.Model ) (* model_dto.SimpleModel , error ) {
@@ -115,21 +161,90 @@ func (i *imlProviderModelModule) AddProviderModel(ctx *gin.Context, provider str
115
161
return nil , fmt .Errorf ("provider model already exist" )
116
162
}
117
163
id := uuid .New ().String ()
118
- typeValue := "chat"
119
- if err := i .providerModelService .Save (ctx , id , & ai_model.Model {
120
- Name : & input .Name ,
121
- Type : & typeValue ,
122
- Provider : & provider ,
123
- AccessConfiguration : & input .AccessConfiguration ,
124
- ModelParameters : & input .ModelParameters ,
125
- }); err != nil {
164
+ err := i .transaction .Transaction (ctx , func (ctx context.Context ) error {
165
+ typeValue := "chat"
166
+ err := i .providerModelService .Save (ctx , id , & ai_model.Model {
167
+ Name : & input .Name ,
168
+ Type : & typeValue ,
169
+ Provider : & provider ,
170
+ AccessConfiguration : & input .AccessConfiguration ,
171
+ ModelParameters : & input .ModelParameters ,
172
+ })
173
+ if err != nil {
174
+ return err
175
+ }
176
+ // update provider model
177
+ iModel , err := model_runtime .NewCustomizeModel (id , input .Name , p .Logo (), input .AccessConfiguration , input .ModelParameters )
178
+ if err != nil {
179
+ return err
180
+ }
181
+ // 判断是否需要发布model
182
+ if p .GetModelConfig ().AccessConfigurationStatus {
183
+ if err := i .syncGateway (ctx , cluster .DefaultClusterID , []* gateway.DynamicRelease {
184
+ newModel (provider , input .Name , input .AccessConfiguration ),
185
+ }, true ); err != nil {
186
+ return err
187
+ }
188
+ }
189
+
190
+ p .SetModel (id , iModel )
191
+ return nil
192
+ })
193
+ if err != nil {
126
194
return nil , err
127
195
}
128
- // update provider model
129
- iModel , _ := model_runtime .NewCustomizeModel (id , input .Name , p .Logo (), input .AccessConfiguration , input .ModelParameters )
130
- p .SetModel (id , iModel )
131
196
return & model_dto.SimpleModel {
132
197
Id : id ,
133
198
Name : input .Name ,
134
199
}, nil
135
200
}
201
+
202
+ func newModel (provider string , model string , config string ) * gateway.DynamicRelease {
203
+
204
+ return & gateway.DynamicRelease {
205
+ BasicItem : & gateway.BasicItem {
206
+ ID : fmt .Sprintf ("%s$%s" , provider , model ),
207
+ Description : fmt .Sprintf ("auto generated model: %s, provider: %s" , model , provider ),
208
+ Resource : "ai-model" ,
209
+ Version : time .Now ().Format ("20060102150405" ),
210
+ MatchLabels : map [string ]string {
211
+ "module" : "ai-model" ,
212
+ },
213
+ },
214
+ Attr : map [string ]interface {}{
215
+ "provider" : provider ,
216
+ "model" : model ,
217
+ "access_config" : config ,
218
+ },
219
+ }
220
+ }
221
+
222
+ func (i * imlProviderModelModule ) syncGateway (ctx context.Context , clusterId string , releases []* gateway.DynamicRelease , online bool ) error {
223
+ client , err := i .clusterService .GatewayClient (ctx , clusterId )
224
+ if err != nil {
225
+ log .Errorf ("get apinto client error: %v" , err )
226
+ return nil
227
+ }
228
+ defer func () {
229
+ err := client .Close (ctx )
230
+ if err != nil {
231
+ log .Warn ("close apinto client:" , err )
232
+ }
233
+ }()
234
+ for _ , releaseInfo := range releases {
235
+ dynamicClient , err := client .Dynamic (releaseInfo .Resource )
236
+ if err != nil {
237
+ return err
238
+ }
239
+ if online {
240
+ err = dynamicClient .Online (ctx , releaseInfo )
241
+ } else {
242
+ err = dynamicClient .Offline (ctx , releaseInfo )
243
+ }
244
+ if err != nil {
245
+ return err
246
+ }
247
+ }
248
+
249
+ return nil
250
+ }
0 commit comments