diff --git a/api/types.go b/api/types.go index a38b335b7..b4a65fe5c 100644 --- a/api/types.go +++ b/api/types.go @@ -12,6 +12,7 @@ import ( "time" "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/types/model" ) // StatusError is an error with an HTTP status code and message. @@ -340,17 +341,18 @@ type ShowRequest struct { // ShowResponse is the response returned from [Client.Show]. type ShowResponse struct { - License string `json:"license,omitempty"` - Modelfile string `json:"modelfile,omitempty"` - Parameters string `json:"parameters,omitempty"` - Template string `json:"template,omitempty"` - System string `json:"system,omitempty"` - Details ModelDetails `json:"details,omitempty"` - Messages []Message `json:"messages,omitempty"` - ModelInfo map[string]any `json:"model_info,omitempty"` - ProjectorInfo map[string]any `json:"projector_info,omitempty"` - Tensors []Tensor `json:"tensors,omitempty"` - ModifiedAt time.Time `json:"modified_at,omitempty"` + License string `json:"license,omitempty"` + Modelfile string `json:"modelfile,omitempty"` + Parameters string `json:"parameters,omitempty"` + Template string `json:"template,omitempty"` + System string `json:"system,omitempty"` + Details ModelDetails `json:"details,omitempty"` + Messages []Message `json:"messages,omitempty"` + ModelInfo map[string]any `json:"model_info,omitempty"` + ProjectorInfo map[string]any `json:"projector_info,omitempty"` + Tensors []Tensor `json:"tensors,omitempty"` + Capabilities []model.Capability `json:"capabilities,omitempty"` + ModifiedAt time.Time `json:"modified_at,omitempty"` } // CopyRequest is the request passed to [Client.Copy]. diff --git a/cmd/cmd.go b/cmd/cmd.go index abb4806b5..36d7e6cfc 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -18,6 +18,7 @@ import ( "os/signal" "path/filepath" "runtime" + "slices" "sort" "strconv" "strings" @@ -339,6 +340,11 @@ func RunHandler(cmd *cobra.Command, args []string) error { return err } + opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) + + // TODO: remove the projector info and vision info checks below, + // these are left in for backwards compatibility with older servers + // that don't have the capabilities field in the model info if len(info.ProjectorInfo) != 0 { opts.MultiModal = true } @@ -669,6 +675,15 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error { return }) + if len(resp.Capabilities) > 0 { + tableRender("Capabilities", func() (rows [][]string) { + for _, capability := range resp.Capabilities { + rows = append(rows, []string{"", capability.String()}) + } + return + }) + } + if resp.ProjectorInfo != nil { tableRender("Projector", func() (rows [][]string) { arch := resp.ProjectorInfo["general.architecture"].(string) diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index ea3bdffe8..e6a542d09 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -16,6 +16,7 @@ import ( "github.com/spf13/cobra" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/types/model" ) func TestShowInfo(t *testing.T) { @@ -260,6 +261,34 @@ Weigh anchor! t.Errorf("unexpected output (-want +got):\n%s", diff) } }) + + t.Run("capabilities", func(t *testing.T) { + var b bytes.Buffer + if err := showInfo(&api.ShowResponse{ + Details: api.ModelDetails{ + Family: "test", + ParameterSize: "7B", + QuantizationLevel: "FP16", + }, + Capabilities: []model.Capability{model.CapabilityVision, model.CapabilityTools}, + }, false, &b); err != nil { + t.Fatal(err) + } + + expect := " Model\n" + + " architecture test \n" + + " parameters 7B \n" + + " quantization FP16 \n" + + "\n" + + " Capabilities\n" + + " vision \n" + + " tools \n" + + "\n" + + if diff := cmp.Diff(expect, b.String()); diff != "" { + t.Errorf("unexpected output (-want +got):\n%s", diff) + } + }) } func TestDeleteHandler(t *testing.T) { diff --git a/docs/api.md b/docs/api.md index fe044d794..04ee299d2 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1217,7 +1217,7 @@ Show information about a model including details, modelfile, template, parameter ```shell curl http://localhost:11434/api/show -d '{ - "model": "llama3.2" + "model": "llava" }' ``` @@ -1260,7 +1260,11 @@ curl http://localhost:11434/api/show -d '{ "tokenizer.ggml.pre": "llama-bpe", "tokenizer.ggml.token_type": [], // populates if `verbose=true` "tokenizer.ggml.tokens": [] // populates if `verbose=true` - } + }, + "capabilities": [ + "completion", + "vision" + ], } ``` diff --git a/server/images.go b/server/images.go index 290e68ba9..2ef9e5d02 100644 --- a/server/images.go +++ b/server/images.go @@ -35,17 +35,11 @@ var ( errCapabilityCompletion = errors.New("completion") errCapabilityTools = errors.New("tools") errCapabilityInsert = errors.New("insert") + errCapabilityVision = errors.New("vision") + errCapabilityEmbedding = errors.New("embedding") errInsecureProtocol = errors.New("insecure protocol http") ) -type Capability string - -const ( - CapabilityCompletion = Capability("completion") - CapabilityTools = Capability("tools") - CapabilityInsert = Capability("insert") -) - type registryOptions struct { Insecure bool Username string @@ -72,46 +66,77 @@ type Model struct { Template *template.Template } +// Capabilities returns the capabilities that the model supports +func (m *Model) Capabilities() []model.Capability { + capabilities := []model.Capability{} + + // Check for completion capability + r, err := os.Open(m.ModelPath) + if err == nil { + defer r.Close() + + f, _, err := ggml.Decode(r, 0) + if err == nil { + if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok { + capabilities = append(capabilities, model.CapabilityEmbedding) + } else { + capabilities = append(capabilities, model.CapabilityCompletion) + } + if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok { + capabilities = append(capabilities, model.CapabilityVision) + } + } else { + slog.Error("couldn't decode ggml", "error", err) + } + } else { + slog.Error("couldn't open model file", "error", err) + } + + if m.Template == nil { + return capabilities + } + + // Check for tools capability + if slices.Contains(m.Template.Vars(), "tools") { + capabilities = append(capabilities, model.CapabilityTools) + } + + // Check for insert capability + if slices.Contains(m.Template.Vars(), "suffix") { + capabilities = append(capabilities, model.CapabilityInsert) + } + + return capabilities +} + // CheckCapabilities checks if the model has the specified capabilities returning an error describing // any missing or unknown capabilities -func (m *Model) CheckCapabilities(caps ...Capability) error { +func (m *Model) CheckCapabilities(want ...model.Capability) error { + available := m.Capabilities() var errs []error - for _, cap := range caps { - switch cap { - case CapabilityCompletion: - r, err := os.Open(m.ModelPath) - if err != nil { - slog.Error("couldn't open model file", "error", err) - continue - } - defer r.Close() - // TODO(mxyng): decode the GGML into model to avoid doing this multiple times - f, _, err := ggml.Decode(r, 0) - if err != nil { - slog.Error("couldn't decode ggml", "error", err) - continue - } + // Map capabilities to their corresponding error + capToErr := map[model.Capability]error{ + model.CapabilityCompletion: errCapabilityCompletion, + model.CapabilityTools: errCapabilityTools, + model.CapabilityInsert: errCapabilityInsert, + model.CapabilityVision: errCapabilityVision, + model.CapabilityEmbedding: errCapabilityEmbedding, + } - if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok { - errs = append(errs, errCapabilityCompletion) - } - case CapabilityTools: - if !slices.Contains(m.Template.Vars(), "tools") { - errs = append(errs, errCapabilityTools) - } - case CapabilityInsert: - vars := m.Template.Vars() - if !slices.Contains(vars, "suffix") { - errs = append(errs, errCapabilityInsert) - } - default: + for _, cap := range want { + err, ok := capToErr[cap] + if !ok { slog.Error("unknown capability", "capability", cap) return fmt.Errorf("unknown capability: %s", cap) } + + if !slices.Contains(available, cap) { + errs = append(errs, err) + } } - if err := errors.Join(errs...); err != nil { + if len(errs) > 0 { return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...)) } diff --git a/server/images_test.go b/server/images_test.go new file mode 100644 index 000000000..22e5b7e6a --- /dev/null +++ b/server/images_test.go @@ -0,0 +1,360 @@ +package server + +import ( + "bytes" + "encoding/binary" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/ollama/ollama/template" + "github.com/ollama/ollama/types/model" +) + +// Constants for GGUF magic bytes and version +var ( + ggufMagic = []byte{0x47, 0x47, 0x55, 0x46} // "GGUF" + ggufVer = uint32(3) // Version 3 +) + +// Helper function to create mock GGUF data +func createMockGGUFData(architecture string, vision bool) []byte { + var buf bytes.Buffer + + // Write GGUF header + buf.Write(ggufMagic) + binary.Write(&buf, binary.LittleEndian, ggufVer) + + // Write tensor count (0 for our test) + var numTensors uint64 = 0 + binary.Write(&buf, binary.LittleEndian, numTensors) + + // Calculate number of metadata entries + numMetaEntries := uint64(1) // architecture entry + if vision { + numMetaEntries++ + } + // Add embedding entry if architecture is "bert" + if architecture == "bert" { + numMetaEntries++ + } + binary.Write(&buf, binary.LittleEndian, numMetaEntries) + + // Write architecture metadata + archKey := "general.architecture" + keyLen := uint64(len(archKey)) + binary.Write(&buf, binary.LittleEndian, keyLen) + buf.WriteString(archKey) + + // String type (8) + var strType uint32 = 8 + binary.Write(&buf, binary.LittleEndian, strType) + + // String length + strLen := uint64(len(architecture)) + binary.Write(&buf, binary.LittleEndian, strLen) + buf.WriteString(architecture) + + if vision { + visionKey := architecture + ".vision.block_count" + keyLen = uint64(len(visionKey)) + binary.Write(&buf, binary.LittleEndian, keyLen) + buf.WriteString(visionKey) + + // uint32 type (4) + var uint32Type uint32 = 4 + binary.Write(&buf, binary.LittleEndian, uint32Type) + + // uint32 value (1) + var countVal uint32 = 1 + binary.Write(&buf, binary.LittleEndian, countVal) + } + // Write embedding metadata if architecture is "bert" + if architecture == "bert" { + poolKey := architecture + ".pooling_type" + keyLen = uint64(len(poolKey)) + binary.Write(&buf, binary.LittleEndian, keyLen) + buf.WriteString(poolKey) + + // uint32 type (4) + var uint32Type uint32 = 4 + binary.Write(&buf, binary.LittleEndian, uint32Type) + + // uint32 value (1) + var poolingVal uint32 = 1 + binary.Write(&buf, binary.LittleEndian, poolingVal) + } + + return buf.Bytes() +} + +func TestModelCapabilities(t *testing.T) { + // Create a temporary directory for test files + tempDir, err := os.MkdirTemp("", "model_capabilities_test") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + // Create different types of mock model files + completionModelPath := filepath.Join(tempDir, "model.bin") + visionModelPath := filepath.Join(tempDir, "vision_model.bin") + embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin") + // Create a simple model file for tests that don't depend on GGUF content + simpleModelPath := filepath.Join(tempDir, "simple_model.bin") + + err = os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644) + if err != nil { + t.Fatalf("Failed to create completion model file: %v", err) + } + err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644) + if err != nil { + t.Fatalf("Failed to create completion model file: %v", err) + } + err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644) + if err != nil { + t.Fatalf("Failed to create embedding model file: %v", err) + } + err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644) + if err != nil { + t.Fatalf("Failed to create simple model file: %v", err) + } + + toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}") + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + chatTemplate, err := template.Parse("{{ .prompt }}") + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}") + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + testModels := []struct { + name string + model Model + expectedCaps []model.Capability + }{ + { + name: "model with completion capability", + model: Model{ + ModelPath: completionModelPath, + Template: chatTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityCompletion}, + }, + + { + name: "model with completion, tools, and insert capability", + model: Model{ + ModelPath: completionModelPath, + Template: toolsInsertTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert}, + }, + { + name: "model with tools and insert capability", + model: Model{ + ModelPath: simpleModelPath, + Template: toolsInsertTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert}, + }, + { + name: "model with tools capability", + model: Model{ + ModelPath: simpleModelPath, + Template: toolsTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityTools}, + }, + { + name: "model with vision capability", + model: Model{ + ModelPath: visionModelPath, + Template: chatTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision}, + }, + { + name: "model with vision, tools, and insert capability", + model: Model{ + ModelPath: visionModelPath, + Template: toolsInsertTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert}, + }, + { + name: "model with embedding capability", + model: Model{ + ModelPath: embeddingModelPath, + Template: chatTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityEmbedding}, + }, + } + + // compare two slices of model.Capability regardless of order + compareCapabilities := func(a, b []model.Capability) bool { + if len(a) != len(b) { + return false + } + + aCount := make(map[model.Capability]int) + for _, cap := range a { + aCount[cap]++ + } + + bCount := make(map[model.Capability]int) + for _, cap := range b { + bCount[cap]++ + } + + for cap, count := range aCount { + if bCount[cap] != count { + return false + } + } + + return true + } + + for _, tt := range testModels { + t.Run(tt.name, func(t *testing.T) { + // Test Capabilities method + caps := tt.model.Capabilities() + if !compareCapabilities(caps, tt.expectedCaps) { + t.Errorf("Expected capabilities %v, got %v", tt.expectedCaps, caps) + } + }) + } +} + +func TestModelCheckCapabilities(t *testing.T) { + // Create a temporary directory for test files + tempDir, err := os.MkdirTemp("", "model_check_capabilities_test") + if err != nil { + t.Fatalf("Failed to create temp directory: %v", err) + } + defer os.RemoveAll(tempDir) + + visionModelPath := filepath.Join(tempDir, "vision_model.bin") + simpleModelPath := filepath.Join(tempDir, "model.bin") + embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin") + + err = os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644) + if err != nil { + t.Fatalf("Failed to create simple model file: %v", err) + } + err = os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644) + if err != nil { + t.Fatalf("Failed to create vision model file: %v", err) + } + err = os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644) + if err != nil { + t.Fatalf("Failed to create embedding model file: %v", err) + } + + toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}") + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + chatTemplate, err := template.Parse("{{ .prompt }}") + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}") + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + tests := []struct { + name string + model Model + checkCaps []model.Capability + expectedErrMsg string + }{ + { + name: "completion model without tools capability", + model: Model{ + ModelPath: simpleModelPath, + Template: chatTemplate, + }, + checkCaps: []model.Capability{model.CapabilityTools}, + expectedErrMsg: "does not support tools", + }, + { + name: "model with all needed capabilities", + model: Model{ + ModelPath: simpleModelPath, + Template: toolsInsertTemplate, + }, + checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert}, + }, + { + name: "model missing insert capability", + model: Model{ + ModelPath: simpleModelPath, + Template: toolsTemplate, + }, + checkCaps: []model.Capability{model.CapabilityInsert}, + expectedErrMsg: "does not support insert", + }, + { + name: "model missing vision capability", + model: Model{ + ModelPath: simpleModelPath, + Template: toolsTemplate, + }, + checkCaps: []model.Capability{model.CapabilityVision}, + expectedErrMsg: "does not support vision", + }, + { + name: "model with vision capability", + model: Model{ + ModelPath: visionModelPath, + Template: chatTemplate, + }, + checkCaps: []model.Capability{model.CapabilityVision}, + }, + { + name: "model with embedding capability", + model: Model{ + ModelPath: embeddingModelPath, + Template: chatTemplate, + }, + checkCaps: []model.Capability{model.CapabilityEmbedding}, + }, + { + name: "unknown capability", + model: Model{ + ModelPath: simpleModelPath, + Template: chatTemplate, + }, + checkCaps: []model.Capability{"unknown"}, + expectedErrMsg: "unknown capability", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test CheckCapabilities method + err := tt.model.CheckCapabilities(tt.checkCaps...) + if tt.expectedErrMsg == "" { + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + } else { + if err == nil { + t.Errorf("Expected error containing %q, got nil", tt.expectedErrMsg) + } else if !strings.Contains(err.Error(), tt.expectedErrMsg) { + t.Errorf("Expected error containing %q, got: %v", tt.expectedErrMsg, err) + } + } + }) + } +} diff --git a/server/routes.go b/server/routes.go index 92336af00..95e498202 100644 --- a/server/routes.go +++ b/server/routes.go @@ -87,7 +87,7 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options // scheduleRunner schedules a runner after validating inputs such as capabilities and model options. // It returns the allocated runner, model instance, and consolidated options if successful and error otherwise. -func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) { +func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.Capability, requestOpts map[string]any, keepAlive *api.Duration) (llm.LlamaServer, *Model, *api.Options, error) { if name == "" { return nil, nil, nil, fmt.Errorf("model %w", errRequired) } @@ -144,7 +144,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - model, err := GetModel(name.String()) + m, err := GetModel(name.String()) if err != nil { switch { case errors.Is(err, fs.ErrNotExist): @@ -159,7 +159,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { // expire the runner if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 { - s.sched.expireRunner(model) + s.sched.expireRunner(m) c.JSON(http.StatusOK, api.GenerateResponse{ Model: req.Model, @@ -176,9 +176,9 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - caps := []Capability{CapabilityCompletion} + caps := []model.Capability{model.CapabilityCompletion} if req.Suffix != "" { - caps = append(caps, CapabilityInsert) + caps = append(caps, model.CapabilityInsert) } r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) @@ -203,7 +203,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - isMllama := checkMllamaModelFamily(model) + isMllama := checkMllamaModelFamily(m) if isMllama && len(req.Images) > 1 { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"}) return @@ -211,7 +211,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { images := make([]llm.ImageData, len(req.Images)) for i := range req.Images { - if isMllama && len(model.ProjectorPaths) > 0 { + if isMllama && len(m.ProjectorPaths) > 0 { data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i])) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"}) @@ -422,7 +422,7 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } - r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive) + r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) return @@ -530,7 +530,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { return } - r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive) + r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive) if err != nil { handleScheduleError(c, req.Model, err) return @@ -813,12 +813,13 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { } resp := &api.ShowResponse{ - License: strings.Join(m.License, "\n"), - System: m.System, - Template: m.Template.String(), - Details: modelDetails, - Messages: msgs, - ModifiedAt: manifest.fi.ModTime(), + License: strings.Join(m.License, "\n"), + System: m.System, + Template: m.Template.String(), + Details: modelDetails, + Messages: msgs, + Capabilities: m.Capabilities(), + ModifiedAt: manifest.fi.ModTime(), } var params []string @@ -1468,9 +1469,9 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - caps := []Capability{CapabilityCompletion} + caps := []model.Capability{model.CapabilityCompletion} if len(req.Tools) > 0 { - caps = append(caps, CapabilityTools) + caps = append(caps, model.CapabilityTools) } name := model.ParseName(req.Model) diff --git a/server/sched.go b/server/sched.go index 9126c2969..e6cefa5ac 100644 --- a/server/sched.go +++ b/server/sched.go @@ -20,6 +20,7 @@ import ( "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/types/model" ) type LlmRequest struct { @@ -195,7 +196,7 @@ func (s *Scheduler) processPending(ctx context.Context) { } // Embedding models should always be loaded with parallel=1 - if pending.model.CheckCapabilities(CapabilityCompletion) != nil { + if pending.model.CheckCapabilities(model.CapabilityCompletion) != nil { numParallel = 1 } diff --git a/types/model/capability.go b/types/model/capability.go new file mode 100644 index 000000000..fb8689403 --- /dev/null +++ b/types/model/capability.go @@ -0,0 +1,15 @@ +package model + +type Capability string + +const ( + CapabilityCompletion = Capability("completion") + CapabilityTools = Capability("tools") + CapabilityInsert = Capability("insert") + CapabilityVision = Capability("vision") + CapabilityEmbedding = Capability("embedding") +) + +func (c Capability) String() string { + return string(c) +}