Skip to content

Commit cc9aa9e

Browse files
authored
feat: add /models/apply endpoint to prepare models (#286)
1 parent 5617e50 commit cc9aa9e

23 files changed

+556
-33
lines changed

Makefile

+3-3
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,11 @@ test-models/testmodel:
211211
wget https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav
212212
wget https://huggingface.co/imxcstar/rwkv-4-raven-ggml/resolve/main/RWKV-4-Raven-1B5-v11-Eng99%25-Other1%25-20230425-ctx4096-16_Q4_2.bin -O test-models/rwkv
213213
wget https://raw.githubusercontent.com/saharNooby/rwkv.cpp/5eb8f09c146ea8124633ab041d9ea0b1f1db4459/rwkv/20B_tokenizer.json -O test-models/rwkv.tokenizer.json
214-
cp tests/fixtures/* test-models
214+
cp tests/models_fixtures/* test-models
215215

216216
test: prepare test-models/testmodel
217-
cp tests/fixtures/* test-models
218-
@C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo -v -r ./api
217+
cp tests/models_fixtures/* test-models
218+
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo -v -r ./api ./pkg
219219

220220
## Help:
221221
help: ## Show this help.

README.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111

1212
**LocalAI** is a drop-in replacement REST API compatible with OpenAI API specifications for local inferencing. It allows to run models locally or on-prem with consumer grade hardware, supporting multiple models families compatible with the `ggml` format. For a list of the supported model families, see [the model compatibility table below](https://github.com/go-skynet/LocalAI#model-compatibility-table).
1313

14-
- OpenAI drop-in alternative REST API
14+
- Local, OpenAI drop-in alternative REST API. You own your data.
1515
- Supports multiple models, Audio transcription, Text generation with GPTs, Image generation with stable diffusion (experimental)
1616
- Once loaded the first time, it keep models loaded in memory for faster inference
1717
- Support for prompt templates
1818
- Doesn't shell-out, but uses C++ bindings for a faster inference and better performance.
19+
- NO GPU required. NO Internet access is required either. Optional, GPU Acceleration is available in `llama.cpp`-compatible LLMs. [See building instructions](https://github.com/go-skynet/LocalAI#cublas).
1920

2021
LocalAI is a community-driven project, focused on making the AI accessible to anyone. Any contribution, feedback and PR is welcome! It was initially created by [mudler](https://github.com/mudler/) at the [SpectroCloud OSS Office](https://github.com/spectrocloud).
2122

@@ -434,7 +435,7 @@ local-ai --models-path <model_path> [--address <address>] [--threads <num_thread
434435
| debug | DEBUG | false | Enable debug mode. |
435436
| config-file | CONFIG_FILE | empty | Path to a LocalAI config file. |
436437
| upload_limit | UPLOAD_LIMIT | 5MB | Upload limit for whisper. |
437-
| image-dir | CONFIG_FILE | empty | Image directory to store and serve processed images. |
438+
| image-path | IMAGE_PATH | empty | Image directory to store and serve processed images. |
438439

439440
</details>
440441

@@ -567,6 +568,8 @@ Note: CuBLAS support is experimental, and has not been tested on real HW. please
567568
make BUILD_TYPE=cublas build
568569
```
569570

571+
More informations available in the upstream PR: https://github.com/ggerganov/llama.cpp/pull/1412
572+
570573
</details>
571574

572575
### Windows compatibility

api/api.go

+20-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package api
22

33
import (
4+
"context"
45
"errors"
56

67
model "github.com/go-skynet/LocalAI/pkg/model"
@@ -12,7 +13,7 @@ import (
1213
"github.com/rs/zerolog/log"
1314
)
1415

15-
func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App {
16+
func App(c context.Context, configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App {
1617
zerolog.SetGlobalLevel(zerolog.InfoLevel)
1718
if debug {
1819
zerolog.SetGlobalLevel(zerolog.DebugLevel)
@@ -48,7 +49,7 @@ func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, c
4849
}))
4950
}
5051

51-
cm := make(ConfigMerger)
52+
cm := NewConfigMerger()
5253
if err := cm.LoadConfigs(loader.ModelPath); err != nil {
5354
log.Error().Msgf("error loading config files: %s", err.Error())
5455
}
@@ -60,39 +61,51 @@ func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, c
6061
}
6162

6263
if debug {
63-
for k, v := range cm {
64-
log.Debug().Msgf("Model: %s (config: %+v)", k, v)
64+
for _, v := range cm.ListConfigs() {
65+
cfg, _ := cm.GetConfig(v)
66+
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
6567
}
6668
}
6769
// Default middleware config
6870
app.Use(recover.New())
6971
app.Use(cors.New())
7072

73+
// LocalAI API endpoints
74+
applier := newGalleryApplier(loader.ModelPath)
75+
applier.start(c, cm)
76+
app.Post("/models/apply", applyModelGallery(loader.ModelPath, cm, applier.C))
77+
app.Get("/models/jobs/:uid", getOpStatus(applier))
78+
7179
// openAI compatible API endpoint
80+
81+
// chat
7282
app.Post("/v1/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16))
7383
app.Post("/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16))
7484

85+
// edit
7586
app.Post("/v1/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16))
7687
app.Post("/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16))
7788

89+
// completion
7890
app.Post("/v1/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16))
7991
app.Post("/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16))
8092

93+
// embeddings
8194
app.Post("/v1/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
8295
app.Post("/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
83-
84-
// /v1/engines/{engine_id}/embeddings
85-
8696
app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16))
8797

98+
// audio
8899
app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, debug, loader, threads, ctxSize, f16))
89100

101+
// images
90102
app.Post("/v1/images/generations", imageEndpoint(cm, debug, loader, imageDir))
91103

92104
if imageDir != "" {
93105
app.Static("/generated-images", imageDir)
94106
}
95107

108+
// models
96109
app.Get("/v1/models", listModels(loader, cm))
97110
app.Get("/models", listModels(loader, cm))
98111

api/api_test.go

+10-3
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,14 @@ var _ = Describe("API test", func() {
2222
var modelLoader *model.ModelLoader
2323
var client *openai.Client
2424
var client2 *openaigo.Client
25+
var c context.Context
26+
var cancel context.CancelFunc
2527
Context("API query", func() {
2628
BeforeEach(func() {
2729
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
28-
app = App("", modelLoader, 15, 1, 512, false, true, true, "")
30+
c, cancel = context.WithCancel(context.Background())
31+
32+
app = App(c, "", modelLoader, 15, 1, 512, false, true, true, "")
2933
go app.Listen("127.0.0.1:9090")
3034

3135
defaultConfig := openai.DefaultConfig("")
@@ -42,6 +46,7 @@ var _ = Describe("API test", func() {
4246
}, "2m").ShouldNot(HaveOccurred())
4347
})
4448
AfterEach(func() {
49+
cancel()
4550
app.Shutdown()
4651
})
4752
It("returns the models list", func() {
@@ -140,7 +145,9 @@ var _ = Describe("API test", func() {
140145
Context("Config file", func() {
141146
BeforeEach(func() {
142147
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
143-
app = App(os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true, "")
148+
c, cancel = context.WithCancel(context.Background())
149+
150+
app = App(c, os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true, "")
144151
go app.Listen("127.0.0.1:9090")
145152

146153
defaultConfig := openai.DefaultConfig("")
@@ -155,10 +162,10 @@ var _ = Describe("API test", func() {
155162
}, "2m").ShouldNot(HaveOccurred())
156163
})
157164
AfterEach(func() {
165+
cancel()
158166
app.Shutdown()
159167
})
160168
It("can generate chat completions from config file", func() {
161-
162169
models, err := client.ListModels(context.TODO())
163170
Expect(err).ToNot(HaveOccurred())
164171
Expect(len(models.Models)).To(Equal(12))

api/config.go

+38-6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"os"
88
"path/filepath"
99
"strings"
10+
"sync"
1011

1112
model "github.com/go-skynet/LocalAI/pkg/model"
1213
"github.com/gofiber/fiber/v2"
@@ -43,8 +44,16 @@ type TemplateConfig struct {
4344
Edit string `yaml:"edit"`
4445
}
4546

46-
type ConfigMerger map[string]Config
47+
type ConfigMerger struct {
48+
configs map[string]Config
49+
sync.Mutex
50+
}
4751

52+
func NewConfigMerger() *ConfigMerger {
53+
return &ConfigMerger{
54+
configs: make(map[string]Config),
55+
}
56+
}
4857
func ReadConfigFile(file string) ([]*Config, error) {
4958
c := &[]*Config{}
5059
f, err := os.ReadFile(file)
@@ -72,28 +81,51 @@ func ReadConfig(file string) (*Config, error) {
7281
}
7382

7483
func (cm ConfigMerger) LoadConfigFile(file string) error {
84+
cm.Lock()
85+
defer cm.Unlock()
7586
c, err := ReadConfigFile(file)
7687
if err != nil {
7788
return fmt.Errorf("cannot load config file: %w", err)
7889
}
7990

8091
for _, cc := range c {
81-
cm[cc.Name] = *cc
92+
cm.configs[cc.Name] = *cc
8293
}
8394
return nil
8495
}
8596

8697
func (cm ConfigMerger) LoadConfig(file string) error {
98+
cm.Lock()
99+
defer cm.Unlock()
87100
c, err := ReadConfig(file)
88101
if err != nil {
89102
return fmt.Errorf("cannot read config file: %w", err)
90103
}
91104

92-
cm[c.Name] = *c
105+
cm.configs[c.Name] = *c
93106
return nil
94107
}
95108

109+
func (cm ConfigMerger) GetConfig(m string) (Config, bool) {
110+
cm.Lock()
111+
defer cm.Unlock()
112+
v, exists := cm.configs[m]
113+
return v, exists
114+
}
115+
116+
func (cm ConfigMerger) ListConfigs() []string {
117+
cm.Lock()
118+
defer cm.Unlock()
119+
var res []string
120+
for k := range cm.configs {
121+
res = append(res, k)
122+
}
123+
return res
124+
}
125+
96126
func (cm ConfigMerger) LoadConfigs(path string) error {
127+
cm.Lock()
128+
defer cm.Unlock()
97129
files, err := ioutil.ReadDir(path)
98130
if err != nil {
99131
return err
@@ -106,7 +138,7 @@ func (cm ConfigMerger) LoadConfigs(path string) error {
106138
}
107139
c, err := ReadConfig(filepath.Join(path, file.Name()))
108140
if err == nil {
109-
cm[c.Name] = *c
141+
cm.configs[c.Name] = *c
110142
}
111143
}
112144

@@ -253,7 +285,7 @@ func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (strin
253285
return modelFile, input, nil
254286
}
255287

256-
func readConfig(modelFile string, input *OpenAIRequest, cm ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
288+
func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {
257289
// Load a config file if present after the model name
258290
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml")
259291
if _, err := os.Stat(modelConfig); err == nil {
@@ -263,7 +295,7 @@ func readConfig(modelFile string, input *OpenAIRequest, cm ConfigMerger, loader
263295
}
264296

265297
var config *Config
266-
cfg, exists := cm[modelFile]
298+
cfg, exists := cm.GetConfig(modelFile)
267299
if !exists {
268300
config = &Config{
269301
OpenAIRequest: defaultRequest(modelFile),

0 commit comments

Comments
 (0)