mirror of
https://github.com/ollama/ollama.git
synced 2025-11-10 20:37:52 +01:00
afaik gpt-oss is the first model that meaningfully transforms tool function definitions in its template. We found that relatively common definitions that include `anyOf` were not working because the template was assuming that types were always defined via a `type` field. anyOf allows for fully recursive types, so I exposed a `toTypeScriptType()` function to handle this recursive logic in go and keep the templates cleaner. The gpt-oss templates will need to be updated to use this. We should keep building out our function definition support to more fully support the parts of json schema that make sense for this use case, but in the meantime this will unblock some users (e.g., zed's ollama integration w/ gpt-oss). Probably the most urgent is proper array support
970 lines
27 KiB
Go
970 lines
27 KiB
Go
package server
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/go-cmp/cmp"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/discover"
|
|
"github.com/ollama/ollama/fs/ggml"
|
|
"github.com/ollama/ollama/llm"
|
|
)
|
|
|
|
type mockRunner struct {
|
|
llm.LlamaServer
|
|
|
|
// CompletionRequest is only valid until the next call to Completion
|
|
llm.CompletionRequest
|
|
llm.CompletionResponse
|
|
CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error
|
|
}
|
|
|
|
func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
|
m.CompletionRequest = r
|
|
if m.CompletionFn != nil {
|
|
return m.CompletionFn(ctx, r, fn)
|
|
}
|
|
fn(m.CompletionResponse)
|
|
return nil
|
|
}
|
|
|
|
func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
|
|
for range strings.Fields(s) {
|
|
tokens = append(tokens, len(tokens))
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
|
return func(_ discover.GpuInfoList, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) {
|
|
return mock, nil
|
|
}
|
|
}
|
|
|
|
func TestGenerateChat(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
mock := mockRunner{
|
|
CompletionResponse: llm.CompletionResponse{
|
|
Done: true,
|
|
DoneReason: llm.DoneReasonStop,
|
|
PromptEvalCount: 1,
|
|
PromptEvalDuration: 1,
|
|
EvalCount: 1,
|
|
EvalDuration: 1,
|
|
},
|
|
}
|
|
|
|
s := Server{
|
|
sched: &Scheduler{
|
|
pendingReqCh: make(chan *LlmRequest, 1),
|
|
finishedReqCh: make(chan *LlmRequest, 1),
|
|
expiredCh: make(chan *runnerRef, 1),
|
|
unloadedCh: make(chan any, 1),
|
|
loaded: make(map[string]*runnerRef),
|
|
newServerFn: newMockServer(&mock),
|
|
getGpuFn: discover.GetGPUInfo,
|
|
getCpuFn: discover.GetCPUInfo,
|
|
reschedDelay: 250 * time.Millisecond,
|
|
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
|
// add small delay to simulate loading
|
|
time.Sleep(time.Millisecond)
|
|
req.successCh <- &runnerRef{
|
|
llama: &mock,
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
go s.sched.Run(t.Context())
|
|
|
|
_, digest := createBinFile(t, ggml.KV{
|
|
"general.architecture": "llama",
|
|
"llama.block_count": uint32(1),
|
|
"llama.context_length": uint32(8192),
|
|
"llama.embedding_length": uint32(4096),
|
|
"llama.attention.head_count": uint32(32),
|
|
"llama.attention.head_count_kv": uint32(8),
|
|
"tokenizer.ggml.tokens": []string{""},
|
|
"tokenizer.ggml.scores": []float32{0},
|
|
"tokenizer.ggml.token_type": []int32{0},
|
|
}, []*ggml.Tensor{
|
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
})
|
|
|
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
Model: "test",
|
|
Files: map[string]string{"file.gguf": digest},
|
|
Template: `
|
|
{{- if .Tools }}
|
|
{{ .Tools }}
|
|
{{ end }}
|
|
{{- range .Messages }}
|
|
{{- .Role }}: {{ .Content }}
|
|
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
|
{{- end }}
|
|
{{ end }}`,
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
t.Run("missing body", func(t *testing.T) {
|
|
w := createRequest(t, s.ChatHandler, nil)
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Errorf("expected status 400, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("missing thinking capability", func(t *testing.T) {
|
|
think := true
|
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
Model: "test",
|
|
Messages: []api.Message{
|
|
{Role: "user", Content: "Hello!"},
|
|
},
|
|
Think: &api.ThinkValue{Value: think},
|
|
})
|
|
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Errorf("expected status 400, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support thinking"}`); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("missing model", func(t *testing.T) {
|
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Errorf("expected status 400, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("missing capabilities chat", func(t *testing.T) {
|
|
_, digest := createBinFile(t, ggml.KV{
|
|
"general.architecture": "bert",
|
|
"bert.pooling_type": uint32(0),
|
|
}, []*ggml.Tensor{})
|
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
Model: "bert",
|
|
Files: map[string]string{"bert.gguf": digest},
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
Model: "bert",
|
|
})
|
|
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Errorf("expected status 400, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("load model", func(t *testing.T) {
|
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
Model: "test",
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
var actual api.ChatResponse
|
|
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if actual.Model != "test" {
|
|
t.Errorf("expected model test, got %s", actual.Model)
|
|
}
|
|
|
|
if !actual.Done {
|
|
t.Errorf("expected done true, got false")
|
|
}
|
|
|
|
if actual.DoneReason != "load" {
|
|
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
|
}
|
|
})
|
|
|
|
checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
|
|
t.Helper()
|
|
|
|
var actual api.ChatResponse
|
|
if err := json.NewDecoder(body).Decode(&actual); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if actual.Model != model {
|
|
t.Errorf("expected model test, got %s", actual.Model)
|
|
}
|
|
|
|
if !actual.Done {
|
|
t.Errorf("expected done false, got true")
|
|
}
|
|
|
|
if actual.DoneReason != "stop" {
|
|
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
|
|
}
|
|
|
|
if diff := cmp.Diff(actual.Message, api.Message{
|
|
Role: "assistant",
|
|
Content: content,
|
|
}); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
if actual.PromptEvalCount == 0 {
|
|
t.Errorf("expected prompt eval count > 0, got 0")
|
|
}
|
|
|
|
if actual.PromptEvalDuration == 0 {
|
|
t.Errorf("expected prompt eval duration > 0, got 0")
|
|
}
|
|
|
|
if actual.EvalCount == 0 {
|
|
t.Errorf("expected eval count > 0, got 0")
|
|
}
|
|
|
|
if actual.EvalDuration == 0 {
|
|
t.Errorf("expected eval duration > 0, got 0")
|
|
}
|
|
|
|
if actual.LoadDuration == 0 {
|
|
t.Errorf("expected load duration > 0, got 0")
|
|
}
|
|
|
|
if actual.TotalDuration == 0 {
|
|
t.Errorf("expected total duration > 0, got 0")
|
|
}
|
|
}
|
|
|
|
mock.CompletionResponse.Content = "Hi!"
|
|
t.Run("messages", func(t *testing.T) {
|
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
Model: "test",
|
|
Messages: []api.Message{
|
|
{Role: "user", Content: "Hello!"},
|
|
},
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
checkChatResponse(t, w.Body, "test", "Hi!")
|
|
})
|
|
|
|
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
Model: "test-system",
|
|
From: "test",
|
|
System: "You are a helpful assistant.",
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
t.Run("messages with model system", func(t *testing.T) {
|
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
Model: "test-system",
|
|
Messages: []api.Message{
|
|
{Role: "user", Content: "Hello!"},
|
|
},
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
checkChatResponse(t, w.Body, "test-system", "Hi!")
|
|
})
|
|
|
|
mock.CompletionResponse.Content = "Abra kadabra!"
|
|
t.Run("messages with system", func(t *testing.T) {
|
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
Model: "test-system",
|
|
Messages: []api.Message{
|
|
{Role: "system", Content: "You can perform magic tricks."},
|
|
{Role: "user", Content: "Hello!"},
|
|
},
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
|
})
|
|
|
|
t.Run("messages with interleaved system", func(t *testing.T) {
|
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
Model: "test-system",
|
|
Messages: []api.Message{
|
|
{Role: "user", Content: "Hello!"},
|
|
{Role: "assistant", Content: "I can help you with that."},
|
|
{Role: "system", Content: "You can perform magic tricks."},
|
|
{Role: "user", Content: "Help me write tests."},
|
|
},
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
|
|
})
|
|
|
|
t.Run("messages with tools (non-streaming)", func(t *testing.T) {
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("failed to create test-system model: %d", w.Code)
|
|
}
|
|
|
|
tools := []api.Tool{
|
|
{
|
|
Type: "function",
|
|
Function: api.ToolFunction{
|
|
Name: "get_weather",
|
|
Description: "Get the current weather",
|
|
Parameters: struct {
|
|
Type string `json:"type"`
|
|
Defs any `json:"$defs,omitempty"`
|
|
Items any `json:"items,omitempty"`
|
|
Required []string `json:"required"`
|
|
Properties map[string]api.ToolProperty `json:"properties"`
|
|
}{
|
|
Type: "object",
|
|
Required: []string{"location"},
|
|
Properties: map[string]api.ToolProperty{
|
|
"location": {
|
|
Type: api.PropertyType{"string"},
|
|
Description: "The city and state",
|
|
},
|
|
"unit": {
|
|
Type: api.PropertyType{"string"},
|
|
Enum: []any{"celsius", "fahrenheit"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
mock.CompletionResponse = llm.CompletionResponse{
|
|
Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
|
|
Done: true,
|
|
DoneReason: llm.DoneReasonStop,
|
|
PromptEvalCount: 1,
|
|
PromptEvalDuration: 1,
|
|
EvalCount: 1,
|
|
EvalDuration: 1,
|
|
}
|
|
|
|
streamRequest := true
|
|
|
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
Model: "test-system",
|
|
Messages: []api.Message{
|
|
{Role: "user", Content: "What's the weather in Seattle?"},
|
|
},
|
|
Tools: tools,
|
|
Stream: &streamRequest,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
var errResp struct {
|
|
Error string `json:"error"`
|
|
}
|
|
if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil {
|
|
t.Logf("Failed to decode error response: %v", err)
|
|
} else {
|
|
t.Logf("Error response: %s", errResp.Error)
|
|
}
|
|
}
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
var resp api.ChatResponse
|
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if resp.Message.ToolCalls == nil {
|
|
t.Error("expected tool calls, got nil")
|
|
}
|
|
|
|
expectedToolCall := api.ToolCall{
|
|
Function: api.ToolCallFunction{
|
|
Name: "get_weather",
|
|
Arguments: api.ToolCallFunctionArguments{
|
|
"location": "Seattle, WA",
|
|
"unit": "celsius",
|
|
},
|
|
},
|
|
}
|
|
|
|
if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
|
|
t.Errorf("tool call mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("messages with tools (streaming)", func(t *testing.T) {
|
|
tools := []api.Tool{
|
|
{
|
|
Type: "function",
|
|
Function: api.ToolFunction{
|
|
Name: "get_weather",
|
|
Description: "Get the current weather",
|
|
Parameters: struct {
|
|
Type string `json:"type"`
|
|
Defs any `json:"$defs,omitempty"`
|
|
Items any `json:"items,omitempty"`
|
|
Required []string `json:"required"`
|
|
Properties map[string]api.ToolProperty `json:"properties"`
|
|
}{
|
|
Type: "object",
|
|
Required: []string{"location"},
|
|
Properties: map[string]api.ToolProperty{
|
|
"location": {
|
|
Type: api.PropertyType{"string"},
|
|
Description: "The city and state",
|
|
},
|
|
"unit": {
|
|
Type: api.PropertyType{"string"},
|
|
Enum: []any{"celsius", "fahrenheit"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
// Simulate streaming response with multiple chunks
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
|
|
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
|
defer wg.Done()
|
|
|
|
// Send chunks with small delays to simulate streaming
|
|
responses := []llm.CompletionResponse{
|
|
{
|
|
Content: `{"name":"get_`,
|
|
Done: false,
|
|
PromptEvalCount: 1,
|
|
PromptEvalDuration: 1,
|
|
},
|
|
{
|
|
Content: `weather","arguments":{"location":"Seattle`,
|
|
Done: false,
|
|
PromptEvalCount: 2,
|
|
PromptEvalDuration: 1,
|
|
},
|
|
{
|
|
Content: `, WA","unit":"celsius"}}`,
|
|
Done: true,
|
|
DoneReason: llm.DoneReasonStop,
|
|
PromptEvalCount: 3,
|
|
PromptEvalDuration: 1,
|
|
},
|
|
}
|
|
|
|
for _, resp := range responses {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
default:
|
|
fn(resp)
|
|
time.Sleep(10 * time.Millisecond) // Small delay between chunks
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
|
Model: "test-system",
|
|
Messages: []api.Message{
|
|
{Role: "user", Content: "What's the weather in Seattle?"},
|
|
},
|
|
Tools: tools,
|
|
Stream: &stream,
|
|
})
|
|
|
|
wg.Wait()
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
// Read and validate the streamed responses
|
|
decoder := json.NewDecoder(w.Body)
|
|
var finalToolCall api.ToolCall
|
|
|
|
for {
|
|
var resp api.ChatResponse
|
|
if err := decoder.Decode(&resp); err == io.EOF {
|
|
break
|
|
} else if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if resp.Done {
|
|
if len(resp.Message.ToolCalls) != 1 {
|
|
t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
|
|
}
|
|
finalToolCall = resp.Message.ToolCalls[0]
|
|
}
|
|
}
|
|
|
|
expectedToolCall := api.ToolCall{
|
|
Function: api.ToolCallFunction{
|
|
Name: "get_weather",
|
|
Arguments: api.ToolCallFunctionArguments{
|
|
"location": "Seattle, WA",
|
|
"unit": "celsius",
|
|
},
|
|
},
|
|
}
|
|
|
|
if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
|
|
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestGenerate(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
mock := mockRunner{
|
|
CompletionResponse: llm.CompletionResponse{
|
|
Done: true,
|
|
DoneReason: llm.DoneReasonStop,
|
|
PromptEvalCount: 1,
|
|
PromptEvalDuration: 1,
|
|
EvalCount: 1,
|
|
EvalDuration: 1,
|
|
},
|
|
}
|
|
|
|
s := Server{
|
|
sched: &Scheduler{
|
|
pendingReqCh: make(chan *LlmRequest, 1),
|
|
finishedReqCh: make(chan *LlmRequest, 1),
|
|
expiredCh: make(chan *runnerRef, 1),
|
|
unloadedCh: make(chan any, 1),
|
|
loaded: make(map[string]*runnerRef),
|
|
newServerFn: newMockServer(&mock),
|
|
getGpuFn: discover.GetGPUInfo,
|
|
getCpuFn: discover.GetCPUInfo,
|
|
reschedDelay: 250 * time.Millisecond,
|
|
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
|
// add small delay to simulate loading
|
|
time.Sleep(time.Millisecond)
|
|
req.successCh <- &runnerRef{
|
|
llama: &mock,
|
|
}
|
|
},
|
|
},
|
|
}
|
|
|
|
go s.sched.Run(t.Context())
|
|
|
|
_, digest := createBinFile(t, ggml.KV{
|
|
"general.architecture": "llama",
|
|
"llama.block_count": uint32(1),
|
|
"llama.context_length": uint32(8192),
|
|
"llama.embedding_length": uint32(4096),
|
|
"llama.attention.head_count": uint32(32),
|
|
"llama.attention.head_count_kv": uint32(8),
|
|
"tokenizer.ggml.tokens": []string{""},
|
|
"tokenizer.ggml.scores": []float32{0},
|
|
"tokenizer.ggml.token_type": []int32{0},
|
|
}, []*ggml.Tensor{
|
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
|
})
|
|
|
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
Model: "test",
|
|
Files: map[string]string{"file.gguf": digest},
|
|
Template: `
|
|
{{- if .System }}System: {{ .System }} {{ end }}
|
|
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
|
{{- if .Response }}Assistant: {{ .Response }} {{ end }}
|
|
`,
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
t.Run("missing body", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, nil)
|
|
if w.Code != http.StatusNotFound {
|
|
t.Errorf("expected status 404, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("missing model", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
|
|
if w.Code != http.StatusNotFound {
|
|
t.Errorf("expected status 404, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(w.Body.String(), `{"error":"model '' not found"}`); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("missing capabilities generate", func(t *testing.T) {
|
|
_, digest := createBinFile(t, ggml.KV{
|
|
"general.architecture": "bert",
|
|
"bert.pooling_type": uint32(0),
|
|
}, []*ggml.Tensor{})
|
|
|
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
Model: "bert",
|
|
Files: map[string]string{"file.gguf": digest},
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "bert",
|
|
})
|
|
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Errorf("expected status 400, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("missing capabilities suffix", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test",
|
|
Prompt: "def add(",
|
|
Suffix: " return c",
|
|
})
|
|
|
|
if w.Code != http.StatusBadRequest {
|
|
t.Errorf("expected status 400, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support insert"}`); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("load model", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test",
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
var actual api.GenerateResponse
|
|
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if actual.Model != "test" {
|
|
t.Errorf("expected model test, got %s", actual.Model)
|
|
}
|
|
|
|
if !actual.Done {
|
|
t.Errorf("expected done true, got false")
|
|
}
|
|
|
|
if actual.DoneReason != "load" {
|
|
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
|
}
|
|
})
|
|
|
|
checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
|
|
t.Helper()
|
|
|
|
var actual api.GenerateResponse
|
|
if err := json.NewDecoder(body).Decode(&actual); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if actual.Model != model {
|
|
t.Errorf("expected model test, got %s", actual.Model)
|
|
}
|
|
|
|
if !actual.Done {
|
|
t.Errorf("expected done false, got true")
|
|
}
|
|
|
|
if actual.DoneReason != "stop" {
|
|
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
|
|
}
|
|
|
|
if actual.Response != content {
|
|
t.Errorf("expected response %s, got %s", content, actual.Response)
|
|
}
|
|
|
|
if actual.Context == nil {
|
|
t.Errorf("expected context not nil")
|
|
}
|
|
|
|
if actual.PromptEvalCount == 0 {
|
|
t.Errorf("expected prompt eval count > 0, got 0")
|
|
}
|
|
|
|
if actual.PromptEvalDuration == 0 {
|
|
t.Errorf("expected prompt eval duration > 0, got 0")
|
|
}
|
|
|
|
if actual.EvalCount == 0 {
|
|
t.Errorf("expected eval count > 0, got 0")
|
|
}
|
|
|
|
if actual.EvalDuration == 0 {
|
|
t.Errorf("expected eval duration > 0, got 0")
|
|
}
|
|
|
|
if actual.LoadDuration == 0 {
|
|
t.Errorf("expected load duration > 0, got 0")
|
|
}
|
|
|
|
if actual.TotalDuration == 0 {
|
|
t.Errorf("expected total duration > 0, got 0")
|
|
}
|
|
}
|
|
|
|
mock.CompletionResponse.Content = "Hi!"
|
|
t.Run("prompt", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test",
|
|
Prompt: "Hello!",
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
checkGenerateResponse(t, w.Body, "test", "Hi!")
|
|
})
|
|
|
|
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
Model: "test-system",
|
|
From: "test",
|
|
System: "You are a helpful assistant.",
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
t.Run("prompt with model system", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test-system",
|
|
Prompt: "Hello!",
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
checkGenerateResponse(t, w.Body, "test-system", "Hi!")
|
|
})
|
|
|
|
mock.CompletionResponse.Content = "Abra kadabra!"
|
|
t.Run("prompt with system", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test-system",
|
|
Prompt: "Hello!",
|
|
System: "You can perform magic tricks.",
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
|
|
})
|
|
|
|
t.Run("prompt with template", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test-system",
|
|
Prompt: "Help me write tests.",
|
|
System: "You can perform magic tricks.",
|
|
Template: `{{- if .System }}{{ .System }} {{ end }}
|
|
{{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
|
|
{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
|
|
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
|
|
})
|
|
|
|
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
|
Model: "test-suffix",
|
|
Template: `{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
|
|
{{- else }}{{ .Prompt }}
|
|
{{- end }}`,
|
|
From: "test",
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
t.Run("prompt with suffix", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test-suffix",
|
|
Prompt: "def add(",
|
|
Suffix: " return c",
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("prompt without suffix", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test-suffix",
|
|
Prompt: "def add(",
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
|
|
t.Run("raw", func(t *testing.T) {
|
|
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
|
Model: "test-system",
|
|
Prompt: "Help me write tests.",
|
|
Raw: true,
|
|
Stream: &stream,
|
|
})
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", w.Code)
|
|
}
|
|
|
|
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
|
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
|
}
|
|
})
|
|
}
|