diff --git a/server/routes.go b/server/routes.go index d2272992aa..818c2e21bf 100644 --- a/server/routes.go +++ b/server/routes.go @@ -697,8 +697,10 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } + ctx := c.Request.Context() + embedWithRetry := func(text string) ([]float32, int, error) { - emb, tokCount, err := r.Embedding(c.Request.Context(), text, false) + emb, tokCount, err := r.Embedding(ctx, text, false) if err == nil { return emb, tokCount, nil } @@ -707,20 +709,16 @@ func (s *Server) EmbedHandler(c *gin.Context) { if !errors.As(err, &serr) || serr.StatusCode != http.StatusBadRequest { return nil, 0, err } - - // If the caller disabled truncation, bubble the original error as-is if req.Truncate != nil && !*req.Truncate { return nil, 0, err } - // Determine if the error is due to input exceeding context length by checking tokenized length against effective context length - tokens, tokErr := r.Tokenize(c.Request.Context(), text) - if tokErr != nil { - return nil, 0, tokErr + tokens, err := r.Tokenize(ctx, text) + if err != nil { + return nil, 0, err } ctxLen := min(opts.NumCtx, int(kvData.ContextLength())) - if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); len(tokens) > 0 && tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) { ctxLen-- } @@ -731,18 +729,16 @@ func (s *Server) EmbedHandler(c *gin.Context) { if len(tokens) <= ctxLen { return nil, 0, err } - if ctxLen <= 0 { return nil, 0, fmt.Errorf("input after truncation exceeds maximum context length") } - tokens = tokens[:ctxLen] - - truncated, detErr := r.Detokenize(c.Request.Context(), tokens) - if detErr != nil { - return nil, 0, detErr + truncatedTokens := tokens[:ctxLen] + truncated, err := r.Detokenize(ctx, truncatedTokens) + if err != nil { + return nil, 0, err } - return r.Embedding(c.Request.Context(), truncated, true) + return r.Embedding(ctx, truncated, true) } var g errgroup.Group