cleaned up

This commit is contained in:
nicole pardal
2025-11-05 11:21:31 -08:00
parent 010d718835
commit 1f54127f9a

View File

@@ -697,8 +697,10 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return return
} }
ctx := c.Request.Context()
embedWithRetry := func(text string) ([]float32, int, error) { embedWithRetry := func(text string) ([]float32, int, error) {
emb, tokCount, err := r.Embedding(c.Request.Context(), text, false) emb, tokCount, err := r.Embedding(ctx, text, false)
if err == nil { if err == nil {
return emb, tokCount, nil return emb, tokCount, nil
} }
@@ -707,20 +709,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
if !errors.As(err, &serr) || serr.StatusCode != http.StatusBadRequest { if !errors.As(err, &serr) || serr.StatusCode != http.StatusBadRequest {
return nil, 0, err return nil, 0, err
} }
// If the caller disabled truncation, bubble the original error as-is
if req.Truncate != nil && !*req.Truncate { if req.Truncate != nil && !*req.Truncate {
return nil, 0, err return nil, 0, err
} }
// Determine if the error is due to input exceeding context length by checking tokenized length against effective context length tokens, err := r.Tokenize(ctx, text)
tokens, tokErr := r.Tokenize(c.Request.Context(), text) if err != nil {
if tokErr != nil { return nil, 0, err
return nil, 0, tokErr
} }
ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); len(tokens) > 0 && tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) { if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); len(tokens) > 0 && tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
ctxLen-- ctxLen--
} }
@@ -731,18 +729,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
if len(tokens) <= ctxLen { if len(tokens) <= ctxLen {
return nil, 0, err return nil, 0, err
} }
if ctxLen <= 0 { if ctxLen <= 0 {
return nil, 0, fmt.Errorf("input after truncation exceeds maximum context length") return nil, 0, fmt.Errorf("input after truncation exceeds maximum context length")
} }
tokens = tokens[:ctxLen] truncatedTokens := tokens[:ctxLen]
truncated, err := r.Detokenize(ctx, truncatedTokens)
truncated, detErr := r.Detokenize(c.Request.Context(), tokens) if err != nil {
if detErr != nil { return nil, 0, err
return nil, 0, detErr
} }
return r.Embedding(c.Request.Context(), truncated, true) return r.Embedding(ctx, truncated, true)
} }
var g errgroup.Group var g errgroup.Group