Skip to content

Commit e98d30f

Browse files
committed
feat(ui): Asynchronous VRAM estimates with multi-context and use known_usecases
Signed-off-by: Richard Palethorpe <io@richiejp.com>
1 parent c855cd7 commit e98d30f

File tree

13 files changed

+404
-308
lines changed

13 files changed

+404
-308
lines changed

core/application/startup.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/mudler/LocalAI/core/services/jobs"
1818
"github.com/mudler/LocalAI/core/services/nodes"
1919
"github.com/mudler/LocalAI/core/services/storage"
20+
"github.com/mudler/LocalAI/pkg/vram"
2021
coreStartup "github.com/mudler/LocalAI/core/startup"
2122
"github.com/mudler/LocalAI/internal"
2223

@@ -231,6 +232,10 @@ func New(opts ...config.AppOption) (*Application, error) {
231232
xlog.Error("error registering external backends", "error", err)
232233
}
233234

235+
// Wire gallery generation counter into VRAM caches so they invalidate
236+
// when gallery data refreshes instead of using a fixed TTL.
237+
vram.SetGalleryGenerationFunc(gallery.GalleryGeneration)
238+
234239
if options.ConfigFile != "" {
235240
if err := application.ModelConfigLoader().LoadMultipleModelConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil {
236241
xlog.Error("error loading config file", "error", err)

core/gallery/gallery.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,14 @@ var (
301301
availableModelsMu sync.RWMutex
302302
availableModelsCache GalleryElements[*GalleryModel]
303303
refreshing atomic.Bool
304+
galleryGeneration atomic.Uint64
304305
)
305306

307+
// GalleryGeneration returns a counter that increments each time the gallery
308+
// model list is refreshed from upstream. VRAM estimation caches use this to
309+
// invalidate entries when the gallery data changes.
310+
func GalleryGeneration() uint64 { return galleryGeneration.Load() }
311+
306312
// AvailableGalleryModelsCached returns gallery models from an in-memory cache.
307313
// Local-only fields (installed status) are refreshed on every call. A background
308314
// goroutine is triggered to re-fetch the full model list (including network
@@ -335,6 +341,7 @@ func AvailableGalleryModelsCached(galleries []config.Gallery, systemState *syste
335341

336342
availableModelsMu.Lock()
337343
availableModelsCache = models
344+
galleryGeneration.Add(1)
338345
availableModelsMu.Unlock()
339346

340347
return models, nil
@@ -356,6 +363,7 @@ func triggerGalleryRefresh(galleries []config.Gallery, systemState *system.Syste
356363
}
357364
availableModelsMu.Lock()
358365
availableModelsCache = models
366+
galleryGeneration.Add(1)
359367
availableModelsMu.Unlock()
360368
}()
361369
}

core/http/endpoints/localai/import_model.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,17 @@ func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.Appl
5151
}
5252
estCtx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Second)
5353
defer cancel()
54-
result, err := vram.EstimateModel(estCtx, vram.ModelEstimateInput{
55-
Files: files,
56-
Options: vram.EstimateOptions{ContextLength: 8192},
57-
})
54+
result, err := vram.EstimateModelMultiContext(estCtx, vram.ModelEstimateInput{
55+
Files: files,
56+
}, []uint32{8192})
5857
if err == nil {
5958
if result.SizeBytes > 0 {
6059
resp.EstimatedSizeBytes = result.SizeBytes
6160
resp.EstimatedSizeDisplay = result.SizeDisplay
6261
}
63-
if result.VRAMBytes > 0 {
64-
resp.EstimatedVRAMBytes = result.VRAMBytes
65-
resp.EstimatedVRAMDisplay = result.VRAMDisplay
62+
if v := result.VRAMForContext(8192); v > 0 {
63+
resp.EstimatedVRAMBytes = v
64+
resp.EstimatedVRAMDisplay = vram.FormatBytes(v)
6665
}
6766
}
6867
}

core/http/endpoints/localai/vram.go

Lines changed: 32 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ package localai
22

