mirror of
https://github.com/ollama/ollama.git
synced 2025-11-10 21:37:14 +01:00
This makes the core openai compat layer independent of the middleware that adapts it to our particular gin routes
151 lines
3.4 KiB
Go
151 lines
3.4 KiB
Go
package openai
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
)
|
|
|
|
const (
|
|
prefix = `data:image/jpeg;base64,`
|
|
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
|
)
|
|
|
|
func TestFromChatRequest_Basic(t *testing.T) {
|
|
req := ChatCompletionRequest{
|
|
Model: "test-model",
|
|
Messages: []Message{
|
|
{Role: "user", Content: "Hello"},
|
|
},
|
|
}
|
|
|
|
result, err := FromChatRequest(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
if result.Model != "test-model" {
|
|
t.Errorf("expected model 'test-model', got %q", result.Model)
|
|
}
|
|
|
|
if len(result.Messages) != 1 {
|
|
t.Fatalf("expected 1 message, got %d", len(result.Messages))
|
|
}
|
|
|
|
if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" {
|
|
t.Errorf("unexpected message: %+v", result.Messages[0])
|
|
}
|
|
}
|
|
|
|
func TestFromChatRequest_WithImage(t *testing.T) {
|
|
imgData, _ := base64.StdEncoding.DecodeString(image)
|
|
|
|
req := ChatCompletionRequest{
|
|
Model: "test-model",
|
|
Messages: []Message{
|
|
{
|
|
Role: "user",
|
|
Content: []any{
|
|
map[string]any{"type": "text", "text": "Hello"},
|
|
map[string]any{
|
|
"type": "image_url",
|
|
"image_url": map[string]any{"url": prefix + image},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
result, err := FromChatRequest(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
if len(result.Messages) != 2 {
|
|
t.Fatalf("expected 2 messages, got %d", len(result.Messages))
|
|
}
|
|
|
|
if result.Messages[0].Content != "Hello" {
|
|
t.Errorf("expected first message content 'Hello', got %q", result.Messages[0].Content)
|
|
}
|
|
|
|
if len(result.Messages[1].Images) != 1 {
|
|
t.Fatalf("expected 1 image, got %d", len(result.Messages[1].Images))
|
|
}
|
|
|
|
if string(result.Messages[1].Images[0]) != string(imgData) {
|
|
t.Error("image data mismatch")
|
|
}
|
|
}
|
|
|
|
func TestFromCompleteRequest_Basic(t *testing.T) {
|
|
temp := float32(0.8)
|
|
req := CompletionRequest{
|
|
Model: "test-model",
|
|
Prompt: "Hello",
|
|
Temperature: &temp,
|
|
}
|
|
|
|
result, err := FromCompleteRequest(req)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
|
|
if result.Model != "test-model" {
|
|
t.Errorf("expected model 'test-model', got %q", result.Model)
|
|
}
|
|
|
|
if result.Prompt != "Hello" {
|
|
t.Errorf("expected prompt 'Hello', got %q", result.Prompt)
|
|
}
|
|
|
|
if tempVal, ok := result.Options["temperature"].(float32); !ok || tempVal != 0.8 {
|
|
t.Errorf("expected temperature 0.8, got %v", result.Options["temperature"])
|
|
}
|
|
}
|
|
|
|
func TestToUsage(t *testing.T) {
|
|
resp := api.ChatResponse{
|
|
Metrics: api.Metrics{
|
|
PromptEvalCount: 10,
|
|
EvalCount: 20,
|
|
},
|
|
}
|
|
|
|
usage := ToUsage(resp)
|
|
|
|
if usage.PromptTokens != 10 {
|
|
t.Errorf("expected PromptTokens 10, got %d", usage.PromptTokens)
|
|
}
|
|
|
|
if usage.CompletionTokens != 20 {
|
|
t.Errorf("expected CompletionTokens 20, got %d", usage.CompletionTokens)
|
|
}
|
|
|
|
if usage.TotalTokens != 30 {
|
|
t.Errorf("expected TotalTokens 30, got %d", usage.TotalTokens)
|
|
}
|
|
}
|
|
|
|
func TestNewError(t *testing.T) {
|
|
tests := []struct {
|
|
code int
|
|
want string
|
|
}{
|
|
{400, "invalid_request_error"},
|
|
{404, "not_found_error"},
|
|
{500, "api_error"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
result := NewError(tt.code, "test message")
|
|
if result.Error.Type != tt.want {
|
|
t.Errorf("NewError(%d) type = %q, want %q", tt.code, result.Error.Type, tt.want)
|
|
}
|
|
if result.Error.Message != "test message" {
|
|
t.Errorf("NewError(%d) message = %q, want %q", tt.code, result.Error.Message, "test message")
|
|
}
|
|
}
|
|
}
|