Skip to content

Commit 2c425e9

Browse files
authored
feat(loader): enhance single active backend by treating as singleton (#5107)
feat(loader): enhance single active backend by treating at singleton Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent c59975a commit 2c425e9

24 files changed

+92
-71
lines changed

core/application/application.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ type Application struct {
1616
func newApplication(appConfig *config.ApplicationConfig) *Application {
1717
return &Application{
1818
backendLoader: config.NewBackendConfigLoader(appConfig.ModelPath),
19-
modelLoader: model.NewModelLoader(appConfig.ModelPath),
19+
modelLoader: model.NewModelLoader(appConfig.ModelPath, appConfig.SingleBackend),
2020
applicationConfig: appConfig,
2121
templatesEvaluator: templates.NewEvaluator(appConfig.ModelPath),
2222
}

core/application/startup.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ func New(opts ...config.AppOption) (*Application, error) {
143143
}()
144144
}
145145

146-
if options.LoadToMemory != nil {
146+
if options.LoadToMemory != nil && !options.SingleBackend {
147147
for _, m := range options.LoadToMemory {
148148
cfg, err := application.BackendLoader().LoadBackendConfigFileByNameDefaultOptions(m, options)
149149
if err != nil {

core/backend/embeddings.go

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendCo
1717
if err != nil {
1818
return nil, err
1919
}
20+
defer loader.Close()
2021

2122
var fn func() ([]float32, error)
2223
switch model := inferenceModel.(type) {

core/backend/image.go

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
1616
if err != nil {
1717
return nil, err
1818
}
19+
defer loader.Close()
1920

2021
fn := func() error {
2122
_, err := inferenceModel.GenerateImage(

core/backend/llm.go

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
5353
if err != nil {
5454
return nil, err
5555
}
56+
defer loader.Close()
5657

5758
var protoMessages []*proto.Message
5859
// if we are using the tokenizer template, we need to convert the messages to proto messages

core/backend/options.go

+28-32
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,6 @@ func ModelOptions(c config.BackendConfig, so *config.ApplicationConfig, opts ...
4040
grpcOpts := grpcModelOpts(c)
4141
defOpts = append(defOpts, model.WithLoadGRPCLoadModelOpts(grpcOpts))
4242

43-
if so.SingleBackend {
44-
defOpts = append(defOpts, model.WithSingleActiveBackend())
45-
}
46-
4743
if so.ParallelBackendRequests {
4844
defOpts = append(defOpts, model.EnableParallelRequests)
4945
}
@@ -121,7 +117,7 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
121117
triggers := make([]*pb.GrammarTrigger, 0)
122118
for _, t := range c.FunctionsConfig.GrammarConfig.GrammarTriggers {
123119
triggers = append(triggers, &pb.GrammarTrigger{
124-
Word: t.Word,
120+
Word: t.Word,
125121
})
126122

127123
}
@@ -161,33 +157,33 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
161157
DisableLogStatus: c.DisableLogStatus,
162158
DType: c.DType,
163159
// LimitMMPerPrompt vLLM
164-
LimitImagePerPrompt: int32(c.LimitMMPerPrompt.LimitImagePerPrompt),
165-
LimitVideoPerPrompt: int32(c.LimitMMPerPrompt.LimitVideoPerPrompt),
166-
LimitAudioPerPrompt: int32(c.LimitMMPerPrompt.LimitAudioPerPrompt),
167-
MMProj: c.MMProj,
168-
FlashAttention: c.FlashAttention,
169-
CacheTypeKey: c.CacheTypeK,
170-
CacheTypeValue: c.CacheTypeV,
171-
NoKVOffload: c.NoKVOffloading,
172-
YarnExtFactor: c.YarnExtFactor,
173-
YarnAttnFactor: c.YarnAttnFactor,
174-
YarnBetaFast: c.YarnBetaFast,
175-
YarnBetaSlow: c.YarnBetaSlow,
176-
NGQA: c.NGQA,
177-
RMSNormEps: c.RMSNormEps,
178-
MLock: mmlock,
179-
RopeFreqBase: c.RopeFreqBase,
180-
RopeScaling: c.RopeScaling,
181-
Type: c.ModelType,
182-
RopeFreqScale: c.RopeFreqScale,
183-
NUMA: c.NUMA,
184-
Embeddings: embeddings,
185-
LowVRAM: lowVRAM,
186-
NGPULayers: int32(nGPULayers),
187-
MMap: mmap,
188-
MainGPU: c.MainGPU,
189-
Threads: int32(*c.Threads),
190-
TensorSplit: c.TensorSplit,
160+
LimitImagePerPrompt: int32(c.LimitMMPerPrompt.LimitImagePerPrompt),
161+
LimitVideoPerPrompt: int32(c.LimitMMPerPrompt.LimitVideoPerPrompt),
162+
LimitAudioPerPrompt: int32(c.LimitMMPerPrompt.LimitAudioPerPrompt),
163+
MMProj: c.MMProj,
164+
FlashAttention: c.FlashAttention,
165+
CacheTypeKey: c.CacheTypeK,
166+
CacheTypeValue: c.CacheTypeV,
167+
NoKVOffload: c.NoKVOffloading,
168+
YarnExtFactor: c.YarnExtFactor,
169+
YarnAttnFactor: c.YarnAttnFactor,
170+
YarnBetaFast: c.YarnBetaFast,
171+
YarnBetaSlow: c.YarnBetaSlow,
172+
NGQA: c.NGQA,
173+
RMSNormEps: c.RMSNormEps,
174+
MLock: mmlock,
175+
RopeFreqBase: c.RopeFreqBase,
176+
RopeScaling: c.RopeScaling,
177+
Type: c.ModelType,
178+
RopeFreqScale: c.RopeFreqScale,
179+
NUMA: c.NUMA,
180+
Embeddings: embeddings,
181+
LowVRAM: lowVRAM,
182+
NGPULayers: int32(nGPULayers),
183+
MMap: mmap,
184+
MainGPU: c.MainGPU,
185+
Threads: int32(*c.Threads),
186+
TensorSplit: c.TensorSplit,
191187
// AutoGPTQ
192188
ModelBaseName: c.AutoGPTQ.ModelBaseName,
193189
Device: c.AutoGPTQ.Device,

core/backend/rerank.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ import (
1212
func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
1313
opts := ModelOptions(backendConfig, appConfig)
1414
rerankModel, err := loader.Load(opts...)
15-
1615
if err != nil {
1716
return nil, err
1817
}
18+
defer loader.Close()
1919

2020
if rerankModel == nil {
2121
return nil, fmt.Errorf("could not load rerank model")

core/backend/soundgeneration.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ func SoundGeneration(
2626

2727
opts := ModelOptions(backendConfig, appConfig)
2828
soundGenModel, err := loader.Load(opts...)
29-
3029
if err != nil {
3130
return "", nil, err
3231
}
32+
defer loader.Close()
3333

3434
if soundGenModel == nil {
3535
return "", nil, fmt.Errorf("could not load sound generation model")

core/backend/token_metrics.go

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ func TokenMetrics(
2020
if err != nil {
2121
return nil, err
2222
}
23+
defer loader.Close()
2324

2425
if model == nil {
2526
return nil, fmt.Errorf("could not loadmodel model")

core/backend/tokenize.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.Bac
1414

1515
opts := ModelOptions(backendConfig, appConfig)
1616
inferenceModel, err = loader.Load(opts...)
17-
1817
if err != nil {
1918
return schema.TokenizeResponse{}, err
2019
}
20+
defer loader.Close()
2121

2222
predictOptions := gRPCPredictOpts(backendConfig, loader.ModelPath)
2323
predictOptions.Prompt = s

core/backend/transcript.go

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
2424
if err != nil {
2525
return nil, err
2626
}
27+
defer ml.Close()
2728

2829
if transcriptionModel == nil {
2930
return nil, fmt.Errorf("could not load transcription model")

core/backend/tts.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ func ModelTTS(
2323
) (string, *proto.Result, error) {
2424
opts := ModelOptions(backendConfig, appConfig, model.WithDefaultBackendString(model.PiperBackend))
2525
ttsModel, err := loader.Load(opts...)
26-
2726
if err != nil {
2827
return "", nil, err
2928
}
29+
defer loader.Close()
3030

3131
if ttsModel == nil {
3232
return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model)

core/backend/vad.go

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ func VAD(request *schema.VADRequest,
1919
if err != nil {
2020
return nil, err
2121
}
22+
defer ml.Close()
23+
2224
req := proto.VADRequest{
2325
Audio: request.Audio,
2426
}

core/cli/soundgeneration.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
7474
AssetsDestination: t.BackendAssetsPath,
7575
ExternalGRPCBackends: externalBackends,
7676
}
77-
ml := model.NewModelLoader(opts.ModelPath)
77+
ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend)
7878

7979
defer func() {
8080
err := ml.StopAllGRPC()

core/cli/transcript.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
3232
}
3333

3434
cl := config.NewBackendConfigLoader(t.ModelsPath)
35-
ml := model.NewModelLoader(opts.ModelPath)
35+
ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend)
3636
if err := cl.LoadBackendConfigsFromPath(t.ModelsPath); err != nil {
3737
return err
3838
}

core/cli/tts.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
4141
AudioDir: outputDir,
4242
AssetsDestination: t.BackendAssetsPath,
4343
}
44-
ml := model.NewModelLoader(opts.ModelPath)
44+
ml := model.NewModelLoader(opts.ModelPath, opts.SingleBackend)
4545

4646
defer func() {
4747
err := ml.StopAllGRPC()

core/http/endpoints/localai/stores.go

+4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
2121
if err != nil {
2222
return err
2323
}
24+
defer sl.Close()
2425

2526
vals := make([][]byte, len(input.Values))
2627
for i, v := range input.Values {
@@ -48,6 +49,7 @@ func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationCo
4849
if err != nil {
4950
return err
5051
}
52+
defer sl.Close()
5153

5254
if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil {
5355
return err
@@ -69,6 +71,7 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
6971
if err != nil {
7072
return err
7173
}
74+
defer sl.Close()
7275

7376
keys, vals, err := store.GetCols(c.Context(), sb, input.Keys)
7477
if err != nil {
@@ -100,6 +103,7 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf
100103
if err != nil {
101104
return err
102105
}
106+
defer sl.Close()
103107

104108
keys, vals, similarities, err := store.Find(c.Context(), sb, input.Key, input.Topk)
105109
if err != nil {

core/http/endpoints/openai/assistant_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func TestAssistantEndpoints(t *testing.T) {
4040
cl := &config.BackendConfigLoader{}
4141
//configsDir := "/tmp/localai/configs"
4242
modelPath := "/tmp/localai/model"
43-
var ml = model.NewModelLoader(modelPath)
43+
var ml = model.NewModelLoader(modelPath, false)
4444

4545
appConfig := &config.ApplicationConfig{
4646
ConfigsDir: configsDir,

core/http/routes/localai.go

+4-5
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,10 @@ func RegisterLocalAIRoutes(router *fiber.App,
5050
router.Post("/v1/vad", vadChain...)
5151

5252
// Stores
53-
sl := model.NewModelLoader("")
54-
router.Post("/stores/set", localai.StoresSetEndpoint(sl, appConfig))
55-
router.Post("/stores/delete", localai.StoresDeleteEndpoint(sl, appConfig))
56-
router.Post("/stores/get", localai.StoresGetEndpoint(sl, appConfig))
57-
router.Post("/stores/find", localai.StoresFindEndpoint(sl, appConfig))
53+
router.Post("/stores/set", localai.StoresSetEndpoint(ml, appConfig))
54+
router.Post("/stores/delete", localai.StoresDeleteEndpoint(ml, appConfig))
55+
router.Post("/stores/get", localai.StoresGetEndpoint(ml, appConfig))
56+
router.Post("/stores/find", localai.StoresFindEndpoint(ml, appConfig))
5857

5958
if !appConfig.DisableMetrics {
6059
router.Get("/metrics", localai.LocalAIMetricsEndpoint())

pkg/model/initializers.go

+20-1
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,23 @@ func (ml *ModelLoader) stopActiveBackends(modelID string, singleActiveBackend bo
509509
}
510510
}
511511

512+
func (ml *ModelLoader) Close() {
513+
if !ml.singletonMode {
514+
return
515+
}
516+
ml.singletonLock.Unlock()
517+
}
518+
519+
func (ml *ModelLoader) lockBackend() {
520+
if !ml.singletonMode {
521+
return
522+
}
523+
ml.singletonLock.Lock()
524+
}
525+
512526
func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
527+
ml.lockBackend() // grab the singleton lock if needed
528+
513529
o := NewOptions(opts...)
514530

515531
// Return earlier if we have a model already loaded
@@ -520,7 +536,7 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
520536
return m.GRPC(o.parallelRequests, ml.wd), nil
521537
}
522538

523-
ml.stopActiveBackends(o.modelID, o.singleActiveBackend)
539+
ml.stopActiveBackends(o.modelID, ml.singletonMode)
524540

525541
// if a backend is defined, return the loader directly
526542
if o.backendString != "" {
@@ -533,6 +549,7 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
533549
// get backends embedded in the binary
534550
autoLoadBackends, err := ml.ListAvailableBackends(o.assetDir)
535551
if err != nil {
552+
ml.Close() // we failed, release the lock
536553
return nil, err
537554
}
538555

@@ -564,5 +581,7 @@ func (ml *ModelLoader) Load(opts ...Option) (grpc.Backend, error) {
564581
}
565582
}
566583

584+
ml.Close() // make sure to release the lock in case of failure
585+
567586
return nil, fmt.Errorf("could not load model - all backends returned error: %s", err.Error())
568587
}

pkg/model/loader.go

+10-7
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,19 @@ import (
1818

1919
// TODO: Split ModelLoader and TemplateLoader? Just to keep things more organized. Left together to share a mutex until I look into that. Would split if we seperate directories for .bin/.yaml and .tmpl
2020
type ModelLoader struct {
21-
ModelPath string
22-
mu sync.Mutex
23-
models map[string]*Model
24-
wd *WatchDog
21+
ModelPath string
22+
mu sync.Mutex
23+
singletonLock sync.Mutex
24+
singletonMode bool
25+
models map[string]*Model
26+
wd *WatchDog
2527
}
2628

27-
func NewModelLoader(modelPath string) *ModelLoader {
29+
func NewModelLoader(modelPath string, singleActiveBackend bool) *ModelLoader {
2830
nml := &ModelLoader{
29-
ModelPath: modelPath,
30-
models: make(map[string]*Model),
31+
ModelPath: modelPath,
32+
models: make(map[string]*Model),
33+
singletonMode: singleActiveBackend,
3134
}
3235

3336
return nml

pkg/model/loader_options.go

+3-10
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@ type Options struct {
1717

1818
externalBackends map[string]string
1919

20-
grpcAttempts int
21-
grpcAttemptsDelay int
22-
singleActiveBackend bool
23-
parallelRequests bool
20+
grpcAttempts int
21+
grpcAttemptsDelay int
22+
parallelRequests bool
2423
}
2524

2625
type Option func(*Options)
@@ -88,12 +87,6 @@ func WithContext(ctx context.Context) Option {
8887
}
8988
}
9089

91-
func WithSingleActiveBackend() Option {
92-
return func(o *Options) {
93-
o.singleActiveBackend = true
94-
}
95-
}
96-
9790
func WithModelID(id string) Option {
9891
return func(o *Options) {
9992
o.modelID = id

pkg/model/loader_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ var _ = Describe("ModelLoader", func() {
2121
// Setup the model loader with a test directory
2222
modelPath = "/tmp/test_model_path"
2323
os.Mkdir(modelPath, 0755)
24-
modelLoader = model.NewModelLoader(modelPath)
24+
modelLoader = model.NewModelLoader(modelPath, false)
2525
})
2626

2727
AfterEach(func() {

0 commit comments

Comments
 (0)