mirror of
https://github.com/ollama/ollama.git
synced 2025-08-24 09:41:07 +02:00
server: skip parsing initial <think> if provided in the prompt (#12024)
This commit is contained in:
@@ -1673,6 +1673,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
OpeningTag: openingTag,
|
||||
ClosingTag: closingTag,
|
||||
}
|
||||
|
||||
if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) {
|
||||
thinkingState.AddContent(openingTag)
|
||||
}
|
||||
}
|
||||
|
||||
var toolParser *tools.Parser
|
||||
|
@@ -969,3 +969,233 @@ func TestGenerate(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Helper to create a standard thinking test setup
|
||||
setupThinkingTest := func(t *testing.T) (*mockRunner, *Server) {
|
||||
mock := &mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{llama: mock}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a model with thinking support
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
// Create model with thinking template that adds <think> at the end
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-thinking",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Template: `{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}user: {{ .Content }}
|
||||
{{ else if eq .Role "assistant" }}assistant: {{ if .Thinking }}<think>{{ .Thinking }}</think>{{ end }}{{ .Content }}
|
||||
{{ end }}{{ end }}<think>`,
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
return mock, s
|
||||
}
|
||||
|
||||
mock, s := setupThinkingTest(t)
|
||||
|
||||
// Helper to test chat responses
|
||||
testChatRequest := func(t *testing.T, name string, userContent string, modelResponse string, expectedThinking string, expectedContent string, think bool) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mock.CompletionResponse = llm.CompletionResponse{
|
||||
Content: modelResponse,
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
}
|
||||
mock.CompletionFn = nil
|
||||
|
||||
streamRequest := false
|
||||
req := api.ChatRequest{
|
||||
Model: "test-thinking",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: userContent},
|
||||
},
|
||||
Stream: &streamRequest,
|
||||
}
|
||||
if think {
|
||||
req.Think = &api.ThinkValue{Value: think}
|
||||
}
|
||||
|
||||
w := createRequest(t, s.ChatHandler, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.ChatResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Message.Thinking != expectedThinking {
|
||||
t.Errorf("expected thinking %q, got %q", expectedThinking, resp.Message.Thinking)
|
||||
}
|
||||
|
||||
if resp.Message.Content != expectedContent {
|
||||
t.Errorf("expected content %q, got %q", expectedContent, resp.Message.Content)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test cases - Note: Template adds <think> at the end, and leading whitespace after <think> is eaten by the parser
|
||||
testChatRequest(t, "basic thinking response",
|
||||
"Help me solve this problem",
|
||||
" Let me think about this step by step... </think> The answer is 42.",
|
||||
"Let me think about this step by step... ",
|
||||
"The answer is 42.",
|
||||
true)
|
||||
|
||||
testChatRequest(t, "thinking with multiple sentences",
|
||||
"Explain quantum computing",
|
||||
" First, I need to understand the basics. Quantum bits can be in superposition. </think> Quantum computing uses quantum mechanics principles.",
|
||||
"First, I need to understand the basics. Quantum bits can be in superposition. ",
|
||||
"Quantum computing uses quantum mechanics principles.",
|
||||
true)
|
||||
|
||||
testChatRequest(t, "no thinking content",
|
||||
"What is 2+2?",
|
||||
"</think> The answer is 4.",
|
||||
"",
|
||||
"The answer is 4.",
|
||||
true)
|
||||
|
||||
testChatRequest(t, "thinking disabled but template still adds think tag",
|
||||
"Simple question",
|
||||
" My thoughts </think> The answer.",
|
||||
"",
|
||||
" My thoughts </think> The answer.",
|
||||
false)
|
||||
|
||||
// Test streaming response with template-added <think>
|
||||
t.Run("streaming with thinking", func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
defer wg.Done()
|
||||
|
||||
// Verify the prompt ends with <think> due to template
|
||||
if !strings.HasSuffix(r.Prompt, "<think>") {
|
||||
t.Errorf("expected prompt to end with <think>, got: %q", r.Prompt)
|
||||
}
|
||||
|
||||
// Simulate streaming chunks
|
||||
responses := []llm.CompletionResponse{
|
||||
{Content: " I need to consider", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1},
|
||||
{Content: " multiple factors here...", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1},
|
||||
{Content: " </think> Based on my analysis,", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1},
|
||||
{Content: " the solution is straightforward.", Done: true, DoneReason: llm.DoneReasonStop, PromptEvalCount: 1, PromptEvalDuration: 1, EvalCount: 1, EvalDuration: 1},
|
||||
}
|
||||
|
||||
for _, resp := range responses {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
fn(resp)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
think := true
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-thinking",
|
||||
Messages: []api.Message{{Role: "user", Content: "Analyze this complex problem"}},
|
||||
Think: &api.ThinkValue{Value: think},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Parse streaming responses
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
var allThinking, allContent strings.Builder
|
||||
|
||||
for {
|
||||
var resp api.ChatResponse
|
||||
if err := decoder.Decode(&resp); err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
allThinking.WriteString(resp.Message.Thinking)
|
||||
allContent.WriteString(resp.Message.Content)
|
||||
}
|
||||
|
||||
// Note: Leading whitespace after <think> is eaten by the parser
|
||||
if got := allThinking.String(); got != "I need to consider multiple factors here... " {
|
||||
t.Errorf("expected thinking %q, got %q", "I need to consider multiple factors here... ", got)
|
||||
}
|
||||
|
||||
if got := allContent.String(); got != "Based on my analysis, the solution is straightforward." {
|
||||
t.Errorf("expected content %q, got %q", "Based on my analysis, the solution is straightforward.", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user