33
import (
44
"context"
5-
"fmt"
65
"net/http"
76
"path/filepath"
7+
"slices"
88
"strings"
99
"time"
1010

@@ -14,16 +14,10 @@ import (
1414
)
1515

1616
type vramEstimateRequest struct {
17-
Model string `json:"model"` // model name (must be installed)
18-
ContextSize uint32 `json:"context_size,omitempty"` // context length to estimate for (default 8192)
19-
GPULayers int `json:"gpu_layers,omitempty"` // number of layers to offload to GPU (0 = all)
20-
KVQuantBits int `json:"kv_quant_bits,omitempty"` // KV cache quantization bits (0 = fp16)
21-
}
22-
23-
type vramEstimateResponse struct {
24-
vram.EstimateResult
25-
ContextNote string `json:"context_note,omitempty"` // note when context_size was defaulted
26-
ModelMaxContext uint64 `json:"model_max_context,omitempty"` // model's trained maximum context length
17+
Model string `json:"model"` // model name (must be installed)
18+
ContextSizes []uint32 `json:"context_sizes,omitempty"` // context sizes to estimate (default [8192])
19+
GPULayers int `json:"gpu_layers,omitempty"` // number of layers to offload to GPU (0 = all)
20+
KVQuantBits int `json:"kv_quant_bits,omitempty"` // KV cache quantization bits (0 = fp16)
2721
}
2822

2923
// resolveModelURI converts a relative model path to a file:// URI so the
@@ -36,8 +30,8 @@ func resolveModelURI(uri, modelsPath string) string {
3630
return "file://" + filepath.Join(modelsPath, uri)
3731
}
3832

39-
// addWeightFile appends a resolved weight file to files and tracks the first GGUF.
40-
func addWeightFile(uri, modelsPath string, files *[]vram.FileInput, firstGGUF *string, seen map[string]bool) {
33+
// addWeightFile appends a resolved weight file to files.
34+
func addWeightFile(uri, modelsPath string, files *[]vram.FileInput, seen map[string]bool) {
4135
if !vram.IsWeightFile(uri) {
4236
return
4337
}
@@ -47,21 +41,17 @@ func addWeightFile(uri, modelsPath string, files *[]vram.FileInput, firstGGUF *s
4741
}
4842
seen[resolved] = true
4943
*files = append(*files, vram.FileInput{URI: resolved, Size: 0})
50-
if *firstGGUF == "" && vram.IsGGUF(uri) {
51-
*firstGGUF = resolved
52-
}
5344
}
5445

5546
// VRAMEstimateEndpoint returns a handler that estimates VRAM usage for an
56-
// installed model configuration. For uninstalled models (gallery URLs), use
57-
// the gallery-level estimates in /api/models instead.
47+
// installed model configuration at multiple context sizes.
5848
// @Summary Estimate VRAM usage for a model
59-
// @Description Estimates VRAM based on model weight files, context size, and GPU layers
49+
// @Description Estimates VRAM based on model weight files at multiple context sizes
6050
// @Tags config
6151
// @Accept json
6252
// @Produce json
6353
// @Param request body vramEstimateRequest true "VRAM estimation parameters"
64-
// @Success 200 {object} vramEstimateResponse "VRAM estimate"
54+
// @Success 200 {object} vram.MultiContextEstimate "VRAM estimate"
6555
// @Router /api/models/vram-estimate [post]
6656
func VRAMEstimateEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
6757
return func(c echo.Context) error {
@@ -82,17 +72,16 @@ func VRAMEstimateEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
8272
modelsPath := appConfig.SystemState.Model.ModelsPath
8373

8474
var files []vram.FileInput
85-
var firstGGUF string
8675
seen := make(map[string]bool)
8776

8877
for _, f := range modelConfig.DownloadFiles {
89-
addWeightFile(string(f.URI), modelsPath, &files, &firstGGUF, seen)
78+
addWeightFile(string(f.URI), modelsPath, &files, seen)
9079
}
9180
if modelConfig.Model != "" {
92-
addWeightFile(modelConfig.Model, modelsPath, &files, &firstGGUF, seen)
81+
addWeightFile(modelConfig.Model, modelsPath, &files, seen)
9382
}
9483
if modelConfig.MMProj != "" {
95-
addWeightFile(modelConfig.MMProj, modelsPath, &files, &firstGGUF, seen)
84+
addWeightFile(modelConfig.MMProj, modelsPath, &files, seen)
9685
}
9786

9887
if len(files) == 0 {
@@ -101,45 +90,36 @@ func VRAMEstimateEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
10190
})
10291
}
10392

104-
contextDefaulted := false
105-
opts := vram.EstimateOptions{
106-
ContextLength: req.ContextSize,
107-
GPULayers: req.GPULayers,
108-
KVQuantBits: req.KVQuantBits,
109-
}
110-
if opts.ContextLength == 0 {
93+
contextSizes := req.ContextSizes
94+
if len(contextSizes) == 0 {
11195
if modelConfig.ContextSize != nil {
112-
opts.ContextLength = uint32(*modelConfig.ContextSize)
96+
contextSizes = []uint32{uint32(*modelConfig.ContextSize)}
11397
} else {
114-
opts.ContextLength = 8192
115-
contextDefaulted = true
98+
contextSizes = []uint32{8192}
99+
}
100+
}
101+
102+
// Include model's configured context size alongside requested sizes
103+
if modelConfig.ContextSize != nil {
104+
modelCtx := uint32(*modelConfig.ContextSize)
105+
if !slices.Contains(contextSizes, modelCtx) {
106+
contextSizes = append(contextSizes, modelCtx)
116107
}
117108
}
118109

110+
opts := vram.EstimateOptions{
111+
GPULayers: req.GPULayers,
112+
KVQuantBits: req.KVQuantBits,
113+
}
114+
119115
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
120116
defer cancel()
121117

122-
result, err := vram.Estimate(ctx, files, opts, vram.DefaultCachedSizeResolver(), vram.DefaultCachedGGUFReader())
118+
result, err := vram.EstimateMultiContext(ctx, files, contextSizes, opts, vram.DefaultCachedSizeResolver(), vram.DefaultCachedGGUFReader())
123119
if err != nil {
124120
return c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()})
125121
}
126122

