mirror of
https://github.com/ollama/ollama.git
synced 2025-03-18 05:41:43 +01:00
New engine: vision models and auto-fallback (#9113)
* Include unified vision layers in memory prediction For newer vision models with a single gguf, include the projection estimates. * Adjust CLI to handle both styles of vision model metadata * Wire up new tokenizers for new engine If we're loading the new engine, utilize the new model text processor instead of calling into cgo wrappers for llama.cpp. This also cleans up some tech debt from the older tokenization flow for the C++ server which was no longer used. This also adjusts the grammar handling logic to pass through to the new engine instead of utilizing the cgo schema to grammar call. * Lay foundation for auto selection of new engine
This commit is contained in:
parent
7a01ad7614
commit
1fdb351c37
14
cmd/cmd.go
14
cmd/cmd.go
@ -339,10 +339,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(jessegross): We should either find another way to know if this is
|
||||
// a vision model or remove the logic. Also consider that other modalities will
|
||||
// need different behavior anyways.
|
||||
opts.MultiModal = len(info.ProjectorInfo) != 0 || envconfig.NewEngine()
|
||||
if len(info.ProjectorInfo) != 0 {
|
||||
opts.MultiModal = true
|
||||
}
|
||||
for k := range info.ModelInfo {
|
||||
if strings.Contains(k, ".vision.") {
|
||||
opts.MultiModal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
opts.ParentModel = info.Details.ParentModel
|
||||
|
||||
if interactive {
|
||||
|
@ -565,6 +565,43 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
return
|
||||
}
|
||||
|
||||
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||
switch llm.KV().Architecture() {
|
||||
case "mllama":
|
||||
for _, layer := range llm.Tensors().GroupLayers()["v"] {
|
||||
weights += layer.Size()
|
||||
}
|
||||
|
||||
kv := func(n string) uint64 {
|
||||
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
|
||||
return uint64(v)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
imageSize := kv("image_size")
|
||||
|
||||
maxNumTiles := kv("max_num_tiles")
|
||||
embeddingLength := kv("embedding_length")
|
||||
headCount := kv("attention.head_count")
|
||||
|
||||
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
|
||||
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
||||
numPatches++
|
||||
}
|
||||
|
||||
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
|
||||
|
||||
graphSize = 4 * (8 +
|
||||
imageSize*imageSize*kv("num_channels")*maxNumTiles +
|
||||
embeddingLength*numPatches*maxNumTiles +
|
||||
9*embeddingLength*numPaddedPatches*maxNumTiles +
|
||||
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
|
||||
}
|
||||
return weights, graphSize
|
||||
}
|
||||
|
||||
// SupportsKVCacheType checks if the requested cache type is supported
|
||||
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
|
||||
|
@ -115,6 +115,9 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
// multimodal models require at least 2048 context
|
||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||
}
|
||||
if projectorWeights == 0 && projectorGraph == 0 {
|
||||
projectorWeights, projectorGraph = f.VisionGraphSize()
|
||||
}
|
||||
|
||||
layers := f.Tensors().GroupLayers()
|
||||
// add one layer worth of memory as a buffer
|
||||
|
278
llm/server.go
278
llm/server.go
@ -30,6 +30,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type LlamaServer interface {
|
||||
@ -54,8 +55,15 @@ type llmServer struct {
|
||||
options api.Options
|
||||
numParallel int
|
||||
modelPath string
|
||||
modelLock sync.Mutex // Temporary until we switch fully to Go server
|
||||
model *llama.Model // If non-nil, the runner is a new Go server
|
||||
|
||||
// llamaModel is an instance of the cgo llama.cpp model definition
|
||||
// nil if this server is running the new engine
|
||||
llamaModel *llama.Model
|
||||
llamaModelLock sync.Mutex
|
||||
|
||||
// textProcessor handles text encoding/decoding for the model in the Ollama engine
|
||||
// nil if this server is running the llama.cpp based engine
|
||||
textProcessor model.TextProcessor
|
||||
|
||||
estimate MemoryEstimate
|
||||
totalLayers uint64
|
||||
@ -89,7 +97,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
|
||||
|
||||
// NewLlamaServer will run a server for the given GPUs
|
||||
// The gpu list must be a single family.
|
||||
func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
systemInfo := discover.GetSystemInfo()
|
||||
systemTotalMemory := systemInfo.System.TotalMemory
|
||||
systemFreeMemory := systemInfo.System.FreeMemory
|
||||
@ -130,7 +138,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
slog.Info("offload", "", estimate)
|
||||
|
||||
params := []string{
|
||||
"--model", model,
|
||||
"--model", modelPath,
|
||||
"--ctx-size", strconv.Itoa(opts.NumCtx),
|
||||
"--batch-size", strconv.Itoa(opts.NumBatch),
|
||||
}
|
||||
@ -153,11 +161,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
}
|
||||
}
|
||||
|
||||
if len(projectors) > 0 {
|
||||
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
|
||||
params = append(params, "--mmproj", projectors[0])
|
||||
}
|
||||
|
||||
defaultThreads := systemInfo.GetOptimalThreadCount()
|
||||
if opts.NumThread > 0 {
|
||||
params = append(params, "--threads", strconv.Itoa(opts.NumThread))
|
||||
@ -257,6 +260,34 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
}
|
||||
}
|
||||
slog.Debug("compatible gpu libraries", "compatible", compatible)
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
}
|
||||
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
var llamaModel *llama.Model
|
||||
var textProcessor model.TextProcessor
|
||||
if envconfig.NewEngine() {
|
||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||
if err != nil {
|
||||
// To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
|
||||
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
|
||||
}
|
||||
}
|
||||
if textProcessor == nil {
|
||||
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(projectors) > 0 && llamaModel != nil {
|
||||
params = append(params, "--mmproj", projectors[0])
|
||||
}
|
||||
|
||||
// iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc.
|
||||
// adding each library's respective path to the LD_LIBRARY_PATH, until finally running
|
||||
@ -275,7 +306,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
||||
}
|
||||
finalParams := []string{"runner"}
|
||||
if envconfig.NewEngine() {
|
||||
if textProcessor != nil {
|
||||
// New engine
|
||||
// TODO - if we have failure to load scenarios, add logic to retry with the old runner
|
||||
finalParams = append(finalParams, "--ollama-engine")
|
||||
}
|
||||
finalParams = append(finalParams, params...)
|
||||
@ -315,28 +348,20 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
// finally, add the root library path
|
||||
libraryPaths = append(libraryPaths, discover.LibOllamaPath)
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
}
|
||||
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
// TODO - once fully switched to the Go runner, load the model here for tokenize/detokenize cgo access
|
||||
s := &llmServer{
|
||||
port: port,
|
||||
cmd: exec.Command(exe, finalParams...),
|
||||
status: NewStatusWriter(os.Stderr),
|
||||
options: opts,
|
||||
modelPath: model,
|
||||
estimate: estimate,
|
||||
numParallel: numParallel,
|
||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||
totalLayers: f.KV().BlockCount() + 1,
|
||||
gpus: gpus,
|
||||
done: make(chan error, 1),
|
||||
port: port,
|
||||
cmd: exec.Command(exe, finalParams...),
|
||||
status: NewStatusWriter(os.Stderr),
|
||||
options: opts,
|
||||
modelPath: modelPath,
|
||||
llamaModel: llamaModel,
|
||||
textProcessor: textProcessor,
|
||||
estimate: estimate,
|
||||
numParallel: numParallel,
|
||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||
totalLayers: f.KV().BlockCount() + 1,
|
||||
gpus: gpus,
|
||||
done: make(chan error, 1),
|
||||
}
|
||||
|
||||
s.cmd.Env = os.Environ()
|
||||
@ -405,6 +430,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
}
|
||||
err := fmt.Errorf("error starting runner: %v %s", err, msg)
|
||||
if len(compatible) == 0 {
|
||||
if llamaModel != nil {
|
||||
llama.FreeModel(llamaModel)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -701,24 +729,29 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
}
|
||||
|
||||
if len(req.Format) > 0 {
|
||||
switch string(req.Format) {
|
||||
case `null`, `""`:
|
||||
// Field was set, but "missing" a value. We accept
|
||||
// these as "not set".
|
||||
break
|
||||
case `"json"`:
|
||||
request["grammar"] = grammarJSON
|
||||
default:
|
||||
if req.Format[0] != '{' {
|
||||
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
|
||||
}
|
||||
format := string(req.Format)
|
||||
if format != `null` && format != `""` {
|
||||
if s.textProcessor != nil {
|
||||
// New engine handles this on the backend
|
||||
request["format"] = req.Format
|
||||
} else {
|
||||
// old engine
|
||||
switch format {
|
||||
case `"json"`:
|
||||
request["grammar"] = grammarJSON
|
||||
default:
|
||||
if req.Format[0] != '{' {
|
||||
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
|
||||
}
|
||||
|
||||
// User provided a JSON schema
|
||||
g := llama.SchemaToGrammar(req.Format)
|
||||
if g == nil {
|
||||
return fmt.Errorf("invalid JSON schema in format")
|
||||
// User provided a JSON schema
|
||||
g := llama.SchemaToGrammar(req.Format)
|
||||
if g == nil {
|
||||
return fmt.Errorf("invalid JSON schema in format")
|
||||
}
|
||||
request["grammar"] = string(g)
|
||||
}
|
||||
}
|
||||
request["grammar"] = string(g)
|
||||
}
|
||||
}
|
||||
|
||||
@ -933,64 +966,25 @@ type TokenizeResponse struct {
|
||||
}
|
||||
|
||||
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
s.modelLock.Lock()
|
||||
defer s.modelLock.Unlock()
|
||||
if s.model != nil {
|
||||
return s.model.Tokenize(content, false, true)
|
||||
}
|
||||
s.llamaModelLock.Lock()
|
||||
defer s.llamaModelLock.Unlock()
|
||||
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
if s.llamaModel != nil {
|
||||
return s.llamaModel.Tokenize(content, false, true)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(TokenizeRequest{Content: content})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling encode data: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do encode request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
if s.model == nil {
|
||||
slog.Debug("new runner detected, loading model for cgo tokenization")
|
||||
m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.model = m
|
||||
if s.textProcessor != nil {
|
||||
tokens, err := s.textProcessor.Encode(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.model.Tokenize(content, false, true)
|
||||
toks := make([]int, len(tokens))
|
||||
for i, t := range tokens {
|
||||
toks[i] = int(t)
|
||||
}
|
||||
return toks, nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read encode request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm encode error: %s", body)
|
||||
return nil, fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var encoded TokenizeResponse
|
||||
if err := json.Unmarshal(body, &encoded); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
return encoded.Tokens, nil
|
||||
// not reached
|
||||
return nil, fmt.Errorf("no tokenizer configured")
|
||||
}
|
||||
|
||||
type DetokenizeRequest struct {
|
||||
@ -1002,80 +996,38 @@ type DetokenizeResponse struct {
|
||||
}
|
||||
|
||||
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
s.modelLock.Lock()
|
||||
defer s.modelLock.Unlock()
|
||||
if s.model != nil {
|
||||
s.llamaModelLock.Lock()
|
||||
defer s.llamaModelLock.Unlock()
|
||||
|
||||
if s.llamaModel != nil {
|
||||
var resp string
|
||||
for _, token := range tokens {
|
||||
resp += s.model.TokenToPiece(token)
|
||||
resp += s.llamaModel.TokenToPiece(token)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
|
||||
return "", fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshaling decode data: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("do decode request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
if s.model == nil {
|
||||
slog.Debug("new runner detected, loading model for cgo tokenization")
|
||||
m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
s.model = m
|
||||
if s.textProcessor != nil {
|
||||
toks := make([]int32, len(tokens))
|
||||
for i, t := range tokens {
|
||||
toks[i] = int32(t)
|
||||
}
|
||||
var resp string
|
||||
for _, token := range tokens {
|
||||
resp += s.model.TokenToPiece(token)
|
||||
content, err := s.textProcessor.Decode(toks)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp, nil
|
||||
return content, nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read decode request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm decode error: %s", body)
|
||||
return "", fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var decoded DetokenizeResponse
|
||||
if err := json.Unmarshal(body, &decoded); err != nil {
|
||||
return "", fmt.Errorf("unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
return decoded.Content, nil
|
||||
// not reached
|
||||
return "", fmt.Errorf("no tokenizer configured")
|
||||
}
|
||||
|
||||
func (s *llmServer) Close() error {
|
||||
s.modelLock.Lock()
|
||||
if s.model != nil {
|
||||
llama.FreeModel(s.model)
|
||||
s.model = nil
|
||||
s.llamaModelLock.Lock()
|
||||
if s.llamaModel != nil {
|
||||
llama.FreeModel(s.llamaModel)
|
||||
s.llamaModel = nil
|
||||
}
|
||||
s.modelLock.Unlock()
|
||||
s.llamaModelLock.Unlock()
|
||||
|
||||
if s.cmd != nil {
|
||||
slog.Debug("stopping llama server")
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
_ "golang.org/x/image/tiff"
|
||||
_ "golang.org/x/image/webp"
|
||||
|
||||
fs "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
@ -100,6 +101,36 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
r, err := os.Open(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
meta, _, err := fs.Decode(r, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return getTextProcessor(meta.KV())
|
||||
}
|
||||
|
||||
func getTextProcessor(kv fs.KV) (TextProcessor, error) {
|
||||
arch := kv.Architecture()
|
||||
f, ok := models[arch]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
||||
}
|
||||
m, err := f(kv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tp, ok := m.(TextProcessor)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%v is not a TextProcessor", m)
|
||||
}
|
||||
return tp, nil
|
||||
}
|
||||
|
||||
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
t := v.Type()
|
||||
|
||||
|
@ -3,9 +3,11 @@ package model
|
||||
import (
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
fs "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
@ -134,3 +136,40 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
||||
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTextProcessor(t *testing.T) {
|
||||
tp, err := getTextProcessor(fs.KV{})
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
} else if tp != nil {
|
||||
t.Error("expected nil tp")
|
||||
}
|
||||
|
||||
models["dummy"] = func(ml.Config) (Model, error) {
|
||||
return notTextProcessorModel{}, nil
|
||||
}
|
||||
tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"})
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
} else if tp != nil {
|
||||
t.Error("expected nil tp")
|
||||
}
|
||||
}
|
||||
|
||||
type notTextProcessorModel struct{}
|
||||
|
||||
func (notTextProcessorModel) Forward(ml.Context, Options) (ml.Tensor, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (notTextProcessorModel) Backend() ml.Backend {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (notTextProcessorModel) Config() config {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
@ -1,7 +1,9 @@
|
||||
package llama
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
@ -29,6 +31,10 @@ type Model struct {
|
||||
}
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
|
||||
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
|
@ -1,6 +1,8 @@
|
||||
package mllama
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
@ -25,6 +27,10 @@ const (
|
||||
)
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
// Verify unified config
|
||||
if c.Uint("vision.block_count") == 0 {
|
||||
return nil, fmt.Errorf("non-unified vision model not supported")
|
||||
}
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
|
@ -10,7 +10,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/model/models/mllama"
|
||||
"github.com/ollama/ollama/template"
|
||||
@ -93,7 +92,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
var imgData llm.ImageData
|
||||
|
||||
if isMllama {
|
||||
if envconfig.NewEngine() {
|
||||
if len(m.ProjectorPaths) == 0 {
|
||||
imgData = llm.ImageData{
|
||||
ID: len(images),
|
||||
Data: i,
|
||||
|
@ -205,7 +205,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
images := make([]llm.ImageData, len(req.Images))
|
||||
for i := range req.Images {
|
||||
if isMllama && !envconfig.NewEngine() {
|
||||
if isMllama && len(model.ProjectorPaths) > 0 {
|
||||
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
|
||||
|
Loading…
x
Reference in New Issue
Block a user