package openai import ( "encoding/base64" "testing" "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" ) const ( prefix = `data:image/jpeg;base64,` image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` ) func TestFromChatRequest_Basic(t *testing.T) { req := ChatCompletionRequest{ Model: "test-model", Messages: []Message{ {Role: "user", Content: "Hello"}, }, } result, err := FromChatRequest(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if result.Model != "test-model" { t.Errorf("expected model 'test-model', got %q", result.Model) } if len(result.Messages) != 1 { t.Fatalf("expected 1 message, got %d", len(result.Messages)) } if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" { t.Errorf("unexpected message: %+v", result.Messages[0]) } } func TestFromChatRequest_WithImage(t *testing.T) { imgData, _ := base64.StdEncoding.DecodeString(image) req := ChatCompletionRequest{ Model: "test-model", Messages: []Message{ { Role: "user", Content: []any{ map[string]any{"type": "text", "text": "Hello"}, map[string]any{ "type": "image_url", "image_url": map[string]any{"url": prefix + image}, }, }, }, }, } result, err := FromChatRequest(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(result.Messages) != 2 { t.Fatalf("expected 2 messages, got %d", len(result.Messages)) } if result.Messages[0].Content != "Hello" { t.Errorf("expected first message content 'Hello', got %q", result.Messages[0].Content) } if len(result.Messages[1].Images) != 1 { t.Fatalf("expected 1 image, got %d", len(result.Messages[1].Images)) } if string(result.Messages[1].Images[0]) != string(imgData) { t.Error("image data mismatch") } } func TestFromCompleteRequest_Basic(t *testing.T) { temp := float32(0.8) req := CompletionRequest{ Model: "test-model", Prompt: "Hello", Temperature: &temp, } result, err := FromCompleteRequest(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if result.Model != "test-model" { t.Errorf("expected model 'test-model', got %q", result.Model) } if result.Prompt != "Hello" { t.Errorf("expected prompt 'Hello', got %q", result.Prompt) } if tempVal, ok := result.Options["temperature"].(float32); !ok || tempVal != 0.8 { t.Errorf("expected temperature 0.8, got %v", result.Options["temperature"]) } } func TestToUsage(t *testing.T) { resp := api.ChatResponse{ Metrics: api.Metrics{ PromptEvalCount: 10, EvalCount: 20, }, } usage := ToUsage(resp) if usage.PromptTokens != 10 { t.Errorf("expected PromptTokens 10, got %d", usage.PromptTokens) } if usage.CompletionTokens != 20 { t.Errorf("expected CompletionTokens 20, got %d", usage.CompletionTokens) } if usage.TotalTokens != 30 { t.Errorf("expected TotalTokens 30, got %d", usage.TotalTokens) } } func TestNewError(t *testing.T) { tests := []struct { code int want string }{ {400, "invalid_request_error"}, {404, "not_found_error"}, {500, "api_error"}, } for _, tt := range tests { result := NewError(tt.code, "test message") if result.Error.Type != tt.want { t.Errorf("NewError(%d) type = %q, want %q", tt.code, result.Error.Type, tt.want) } if result.Error.Message != "test message" { t.Errorf("NewError(%d) message = %q, want %q", tt.code, result.Error.Message, "test message") } } } func TestToToolCallsPreservesIDs(t *testing.T) { original := []api.ToolCall{ { ID: "call_abc123", Function: api.ToolCallFunction{ Index: 2, Name: "get_weather", Arguments: api.ToolCallFunctionArguments{ "location": "Seattle", }, }, }, { ID: "call_def456", Function: api.ToolCallFunction{ Index: 7, Name: "get_time", Arguments: api.ToolCallFunctionArguments{ "timezone": "UTC", }, }, }, } toolCalls := make([]api.ToolCall, len(original)) copy(toolCalls, original) got := ToToolCalls(toolCalls) if len(got) != len(original) { t.Fatalf("expected %d tool calls, got %d", len(original), len(got)) } expected := []ToolCall{ { ID: "call_abc123", Type: "function", Index: 2, Function: struct { Name string `json:"name"` Arguments string `json:"arguments"` }{ Name: "get_weather", Arguments: `{"location":"Seattle"}`, }, }, { ID: "call_def456", Type: "function", Index: 7, Function: struct { Name string `json:"name"` Arguments string `json:"arguments"` }{ Name: "get_time", Arguments: `{"timezone":"UTC"}`, }, }, } if diff := cmp.Diff(expected, got); diff != "" { t.Errorf("tool calls mismatch (-want +got):\n%s", diff) } if diff := cmp.Diff(original, toolCalls); diff != "" { t.Errorf("input tool calls mutated (-want +got):\n%s", diff) } }