simplofieid logic

This commit is contained in:
nicole pardal
2025-11-05 11:06:17 -08:00
parent 61b87b1911
commit 010d718835

View File

@@ -697,47 +697,23 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
isTooLong := func(err error) bool {
var serr api.StatusError
if !errors.As(err, &serr) {
return false
}
if serr.StatusCode != http.StatusBadRequest {
return false
}
msg := strings.TrimSpace(serr.ErrorMessage)
if msg == "embedding_input_too_long" {
return true
}
if strings.HasPrefix(msg, "{") {
var m map[string]any
if json.Unmarshal([]byte(msg), &m) == nil {
if v, ok := m["error"].(string); ok && v == "embedding_input_too_long" {
return true
}
}
}
return strings.Contains(msg, "embedding input length exceeds the context length") ||
strings.Contains(msg, "the embedding input length exceeds the context length") ||
strings.Contains(msg, "input length exceeds the context length") ||
strings.Contains(msg, "exceeds maximum context length")
}
embedWithRetry := func(text string) ([]float32, int, error) { embedWithRetry := func(text string) ([]float32, int, error) {
emb, tokCount, err := r.Embedding(c.Request.Context(), text, false) emb, tokCount, err := r.Embedding(c.Request.Context(), text, false)
if err == nil { if err == nil {
return emb, tokCount, nil return emb, tokCount, nil
} }
if !isTooLong(err) {
var serr api.StatusError
if !errors.As(err, &serr) || serr.StatusCode != http.StatusBadRequest {
return nil, 0, err return nil, 0, err
} }
// If the caller disabled truncation, bubble the original error as-is
if req.Truncate != nil && !*req.Truncate { if req.Truncate != nil && !*req.Truncate {
return nil, 0, err return nil, 0, err
} }
// Determine if the error is due to input exceeding context length by checking tokenized length against effective context length
tokens, tokErr := r.Tokenize(c.Request.Context(), text) tokens, tokErr := r.Tokenize(c.Request.Context(), text)
if tokErr != nil { if tokErr != nil {
return nil, 0, tokErr return nil, 0, tokErr
@@ -752,13 +728,15 @@ func (s *Server) EmbedHandler(c *gin.Context) {
ctxLen-- ctxLen--
} }
if len(tokens) <= ctxLen {
return nil, 0, err
}
if ctxLen <= 0 { if ctxLen <= 0 {
return nil, 0, fmt.Errorf("input after truncation exceeds maximum context length") return nil, 0, fmt.Errorf("input after truncation exceeds maximum context length")
} }
if len(tokens) > ctxLen { tokens = tokens[:ctxLen]
tokens = tokens[:ctxLen]
}
truncated, detErr := r.Detokenize(c.Request.Context(), tokens) truncated, detErr := r.Detokenize(c.Request.Context(), tokens)
if detErr != nil { if detErr != nil {