127-
resp := vramEstimateResponse{EstimateResult: result}
128-
129-
// When context was defaulted to 8192, read the GGUF metadata to report
130-
// the model's trained maximum context length so callers know the estimate
131-
// may be conservative.
132-
if contextDefaulted && firstGGUF != "" {
133-
ggufMeta, err := vram.DefaultCachedGGUFReader().ReadMetadata(ctx, firstGGUF)
134-
if err == nil && ggufMeta != nil && ggufMeta.MaximumContextLength > 0 {
135-
resp.ModelMaxContext = ggufMeta.MaximumContextLength
136-
resp.ContextNote = fmt.Sprintf(
137-
"Estimate used default context_size=8192. The model's trained maximum context is %d; VRAM usage will be higher at larger context sizes.",
138-
ggufMeta.MaximumContextLength,
139-
)
140-
}
141-
}
142-
143-
return c.JSON(http.StatusOK, resp)
123+
return c.JSON(http.StatusOK, result)
144124
}
145125
}

core/http/react-ui/src/pages/Models.jsx

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ function GalleryLoader() {
8686
}
8787

8888

89+
const CONTEXT_SIZES = [8192, 16384, 32768, 65536, 131072, 262144]
90+
const CONTEXT_LABELS = ['8K', '16K', '32K', '64K', '128K', '256K']
91+
8992
const FILTERS = [
9093
{ key: '', label: 'All', icon: 'fa-layer-group' },
9194
{ key: 'chat', label: 'Chat', icon: 'fa-brain' },
@@ -119,6 +122,7 @@ export default function Models() {
119122
const [allBackends, setAllBackends] = useState([])
120123
const [backendUsecases, setBackendUsecases] = useState({})
121124
const [estimates, setEstimates] = useState({})
125+
const [contextSize, setContextSize] = useState(CONTEXT_SIZES[0])
122126
const debounceRef = useRef(null)
123127
const [confirmDialog, setConfirmDialog] = useState(null)
124128

@@ -190,9 +194,9 @@ export default function Models() {
190194
models.forEach(model => {
191195
const id = model.name || model.id
192196
if (estimates[id]) return
193-
modelsApi.estimate(id).then(est => {
197+
modelsApi.estimate(id, CONTEXT_SIZES).then(est => {
194198
if (cancelled) return
195-
if (est && (est.SizeBytes || est.VRAMBytes)) {
199+
if (est && (est.sizeBytes || est.estimates)) {
196200
setEstimates(prev => ({ ...prev, [id]: est }))
197201
}
198202
}).catch(() => {})
@@ -371,6 +375,25 @@ export default function Models() {
371375
)}
372376
</div>
373377

378+
{/* Context size slider for VRAM estimates */}
379+
<div style={{ display: 'flex', alignItems: 'center', gap: 'var(--spacing-sm)', marginBottom: 'var(--spacing-md)', fontSize: '0.8125rem' }}>
380+
<label style={{ color: 'var(--color-text-muted)', whiteSpace: 'nowrap' }}>
381+
<i className="fas fa-memory" style={{ marginRight: 4 }} />
382+
Context:
383+
</label>
384+
<input
385+
type="range"
386+
min={0}
387+
max={CONTEXT_SIZES.length - 1}
388+
value={CONTEXT_SIZES.indexOf(contextSize)}
389+
onChange={(e) => setContextSize(CONTEXT_SIZES[e.target.value])}
390+
style={{ width: 140, accentColor: 'var(--color-primary)' }}
391+
/>
392+
<span style={{ fontWeight: 600, minWidth: '3em' }}>
393+
{CONTEXT_LABELS[CONTEXT_SIZES.indexOf(contextSize)]}
394+
</span>
395+
</div>
396+
374397
{/* Table */}
375398
{loading ? (
376399
<GalleryLoader />
@@ -415,10 +438,11 @@ export default function Models() {
415438
<tbody>
416439
{models.map((model, idx) => {
417440
const name = model.name || model.id
418-
const est = estimates[name] || {}
419-
const sizeDisplay = est.SizeDisplay || model.estimated_size_display
420-
const vramDisplay = est.VRAMDisplay || model.estimated_vram_display
421-
const vramBytes = est.VRAMBytes || model.estimated_vram_bytes
441+
const estData = estimates[name]
442+
const sizeDisplay = estData?.sizeDisplay
443+
const ctxEst = estData?.estimates?.[String(contextSize)]
444+
const vramDisplay = ctxEst?.vramDisplay
445+
const vramBytes = ctxEst?.vramBytes
422446
const installing = isInstalling(name)
423447
const progress = getOperationProgress(name)
424448
const fit = fitsGpu(vramBytes)

core/http/react-ui/src/utils/api.js

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,10 @@ export const modelsApi = {
7979
listCapabilities: () => fetchJSON(API_CONFIG.endpoints.modelsCapabilities),
8080
install: (id) => postJSON(API_CONFIG.endpoints.installModel(id), {}),
8181
delete: (id) => postJSON(API_CONFIG.endpoints.deleteModel(id), {}),
82-
estimate: (id) => fetchJSON(API_CONFIG.endpoints.modelEstimate(id)),
82+
estimate: (id, contexts) => fetchJSON(
83+
buildUrl(API_CONFIG.endpoints.modelEstimate(id),
84+
contexts?.length ? { contexts: contexts.join(',') } : {})
85+
),
8386
getConfig: (id) => postJSON(API_CONFIG.endpoints.modelConfig(id), {}),
8487
getConfigJson: (name) => fetchJSON(API_CONFIG.endpoints.modelConfigJson(name)),
8588
getJob: (uid) => fetchJSON(API_CONFIG.endpoints.modelJob(uid)),

0 commit comments

Comments
 (0)