diff --git a/llm/server.go b/llm/server.go index 83690fcdc1..c4b84950e2 100644 --- a/llm/server.go +++ b/llm/server.go @@ -84,13 +84,12 @@ type LlamaServer interface { // llmServer is an instance of a runner hosting a single model type llmServer struct { - port int - cmd *exec.Cmd - done chan error // Channel to signal when the process exits - status *StatusWriter - options api.Options - numParallel int - modelPath string + port int + cmd *exec.Cmd + done chan error // Channel to signal when the process exits + status *StatusWriter + options api.Options + modelPath string loadRequest LoadRequest // Parameters used to initialize the runner mem *ml.BackendMemory // Memory allocations for this model @@ -100,10 +99,6 @@ type llmServer struct { llamaModel *llama.Model llamaModelLock *sync.Mutex - // textProcessor handles text encoding/decoding for the model in the Ollama engine - // nil if this server is running the llama.cpp based engine - textProcessor model.TextProcessor - totalLayers uint64 loadStart time.Time // Record how long it took the model to load loadProgress float32 @@ -119,6 +114,8 @@ type llamaServer struct { type ollamaServer struct { llmServer + + textProcessor model.TextProcessor // textProcessor handles text encoding/decoding } // LoadModel will load a model from disk. The model must be in the GGML format. @@ -242,8 +239,6 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st loadRequest: loadRequest, llamaModel: llamaModel, llamaModelLock: &sync.Mutex{}, - textProcessor: textProcessor, - numParallel: numParallel, sem: semaphore.NewWeighted(int64(numParallel)), totalLayers: f.KV().BlockCount() + 1, loadStart: time.Now(), @@ -278,7 +273,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st }() if textProcessor != nil { - return &ollamaServer{llmServer: s}, nil + return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil } else { return &llamaServer{llmServer: s, ggml: f}, nil } @@ -1681,68 +1676,59 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err return e.Embedding, nil } -type TokenizeRequest struct { - Content string `json:"content"` -} - -type TokenizeResponse struct { - Tokens []int `json:"tokens"` -} - -func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) { +func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, error) { s.llamaModelLock.Lock() defer s.llamaModelLock.Unlock() - if s.llamaModel != nil { - return s.llamaModel.Tokenize(content, false, true) + if s.llamaModel == nil { + return nil, fmt.Errorf("no tokenizer configured") } - if s.textProcessor != nil { - tokens, err := s.textProcessor.Encode(content, false) - if err != nil { - return nil, err - } - toks := make([]int, len(tokens)) - for i, t := range tokens { - toks[i] = int(t) - } - return toks, nil + + return s.llamaModel.Tokenize(content, false, true) +} + +func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) { + tokens, err := s.textProcessor.Encode(content, false) + if err != nil { + return nil, err } - // not reached - return nil, fmt.Errorf("no tokenizer configured") + + toks := make([]int, len(tokens)) + for i, t := range tokens { + toks[i] = int(t) + } + + return toks, nil } -type DetokenizeRequest struct { - Tokens []int `json:"tokens"` -} - -type DetokenizeResponse struct { - Content string `json:"content"` -} - -func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) { +func (s *llamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) { s.llamaModelLock.Lock() defer s.llamaModelLock.Unlock() - if s.llamaModel != nil { - var resp string - for _, token := range tokens { - resp += s.llamaModel.TokenToPiece(token) - } - return resp, nil + if s.llamaModel == nil { + return "", fmt.Errorf("no tokenizer configured") } - if s.textProcessor != nil { - toks := make([]int32, len(tokens)) - for i, t := range tokens { - toks[i] = int32(t) - } - content, err := s.textProcessor.Decode(toks) - if err != nil { - return "", err - } - return content, nil + + var resp string + for _, token := range tokens { + resp += s.llamaModel.TokenToPiece(token) } - // not reached - return "", fmt.Errorf("no tokenizer configured") + + return resp, nil +} + +func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) { + toks := make([]int32, len(tokens)) + for i, t := range tokens { + toks[i] = int32(t) + } + + content, err := s.textProcessor.Decode(toks) + if err != nil { + return "", err + } + + return content, nil } func (s *llmServer) Close() error {