diff --git a/api/types.go b/api/types.go index d3f6fc5a47..a7ddbc373e 100644 --- a/api/types.go +++ b/api/types.go @@ -388,8 +388,12 @@ type EmbedRequest struct { // this request. KeepAlive *Duration `json:"keep_alive,omitempty"` + // Truncate truncates the input to fit the model's max sequence length. Truncate *bool `json:"truncate,omitempty"` + // Dimensions truncates the output embedding to the specified dimension. + Dimensions int `json:"dimensions,omitempty"` + // Options lists model-specific options. Options map[string]any `json:"options"` } diff --git a/docs/api.md b/docs/api.md index f11d59ed1c..f47af63c62 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1708,6 +1708,7 @@ Advanced parameters: - `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) +- `dimensions`: number of dimensions for the embedding ### Examples diff --git a/openai/openai.go b/openai/openai.go index 9c7c41cb42..b6a8a95e23 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -76,8 +76,9 @@ type JsonSchema struct { } type EmbedRequest struct { - Input any `json:"input"` - Model string `json:"model"` + Input any `json:"input"` + Model string `json:"model"` + Dimensions int `json:"dimensions,omitempty"` } type StreamOptions struct { @@ -1005,7 +1006,7 @@ func EmbeddingsMiddleware() gin.HandlerFunc { } var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil { + if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) return } diff --git a/server/routes.go b/server/routes.go index ac4df4a469..8dd1b217ae 100644 --- a/server/routes.go +++ b/server/routes.go @@ -558,7 +558,12 @@ func (s *Server) EmbedHandler(c *gin.Context) { if err != nil { return err } - embeddings[i] = normalize(embedding) + // TODO: this first normalization should be done by the model + embedding = normalize(embedding) + if req.Dimensions > 0 && req.Dimensions < len(embedding) { + embedding = normalize(embedding[:req.Dimensions]) + } + embeddings[i] = embedding return nil }) } @@ -584,11 +589,7 @@ func normalize(vec []float32) []float32 { sum += v * v } - norm := float32(0.0) - if sum > 0 { - norm = float32(1.0 / math.Sqrt(float64(sum))) - } - + norm := float32(1.0 / max(math.Sqrt(float64(sum)), 1e-12)) for i := range vec { vec[i] *= norm }