diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go index 040c2562af..ea92a9156c 100644 --- a/model/parsers/parsers.go +++ b/model/parsers/parsers.go @@ -16,7 +16,28 @@ type Parser interface { HasThinkingSupport() bool } +type ParserConstructor func() Parser + +type ParserRegistry struct { + constructors map[string]ParserConstructor +} + +func (r *ParserRegistry) Register(name string, constructor ParserConstructor) { + r.constructors[name] = constructor +} + +var registry = ParserRegistry{ + constructors: make(map[string]ParserConstructor), +} + +func Register(name string, constructor ParserConstructor) { + registry.Register(name, constructor) +} + func ParserForName(name string) Parser { + if parser, ok := registry.constructors[name]; ok { + return parser() + } switch name { case "qwen3-coder": parser := &Qwen3CoderParser{} diff --git a/model/parsers/parsers_test.go b/model/parsers/parsers_test.go new file mode 100644 index 0000000000..8a64a23589 --- /dev/null +++ b/model/parsers/parsers_test.go @@ -0,0 +1,97 @@ +package parsers + +import ( + "testing" + + "github.com/ollama/ollama/api" +) + +type mockParser struct { + name string +} + +func (m *mockParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { + return tools +} + +func (m *mockParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + return "mock:" + s, "", nil, nil +} + +func (m *mockParser) HasToolSupport() bool { + return false +} + +func (m *mockParser) HasThinkingSupport() bool { + return false +} + +func TestRegisterCustomParser(t *testing.T) { + // Register a custom parser + Register("custom-parser", func() Parser { + return &mockParser{name: "custom"} + }) + + // Retrieve it + parser := ParserForName("custom-parser") + if parser == nil { + t.Fatal("expected parser to be registered") + } + + // Test it works + content, _, _, err := parser.Add("test", false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "mock:test" { + t.Errorf("expected 'mock:test', got %q", content) + } +} + +func TestBuiltInParsersStillWork(t *testing.T) { + tests := []struct { + name string + }{ + {"passthrough"}, + {"qwen3-coder"}, + {"harmony"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := ParserForName(tt.name) + if parser == nil { + t.Fatalf("expected built-in parser %q to exist", tt.name) + } + }) + } +} + +func TestOverrideBuiltInParser(t *testing.T) { + // Override a built-in parser + Register("passthrough", func() Parser { + return &mockParser{name: "override"} + }) + + // Should get the override + parser := ParserForName("passthrough") + if parser == nil { + t.Fatal("expected parser to exist") + } + + // Test it's the override + content, _, _, err := parser.Add("test", false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if content != "mock:test" { + t.Errorf("expected 'mock:test' from override, got %q", content) + } +} + +func TestUnknownParserReturnsNil(t *testing.T) { + parser := ParserForName("nonexistent-parser") + if parser != nil { + t.Error("expected nil for unknown parser") + } +} diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go index 01e0c6ee2a..e97b4581db 100644 --- a/model/renderers/renderer.go +++ b/model/renderers/renderer.go @@ -1,12 +1,46 @@ package renderers -import "github.com/ollama/ollama/api" +import ( + "fmt" + + "github.com/ollama/ollama/api" +) type Renderer interface { Render(messages []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) } -func RendererForName(name string) Renderer { +type ( + RendererConstructor func() Renderer + RendererRegistry struct { + renderers map[string]RendererConstructor + } +) + +func (r *RendererRegistry) Register(name string, renderer RendererConstructor) { + r.renderers[name] = renderer +} + +var registry = RendererRegistry{ + renderers: make(map[string]RendererConstructor), +} + +func Register(name string, renderer RendererConstructor) { + registry.Register(name, renderer) +} + +func RenderWithRenderer(name string, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) { + renderer := rendererForName(name) + if renderer == nil { + return "", fmt.Errorf("unknown renderer %q", name) + } + return renderer.Render(msgs, tools, think) +} + +func rendererForName(name string) Renderer { + if constructor, ok := registry.renderers[name]; ok { + return constructor() + } switch name { case "qwen3-coder": renderer := &Qwen3CoderRenderer{} diff --git a/model/renderers/renderer_test.go b/model/renderers/renderer_test.go new file mode 100644 index 0000000000..8625634cb0 --- /dev/null +++ b/model/renderers/renderer_test.go @@ -0,0 +1,67 @@ +package renderers + +import ( + "testing" + + "github.com/ollama/ollama/api" +) + +type mockRenderer struct{} + +func (m *mockRenderer) Render(msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) { + return "mock-output", nil +} + +func TestRegisterCustomRenderer(t *testing.T) { + // Register a custom renderer + Register("custom-renderer", func() Renderer { + return &mockRenderer{} + }) + + // Retrieve and use it + result, err := RenderWithRenderer("custom-renderer", nil, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "mock-output" { + t.Errorf("expected 'mock-output', got %q", result) + } +} + +func TestBuiltInRendererStillWorks(t *testing.T) { + // Test that qwen3-coder still works + messages := []api.Message{ + {Role: "user", Content: "Hello"}, + } + + result, err := RenderWithRenderer("qwen3-coder", messages, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result == "" { + t.Error("expected non-empty result from qwen3-coder renderer") + } +} + +func TestOverrideBuiltInRenderer(t *testing.T) { + // Override the built-in renderer + Register("qwen3-coder", func() Renderer { + return &mockRenderer{} + }) + + // Should get the override + result, err := RenderWithRenderer("qwen3-coder", nil, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != "mock-output" { + t.Errorf("expected 'mock-output' from override, got %q", result) + } +} + +func TestUnknownRendererReturnsError(t *testing.T) { + _, err := RenderWithRenderer("nonexistent-renderer", nil, nil, nil) + if err == nil { + t.Error("expected error for unknown renderer") + } +} diff --git a/server/prompt.go b/server/prompt.go index f1f7cec4de..2175919821 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -106,8 +106,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) { if m.Config.Renderer != "" { - renderer := renderers.RendererForName(m.Config.Renderer) - rendered, err := renderer.Render(msgs, tools, think) + rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think) if err != nil { return "", err }