@@ -2,9 +2,9 @@ package localai
22
33import (
44 "context"
5- "fmt"
65 "net/http"
76 "path/filepath"
7+ "slices"
88 "strings"
99 "time"
1010
@@ -14,16 +14,10 @@ import (
1414)
1515
1616type vramEstimateRequest struct {
17- Model string `json:"model"` // model name (must be installed)
18- ContextSize uint32 `json:"context_size,omitempty"` // context length to estimate for (default 8192)
19- GPULayers int `json:"gpu_layers,omitempty"` // number of layers to offload to GPU (0 = all)
20- KVQuantBits int `json:"kv_quant_bits,omitempty"` // KV cache quantization bits (0 = fp16)
21- }
22-
23- type vramEstimateResponse struct {
24- vram.EstimateResult
25- ContextNote string `json:"context_note,omitempty"` // note when context_size was defaulted
26- ModelMaxContext uint64 `json:"model_max_context,omitempty"` // model's trained maximum context length
17+ Model string `json:"model"` // model name (must be installed)
18+ ContextSizes []uint32 `json:"context_sizes,omitempty"` // context sizes to estimate (default [8192])
19+ GPULayers int `json:"gpu_layers,omitempty"` // number of layers to offload to GPU (0 = all)
20+ KVQuantBits int `json:"kv_quant_bits,omitempty"` // KV cache quantization bits (0 = fp16)
2721}
2822
2923// resolveModelURI converts a relative model path to a file:// URI so the
@@ -36,8 +30,8 @@ func resolveModelURI(uri, modelsPath string) string {
3630 return "file://" + filepath .Join (modelsPath , uri )
3731}
3832
39- // addWeightFile appends a resolved weight file to files and tracks the first GGUF .
40- func addWeightFile (uri , modelsPath string , files * []vram.FileInput , firstGGUF * string , seen map [string ]bool ) {
33+ // addWeightFile appends a resolved weight file to files.
34+ func addWeightFile (uri , modelsPath string , files * []vram.FileInput , seen map [string ]bool ) {
4135 if ! vram .IsWeightFile (uri ) {
4236 return
4337 }
@@ -47,21 +41,17 @@ func addWeightFile(uri, modelsPath string, files *[]vram.FileInput, firstGGUF *s
4741 }
4842 seen [resolved ] = true
4943 * files = append (* files , vram.FileInput {URI : resolved , Size : 0 })
50- if * firstGGUF == "" && vram .IsGGUF (uri ) {
51- * firstGGUF = resolved
52- }
5344}
5445
5546// VRAMEstimateEndpoint returns a handler that estimates VRAM usage for an
56- // installed model configuration. For uninstalled models (gallery URLs), use
57- // the gallery-level estimates in /api/models instead.
47+ // installed model configuration at multiple context sizes.
5848// @Summary Estimate VRAM usage for a model
59- // @Description Estimates VRAM based on model weight files, context size, and GPU layers
49+ // @Description Estimates VRAM based on model weight files at multiple context sizes
6050// @Tags config
6151// @Accept json
6252// @Produce json
6353// @Param request body vramEstimateRequest true "VRAM estimation parameters"
64- // @Success 200 {object} vramEstimateResponse "VRAM estimate"
54+ // @Success 200 {object} vram.MultiContextEstimate "VRAM estimate"
6555// @Router /api/models/vram-estimate [post]
6656func VRAMEstimateEndpoint (cl * config.ModelConfigLoader , appConfig * config.ApplicationConfig ) echo.HandlerFunc {
6757 return func (c echo.Context ) error {
@@ -82,17 +72,16 @@ func VRAMEstimateEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
8272 modelsPath := appConfig .SystemState .Model .ModelsPath
8373
8474 var files []vram.FileInput
85- var firstGGUF string
8675 seen := make (map [string ]bool )
8776
8877 for _ , f := range modelConfig .DownloadFiles {
89- addWeightFile (string (f .URI ), modelsPath , & files , & firstGGUF , seen )
78+ addWeightFile (string (f .URI ), modelsPath , & files , seen )
9079 }
9180 if modelConfig .Model != "" {
92- addWeightFile (modelConfig .Model , modelsPath , & files , & firstGGUF , seen )
81+ addWeightFile (modelConfig .Model , modelsPath , & files , seen )
9382 }
9483 if modelConfig .MMProj != "" {
95- addWeightFile (modelConfig .MMProj , modelsPath , & files , & firstGGUF , seen )
84+ addWeightFile (modelConfig .MMProj , modelsPath , & files , seen )
9685 }
9786
9887 if len (files ) == 0 {
@@ -101,45 +90,36 @@ func VRAMEstimateEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
10190 })
10291 }
10392
104- contextDefaulted := false
105- opts := vram.EstimateOptions {
106- ContextLength : req .ContextSize ,
107- GPULayers : req .GPULayers ,
108- KVQuantBits : req .KVQuantBits ,
109- }
110- if opts .ContextLength == 0 {
93+ contextSizes := req .ContextSizes
94+ if len (contextSizes ) == 0 {
11195 if modelConfig .ContextSize != nil {
112- opts . ContextLength = uint32 (* modelConfig .ContextSize )
96+ contextSizes = [] uint32 { uint32 (* modelConfig .ContextSize )}
11397 } else {
114- opts .ContextLength = 8192
115- contextDefaulted = true
98+ contextSizes = []uint32 {8192 }
99+ }
100+ }
101+
102+ // Include model's configured context size alongside requested sizes
103+ if modelConfig .ContextSize != nil {
104+ modelCtx := uint32 (* modelConfig .ContextSize )
105+ if ! slices .Contains (contextSizes , modelCtx ) {
106+ contextSizes = append (contextSizes , modelCtx )
116107 }
117108 }
118109
110+ opts := vram.EstimateOptions {
111+ GPULayers : req .GPULayers ,
112+ KVQuantBits : req .KVQuantBits ,
113+ }
114+
119115 ctx , cancel := context .WithTimeout (context .Background (), 10 * time .Second )
120116 defer cancel ()
121117
122- result , err := vram .Estimate (ctx , files , opts , vram .DefaultCachedSizeResolver (), vram .DefaultCachedGGUFReader ())
118+ result , err := vram .EstimateMultiContext (ctx , files , contextSizes , opts , vram .DefaultCachedSizeResolver (), vram .DefaultCachedGGUFReader ())
123119 if err != nil {
124120 return c .JSON (http .StatusInternalServerError , map [string ]any {"error" : err .Error ()})
125121 }
126122
127- resp := vramEstimateResponse {EstimateResult : result }
128-
129- // When context was defaulted to 8192, read the GGUF metadata to report
130- // the model's trained maximum context length so callers know the estimate
131- // may be conservative.
132- if contextDefaulted && firstGGUF != "" {
133- ggufMeta , err := vram .DefaultCachedGGUFReader ().ReadMetadata (ctx , firstGGUF )
134- if err == nil && ggufMeta != nil && ggufMeta .MaximumContextLength > 0 {
135- resp .ModelMaxContext = ggufMeta .MaximumContextLength
136- resp .ContextNote = fmt .Sprintf (
137- "Estimate used default context_size=8192. The model's trained maximum context is %d; VRAM usage will be higher at larger context sizes." ,
138- ggufMeta .MaximumContextLength ,
139- )
140- }
141- }
142-
143- return c .JSON (http .StatusOK , resp )
123+ return c .JSON (http .StatusOK , result )
144124 }
145125}
0 commit comments