From e0ead1adee0a36f8aecf0df9747996354ee1ed8c Mon Sep 17 00:00:00 2001 From: nicole pardal <109545900+npardal@users.noreply.github.com> Date: Wed, 22 Oct 2025 11:27:44 -0700 Subject: [PATCH] embeddings: base64 encoding fix (#12715) --- middleware/openai.go | 19 +- middleware/openai_encoding_format_test.go | 220 ++++++++++++++++++++++ openai/openai.go | 34 +++- openai/openai_encoding_format_test.go | 125 ++++++++++++ 4 files changed, 386 insertions(+), 12 deletions(-) create mode 100644 middleware/openai_encoding_format_test.go create mode 100644 openai/openai_encoding_format_test.go diff --git a/middleware/openai.go b/middleware/openai.go index 826a2111bd..b2e43f165c 100644 --- a/middleware/openai.go +++ b/middleware/openai.go @@ -7,6 +7,7 @@ import ( "io" "math/rand" "net/http" + "strings" "github.com/gin-gonic/gin" @@ -44,7 +45,8 @@ type RetrieveWriter struct { type EmbedWriter struct { BaseWriter - model string + model string + encodingFormat string } func (w *BaseWriter) writeError(data []byte) (int, error) { @@ -254,7 +256,7 @@ func (w *EmbedWriter) writeResponse(data []byte) (int, error) { } w.ResponseWriter.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToEmbeddingList(w.model, embedResponse)) + err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToEmbeddingList(w.model, embedResponse, w.encodingFormat)) if err != nil { return 0, err } @@ -348,6 +350,14 @@ func EmbeddingsMiddleware() gin.HandlerFunc { return } + // Validate encoding_format parameter + if req.EncodingFormat != "" { + if !strings.EqualFold(req.EncodingFormat, "float") && !strings.EqualFold(req.EncodingFormat, "base64") { + c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, fmt.Sprintf("Invalid value for 'encoding_format' = %s. Supported values: ['float', 'base64'].", req.EncodingFormat))) + return + } + } + if req.Input == "" { req.Input = []string{""} } @@ -371,8 +381,9 @@ func EmbeddingsMiddleware() gin.HandlerFunc { c.Request.Body = io.NopCloser(&b) w := &EmbedWriter{ - BaseWriter: BaseWriter{ResponseWriter: c.Writer}, - model: req.Model, + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + model: req.Model, + encodingFormat: req.EncodingFormat, } c.Writer = w diff --git a/middleware/openai_encoding_format_test.go b/middleware/openai_encoding_format_test.go new file mode 100644 index 0000000000..52107d6ef1 --- /dev/null +++ b/middleware/openai_encoding_format_test.go @@ -0,0 +1,220 @@ +package middleware + +import ( + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/openai" +) + +func TestEmbeddingsMiddleware_EncodingFormats(t *testing.T) { + testCases := []struct { + name string + encodingFormat string + expectType string // "array" or "string" + verifyBase64 bool + }{ + {"float format", "float", "array", false}, + {"base64 format", "base64", "string", true}, + {"default format", "", "array", false}, + } + + gin.SetMode(gin.TestMode) + + endpoint := func(c *gin.Context) { + resp := api.EmbedResponse{ + Embeddings: [][]float32{{0.1, -0.2, 0.3}}, + PromptEvalCount: 5, + } + c.JSON(http.StatusOK, resp) + } + + router := gin.New() + router.Use(EmbeddingsMiddleware()) + router.Handle(http.MethodPost, "/api/embed", endpoint) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + body := `{"input": "test", "model": "test-model"` + if tc.encodingFormat != "" { + body += `, "encoding_format": "` + tc.encodingFormat + `"` + } + body += `}` + + req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", resp.Code) + } + + var result openai.EmbeddingList + if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if len(result.Data) != 1 { + t.Fatalf("expected 1 embedding, got %d", len(result.Data)) + } + + switch tc.expectType { + case "array": + if _, ok := result.Data[0].Embedding.([]interface{}); !ok { + t.Errorf("expected array, got %T", result.Data[0].Embedding) + } + case "string": + embStr, ok := result.Data[0].Embedding.(string) + if !ok { + t.Errorf("expected string, got %T", result.Data[0].Embedding) + } else if tc.verifyBase64 { + decoded, err := base64.StdEncoding.DecodeString(embStr) + if err != nil { + t.Errorf("invalid base64: %v", err) + } else if len(decoded) != 12 { + t.Errorf("expected 12 bytes, got %d", len(decoded)) + } + } + } + }) + } +} + +func TestEmbeddingsMiddleware_BatchWithBase64(t *testing.T) { + gin.SetMode(gin.TestMode) + + endpoint := func(c *gin.Context) { + resp := api.EmbedResponse{ + Embeddings: [][]float32{ + {0.1, 0.2}, + {0.3, 0.4}, + {0.5, 0.6}, + }, + PromptEvalCount: 10, + } + c.JSON(http.StatusOK, resp) + } + + router := gin.New() + router.Use(EmbeddingsMiddleware()) + router.Handle(http.MethodPost, "/api/embed", endpoint) + + body := `{ + "input": ["hello", "world", "test"], + "model": "test-model", + "encoding_format": "base64" + }` + + req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", resp.Code) + } + + var result openai.EmbeddingList + if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if len(result.Data) != 3 { + t.Fatalf("expected 3 embeddings, got %d", len(result.Data)) + } + + // All should be base64 strings + for i := range 3 { + embeddingStr, ok := result.Data[i].Embedding.(string) + if !ok { + t.Errorf("embedding %d: expected string, got %T", i, result.Data[i].Embedding) + continue + } + + // Verify it's valid base64 + if _, err := base64.StdEncoding.DecodeString(embeddingStr); err != nil { + t.Errorf("embedding %d: invalid base64: %v", i, err) + } + + // Check index + if result.Data[i].Index != i { + t.Errorf("embedding %d: expected index %d, got %d", i, i, result.Data[i].Index) + } + } +} + +func TestEmbeddingsMiddleware_InvalidEncodingFormat(t *testing.T) { + gin.SetMode(gin.TestMode) + + endpoint := func(c *gin.Context) { + c.Status(http.StatusOK) + } + + router := gin.New() + router.Use(EmbeddingsMiddleware()) + router.Handle(http.MethodPost, "/api/embed", endpoint) + + testCases := []struct { + name string + encodingFormat string + shouldFail bool + }{ + {"valid: float", "float", false}, + {"valid: base64", "base64", false}, + {"valid: FLOAT (uppercase)", "FLOAT", false}, + {"valid: BASE64 (uppercase)", "BASE64", false}, + {"valid: Float (mixed)", "Float", false}, + {"valid: Base64 (mixed)", "Base64", false}, + {"invalid: json", "json", true}, + {"invalid: hex", "hex", true}, + {"invalid: invalid_format", "invalid_format", true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + body := `{ + "input": "test", + "model": "test-model", + "encoding_format": "` + tc.encodingFormat + `" + }` + + req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if tc.shouldFail { + if resp.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", resp.Code) + } + + var errResp openai.ErrorResponse + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatalf("failed to unmarshal error response: %v", err) + } + + if errResp.Error.Type != "invalid_request_error" { + t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type) + } + + if !strings.Contains(errResp.Error.Message, "encoding_format") { + t.Errorf("expected error message to mention encoding_format, got %q", errResp.Error.Message) + } + } else { + if resp.Code != http.StatusOK { + t.Errorf("expected status 200, got %d: %s", resp.Code, resp.Body.String()) + } + } + }) + } +} diff --git a/openai/openai.go b/openai/openai.go index 23e9522f03..650514cf19 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -2,7 +2,9 @@ package openai import ( + "bytes" "encoding/base64" + "encoding/binary" "encoding/json" "errors" "fmt" @@ -73,9 +75,10 @@ type JsonSchema struct { } type EmbedRequest struct { - Input any `json:"input"` - Model string `json:"model"` - Dimensions int `json:"dimensions,omitempty"` + Input any `json:"input"` + Model string `json:"model"` + Dimensions int `json:"dimensions,omitempty"` + EncodingFormat string `json:"encoding_format,omitempty"` // "float" or "base64" } type StreamOptions struct { @@ -181,9 +184,9 @@ type Model struct { } type Embedding struct { - Object string `json:"object"` - Embedding []float32 `json:"embedding"` - Index int `json:"index"` + Object string `json:"object"` + Embedding any `json:"embedding"` // Can be []float32 (float format) or string (base64 format) + Index int `json:"index"` } type ListCompletion struct { @@ -377,13 +380,21 @@ func ToListCompletion(r api.ListResponse) ListCompletion { } // ToEmbeddingList converts an api.EmbedResponse to EmbeddingList -func ToEmbeddingList(model string, r api.EmbedResponse) EmbeddingList { +// encodingFormat can be "float", "base64", or empty (defaults to "float") +func ToEmbeddingList(model string, r api.EmbedResponse, encodingFormat string) EmbeddingList { if r.Embeddings != nil { var data []Embedding for i, e := range r.Embeddings { + var embedding any + if strings.EqualFold(encodingFormat, "base64") { + embedding = floatsToBase64(e) + } else { + embedding = e + } + data = append(data, Embedding{ Object: "embedding", - Embedding: e, + Embedding: embedding, Index: i, }) } @@ -402,6 +413,13 @@ func ToEmbeddingList(model string, r api.EmbedResponse) EmbeddingList { return EmbeddingList{} } +// floatsToBase64 encodes a []float32 to a base64 string +func floatsToBase64(floats []float32) string { + var buf bytes.Buffer + binary.Write(&buf, binary.LittleEndian, floats) + return base64.StdEncoding.EncodeToString(buf.Bytes()) +} + // ToModel converts an api.ShowResponse to Model func ToModel(r api.ShowResponse, m string) Model { return Model{ diff --git a/openai/openai_encoding_format_test.go b/openai/openai_encoding_format_test.go new file mode 100644 index 0000000000..d8ccb1f8a6 --- /dev/null +++ b/openai/openai_encoding_format_test.go @@ -0,0 +1,125 @@ +package openai + +import ( + "encoding/base64" + "math" + "testing" + + "github.com/ollama/ollama/api" +) + +func TestToEmbeddingList(t *testing.T) { + testCases := []struct { + name string + embeddings [][]float32 + format string + expectType string // "float" or "base64" + expectCount int + promptEval int + }{ + {"float format", [][]float32{{0.1, -0.2, 0.3}}, "float", "float", 1, 10}, + {"base64 format", [][]float32{{0.1, -0.2, 0.3}}, "base64", "base64", 1, 5}, + {"default to float", [][]float32{{0.1, -0.2, 0.3}}, "", "float", 1, 0}, + {"invalid defaults to float", [][]float32{{0.1, -0.2, 0.3}}, "invalid", "float", 1, 0}, + {"multiple embeddings", [][]float32{{0.1, 0.2}, {0.3, 0.4}, {0.5, 0.6}}, "base64", "base64", 3, 0}, + {"empty embeddings", nil, "float", "", 0, 0}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + resp := api.EmbedResponse{ + Embeddings: tc.embeddings, + PromptEvalCount: tc.promptEval, + } + + result := ToEmbeddingList("test-model", resp, tc.format) + + if tc.expectCount == 0 { + if len(result.Data) != 0 { + t.Errorf("expected 0 embeddings, got %d", len(result.Data)) + } + return + } + + if len(result.Data) != tc.expectCount { + t.Fatalf("expected %d embeddings, got %d", tc.expectCount, len(result.Data)) + } + + if result.Model != "test-model" { + t.Errorf("expected model 'test-model', got %q", result.Model) + } + + // Check type of first embedding + switch tc.expectType { + case "float": + if _, ok := result.Data[0].Embedding.([]float32); !ok { + t.Errorf("expected []float32, got %T", result.Data[0].Embedding) + } + case "base64": + embStr, ok := result.Data[0].Embedding.(string) + if !ok { + t.Errorf("expected string, got %T", result.Data[0].Embedding) + } else if _, err := base64.StdEncoding.DecodeString(embStr); err != nil { + t.Errorf("invalid base64: %v", err) + } + } + + // Check indices + for i := range result.Data { + if result.Data[i].Index != i { + t.Errorf("embedding %d: expected index %d, got %d", i, i, result.Data[i].Index) + } + } + + if tc.promptEval > 0 && result.Usage.PromptTokens != tc.promptEval { + t.Errorf("expected %d prompt tokens, got %d", tc.promptEval, result.Usage.PromptTokens) + } + }) + } +} + +func TestFloatsToBase64(t *testing.T) { + floats := []float32{0.1, -0.2, 0.3, -0.4, 0.5} + + result := floatsToBase64(floats) + + // Verify it's valid base64 + decoded, err := base64.StdEncoding.DecodeString(result) + if err != nil { + t.Fatalf("failed to decode base64: %v", err) + } + + // Check length + expectedBytes := len(floats) * 4 + if len(decoded) != expectedBytes { + t.Errorf("expected %d bytes, got %d", expectedBytes, len(decoded)) + } + + // Decode and verify values + for i, expected := range floats { + offset := i * 4 + bits := uint32(decoded[offset]) | + uint32(decoded[offset+1])<<8 | + uint32(decoded[offset+2])<<16 | + uint32(decoded[offset+3])<<24 + decodedFloat := math.Float32frombits(bits) + + if math.Abs(float64(decodedFloat-expected)) > 1e-6 { + t.Errorf("float[%d]: expected %f, got %f", i, expected, decodedFloat) + } + } +} + +func TestFloatsToBase64_EmptySlice(t *testing.T) { + result := floatsToBase64([]float32{}) + + // Should return valid base64 for empty slice + decoded, err := base64.StdEncoding.DecodeString(result) + if err != nil { + t.Fatalf("failed to decode base64: %v", err) + } + + if len(decoded) != 0 { + t.Errorf("expected 0 bytes, got %d", len(decoded)) + } +}