New engine: vision models and auto-fallback (#9113)

* Include unified vision layers in memory prediction

For newer vision models with a single gguf, include
the projection estimates.

* Adjust CLI to handle both styles of vision model metadata

* Wire up new tokenizers for new engine

If we're loading the new engine, utilize the new model
text processor instead of calling into cgo wrappers for
llama.cpp.  This also cleans up some tech debt from the
older tokenization flow for the C++ server which was
no longer used.

This also adjusts the grammar handling logic to pass
through to the new engine instead of utilizing the cgo
schema to grammar call.

* Lay foundation for auto selection of new engine
This commit is contained in:
Daniel Hiltgen 2025-03-04 09:03:46 -08:00 committed by GitHub
parent 7a01ad7614
commit 1fdb351c37
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 249 additions and 170 deletions

View File

@ -339,10 +339,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
// TODO(jessegross): We should either find another way to know if this is if len(info.ProjectorInfo) != 0 {
// a vision model or remove the logic. Also consider that other modalities will opts.MultiModal = true
// need different behavior anyways. }
opts.MultiModal = len(info.ProjectorInfo) != 0 || envconfig.NewEngine() for k := range info.ModelInfo {
if strings.Contains(k, ".vision.") {
opts.MultiModal = true
break
}
}
opts.ParentModel = info.Details.ParentModel opts.ParentModel = info.Details.ParentModel
if interactive { if interactive {

View File

@ -565,6 +565,43 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
return return
} }
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
switch llm.KV().Architecture() {
case "mllama":
for _, layer := range llm.Tensors().GroupLayers()["v"] {
weights += layer.Size()
}
kv := func(n string) uint64 {
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
return uint64(v)
}
return 0
}
imageSize := kv("image_size")
maxNumTiles := kv("max_num_tiles")
embeddingLength := kv("embedding_length")
headCount := kv("attention.head_count")
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
graphSize = 4 * (8 +
imageSize*imageSize*kv("num_channels")*maxNumTiles +
embeddingLength*numPatches*maxNumTiles +
9*embeddingLength*numPaddedPatches*maxNumTiles +
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
}
return weights, graphSize
}
// SupportsKVCacheType checks if the requested cache type is supported // SupportsKVCacheType checks if the requested cache type is supported
func (f GGML) SupportsKVCacheType(cacheType string) bool { func (f GGML) SupportsKVCacheType(cacheType string) bool {
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType) return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)

View File

@ -115,6 +115,9 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
// multimodal models require at least 2048 context // multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048) opts.NumCtx = max(opts.NumCtx, 2048)
} }
if projectorWeights == 0 && projectorGraph == 0 {
projectorWeights, projectorGraph = f.VisionGraphSize()
}
layers := f.Tensors().GroupLayers() layers := f.Tensors().GroupLayers()
// add one layer worth of memory as a buffer // add one layer worth of memory as a buffer

View File

@ -30,6 +30,7 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
) )
type LlamaServer interface { type LlamaServer interface {
@ -54,8 +55,15 @@ type llmServer struct {
options api.Options options api.Options
numParallel int numParallel int
modelPath string modelPath string
modelLock sync.Mutex // Temporary until we switch fully to Go server
model *llama.Model // If non-nil, the runner is a new Go server // llamaModel is an instance of the cgo llama.cpp model definition
// nil if this server is running the new engine
llamaModel *llama.Model
llamaModelLock sync.Mutex
// textProcessor handles text encoding/decoding for the model in the Ollama engine
// nil if this server is running the llama.cpp based engine
textProcessor model.TextProcessor
estimate MemoryEstimate estimate MemoryEstimate
totalLayers uint64 totalLayers uint64
@ -89,7 +97,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
// NewLlamaServer will run a server for the given GPUs // NewLlamaServer will run a server for the given GPUs
// The gpu list must be a single family. // The gpu list must be a single family.
func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) { func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
systemInfo := discover.GetSystemInfo() systemInfo := discover.GetSystemInfo()
systemTotalMemory := systemInfo.System.TotalMemory systemTotalMemory := systemInfo.System.TotalMemory
systemFreeMemory := systemInfo.System.FreeMemory systemFreeMemory := systemInfo.System.FreeMemory
@ -130,7 +138,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
slog.Info("offload", "", estimate) slog.Info("offload", "", estimate)
params := []string{ params := []string{
"--model", model, "--model", modelPath,
"--ctx-size", strconv.Itoa(opts.NumCtx), "--ctx-size", strconv.Itoa(opts.NumCtx),
"--batch-size", strconv.Itoa(opts.NumBatch), "--batch-size", strconv.Itoa(opts.NumBatch),
} }
@ -153,11 +161,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
} }
} }
if len(projectors) > 0 {
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
params = append(params, "--mmproj", projectors[0])
}
defaultThreads := systemInfo.GetOptimalThreadCount() defaultThreads := systemInfo.GetOptimalThreadCount()
if opts.NumThread > 0 { if opts.NumThread > 0 {
params = append(params, "--threads", strconv.Itoa(opts.NumThread)) params = append(params, "--threads", strconv.Itoa(opts.NumThread))
@ -257,6 +260,34 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
} }
} }
slog.Debug("compatible gpu libraries", "compatible", compatible) slog.Debug("compatible gpu libraries", "compatible", compatible)
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
var llamaModel *llama.Model
var textProcessor model.TextProcessor
if envconfig.NewEngine() {
textProcessor, err = model.NewTextProcessor(modelPath)
if err != nil {
// To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
}
}
if textProcessor == nil {
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return nil, err
}
}
if len(projectors) > 0 && llamaModel != nil {
params = append(params, "--mmproj", projectors[0])
}
// iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc. // iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc.
// adding each library's respective path to the LD_LIBRARY_PATH, until finally running // adding each library's respective path to the LD_LIBRARY_PATH, until finally running
@ -275,7 +306,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
} }
finalParams := []string{"runner"} finalParams := []string{"runner"}
if envconfig.NewEngine() { if textProcessor != nil {
// New engine
// TODO - if we have failure to load scenarios, add logic to retry with the old runner
finalParams = append(finalParams, "--ollama-engine") finalParams = append(finalParams, "--ollama-engine")
} }
finalParams = append(finalParams, params...) finalParams = append(finalParams, params...)
@ -315,28 +348,20 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
// finally, add the root library path // finally, add the root library path
libraryPaths = append(libraryPaths, discover.LibOllamaPath) libraryPaths = append(libraryPaths, discover.LibOllamaPath)
exe, err := os.Executable()
if err != nil {
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
}
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
// TODO - once fully switched to the Go runner, load the model here for tokenize/detokenize cgo access
s := &llmServer{ s := &llmServer{
port: port, port: port,
cmd: exec.Command(exe, finalParams...), cmd: exec.Command(exe, finalParams...),
status: NewStatusWriter(os.Stderr), status: NewStatusWriter(os.Stderr),
options: opts, options: opts,
modelPath: model, modelPath: modelPath,
estimate: estimate, llamaModel: llamaModel,
numParallel: numParallel, textProcessor: textProcessor,
sem: semaphore.NewWeighted(int64(numParallel)), estimate: estimate,
totalLayers: f.KV().BlockCount() + 1, numParallel: numParallel,
gpus: gpus, sem: semaphore.NewWeighted(int64(numParallel)),
done: make(chan error, 1), totalLayers: f.KV().BlockCount() + 1,
gpus: gpus,
done: make(chan error, 1),
} }
s.cmd.Env = os.Environ() s.cmd.Env = os.Environ()
@ -405,6 +430,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
} }
err := fmt.Errorf("error starting runner: %v %s", err, msg) err := fmt.Errorf("error starting runner: %v %s", err, msg)
if len(compatible) == 0 { if len(compatible) == 0 {
if llamaModel != nil {
llama.FreeModel(llamaModel)
}
return nil, err return nil, err
} }
@ -701,24 +729,29 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
} }
if len(req.Format) > 0 { if len(req.Format) > 0 {
switch string(req.Format) { format := string(req.Format)
case `null`, `""`: if format != `null` && format != `""` {
// Field was set, but "missing" a value. We accept if s.textProcessor != nil {
// these as "not set". // New engine handles this on the backend
break request["format"] = req.Format
case `"json"`: } else {
request["grammar"] = grammarJSON // old engine
default: switch format {
if req.Format[0] != '{' { case `"json"`:
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format) request["grammar"] = grammarJSON
} default:
if req.Format[0] != '{' {
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
}
// User provided a JSON schema // User provided a JSON schema
g := llama.SchemaToGrammar(req.Format) g := llama.SchemaToGrammar(req.Format)
if g == nil { if g == nil {
return fmt.Errorf("invalid JSON schema in format") return fmt.Errorf("invalid JSON schema in format")
}
request["grammar"] = string(g)
}
} }
request["grammar"] = string(g)
} }
} }
@ -933,64 +966,25 @@ type TokenizeResponse struct {
} }
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
s.modelLock.Lock() s.llamaModelLock.Lock()
defer s.modelLock.Unlock() defer s.llamaModelLock.Unlock()
if s.model != nil {
return s.model.Tokenize(content, false, true)
}
// Make sure the server is ready if s.llamaModel != nil {
status, err := s.getServerStatus(ctx) return s.llamaModel.Tokenize(content, false, true)
if err != nil {
return nil, err
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
} }
if s.textProcessor != nil {
data, err := json.Marshal(TokenizeRequest{Content: content}) tokens, err := s.textProcessor.Encode(content)
if err != nil { if err != nil {
return nil, fmt.Errorf("marshaling encode data: %w", err) return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("encode request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("do encode request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
if s.model == nil {
slog.Debug("new runner detected, loading model for cgo tokenization")
m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return nil, err
}
s.model = m
} }
return s.model.Tokenize(content, false, true) toks := make([]int, len(tokens))
for i, t := range tokens {
toks[i] = int(t)
}
return toks, nil
} }
// not reached
body, err := io.ReadAll(resp.Body) return nil, fmt.Errorf("no tokenizer configured")
if err != nil {
return nil, fmt.Errorf("read encode request: %w", err)
}
if resp.StatusCode >= 400 {
log.Printf("llm encode error: %s", body)
return nil, fmt.Errorf("%s", body)
}
var encoded TokenizeResponse
if err := json.Unmarshal(body, &encoded); err != nil {
return nil, fmt.Errorf("unmarshal encode response: %w", err)
}
return encoded.Tokens, nil
} }
type DetokenizeRequest struct { type DetokenizeRequest struct {
@ -1002,80 +996,38 @@ type DetokenizeResponse struct {
} }
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) { func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
s.modelLock.Lock() s.llamaModelLock.Lock()
defer s.modelLock.Unlock() defer s.llamaModelLock.Unlock()
if s.model != nil {
if s.llamaModel != nil {
var resp string var resp string
for _, token := range tokens { for _, token := range tokens {
resp += s.model.TokenToPiece(token) resp += s.llamaModel.TokenToPiece(token)
} }
return resp, nil return resp, nil
} }
// Make sure the server is ready if s.textProcessor != nil {
status, err := s.getServerStatus(ctx) toks := make([]int32, len(tokens))
if err != nil { for i, t := range tokens {
return "", err toks[i] = int32(t)
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
return "", fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
if err != nil {
return "", fmt.Errorf("marshaling decode data: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data))
if err != nil {
return "", fmt.Errorf("decode request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("do decode request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
if s.model == nil {
slog.Debug("new runner detected, loading model for cgo tokenization")
m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
if err != nil {
return "", err
}
s.model = m
} }
var resp string content, err := s.textProcessor.Decode(toks)
for _, token := range tokens { if err != nil {
resp += s.model.TokenToPiece(token) return "", err
} }
return resp, nil return content, nil
} }
// not reached
body, err := io.ReadAll(resp.Body) return "", fmt.Errorf("no tokenizer configured")
if err != nil {
return "", fmt.Errorf("read decode request: %w", err)
}
if resp.StatusCode >= 400 {
log.Printf("llm decode error: %s", body)
return "", fmt.Errorf("%s", body)
}
var decoded DetokenizeResponse
if err := json.Unmarshal(body, &decoded); err != nil {
return "", fmt.Errorf("unmarshal encode response: %w", err)
}
return decoded.Content, nil
} }
func (s *llmServer) Close() error { func (s *llmServer) Close() error {
s.modelLock.Lock() s.llamaModelLock.Lock()
if s.model != nil { if s.llamaModel != nil {
llama.FreeModel(s.model) llama.FreeModel(s.llamaModel)
s.model = nil s.llamaModel = nil
} }
s.modelLock.Unlock() s.llamaModelLock.Unlock()
if s.cmd != nil { if s.cmd != nil {
slog.Debug("stopping llama server") slog.Debug("stopping llama server")

View File

@ -16,6 +16,7 @@ import (
_ "golang.org/x/image/tiff" _ "golang.org/x/image/tiff"
_ "golang.org/x/image/webp" _ "golang.org/x/image/webp"
fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
_ "github.com/ollama/ollama/ml/backend" _ "github.com/ollama/ollama/ml/backend"
@ -100,6 +101,36 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
return m, nil return m, nil
} }
func NewTextProcessor(s string) (TextProcessor, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
}
defer r.Close()
meta, _, err := fs.Decode(r, -1)
if err != nil {
return nil, err
}
return getTextProcessor(meta.KV())
}
func getTextProcessor(kv fs.KV) (TextProcessor, error) {
arch := kv.Architecture()
f, ok := models[arch]
if !ok {
return nil, fmt.Errorf("unsupported model architecture %q", arch)
}
m, err := f(kv)
if err != nil {
return nil, err
}
tp, ok := m.(TextProcessor)
if !ok {
return nil, fmt.Errorf("%v is not a TextProcessor", m)
}
return tp, nil
}
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
t := v.Type() t := v.Type()

View File

@ -3,9 +3,11 @@ package model
import ( import (
"reflect" "reflect"
"slices" "slices"
"strings"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
@ -134,3 +136,40 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff) t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
} }
} }
func TestGetTextProcessor(t *testing.T) {
tp, err := getTextProcessor(fs.KV{})
if err == nil {
t.Error("expected error")
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
t.Errorf("unexpected error: %v", err)
} else if tp != nil {
t.Error("expected nil tp")
}
models["dummy"] = func(ml.Config) (Model, error) {
return notTextProcessorModel{}, nil
}
tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"})
if err == nil {
t.Error("expected error")
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
t.Errorf("unexpected error: %v", err)
} else if tp != nil {
t.Error("expected nil tp")
}
}
type notTextProcessorModel struct{}
func (notTextProcessorModel) Forward(ml.Context, Options) (ml.Tensor, error) {
panic("unimplemented")
}
func (notTextProcessorModel) Backend() ml.Backend {
panic("unimplemented")
}
func (notTextProcessorModel) Config() config {
panic("unimplemented")
}

View File

@ -1,7 +1,9 @@
package llama package llama
import ( import (
"fmt"
"math" "math"
"strings"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
@ -29,6 +31,10 @@ type Model struct {
} }
func New(c ml.Config) (model.Model, error) { func New(c ml.Config) (model.Model, error) {
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
}
m := Model{ m := Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),

View File

@ -1,6 +1,8 @@
package mllama package mllama
import ( import (
"fmt"
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
@ -25,6 +27,10 @@ const (
) )
func New(c ml.Config) (model.Model, error) { func New(c ml.Config) (model.Model, error) {
// Verify unified config
if c.Uint("vision.block_count") == 0 {
return nil, fmt.Errorf("non-unified vision model not supported")
}
m := Model{ m := Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),

View File

@ -10,7 +10,6 @@ import (
"strings" "strings"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/model/models/mllama" "github.com/ollama/ollama/model/models/mllama"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
@ -93,7 +92,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
var imgData llm.ImageData var imgData llm.ImageData
if isMllama { if isMllama {
if envconfig.NewEngine() { if len(m.ProjectorPaths) == 0 {
imgData = llm.ImageData{ imgData = llm.ImageData{
ID: len(images), ID: len(images),
Data: i, Data: i,

View File

@ -205,7 +205,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
images := make([]llm.ImageData, len(req.Images)) images := make([]llm.ImageData, len(req.Images))
for i := range req.Images { for i := range req.Images {
if isMllama && !envconfig.NewEngine() { if isMllama && len(model.ProjectorPaths) > 0 {
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i])) data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"}) c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})