mirror of
https://github.com/ollama/ollama.git
synced 2025-08-03 18:12:32 +02:00
trim chat prompt based on llm context size (#1963)
This commit is contained in:
107
server/routes.go
107
server/routes.go
@@ -1121,11 +1121,16 @@ func ChatHandler(c *gin.Context) {
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
prompt, images, err := model.ChatPrompt(req.Messages)
|
||||
chat, err := model.ChatPrompts(req.Messages)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
prompt, err := trimmedPrompt(c.Request.Context(), chat, model)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("prompt: %s", prompt))
|
||||
|
||||
@@ -1164,7 +1169,7 @@ func ChatHandler(c *gin.Context) {
|
||||
predictReq := llm.PredictOpts{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: images,
|
||||
Images: chat.CurrentImages,
|
||||
Options: opts,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
@@ -1202,3 +1207,101 @@ func ChatHandler(c *gin.Context) {
|
||||
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model
|
||||
type promptInfo struct {
|
||||
vars PromptVars
|
||||
tokenLen int
|
||||
}
|
||||
|
||||
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
|
||||
// while preserving the most recent system message.
|
||||
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, error) {
|
||||
if len(chat.Prompts) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var promptsToAdd []promptInfo
|
||||
var totalTokenLength int
|
||||
var systemPromptIncluded bool
|
||||
|
||||
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
|
||||
for i := len(chat.Prompts) - 1; i >= 0; i-- {
|
||||
promptText, err := promptString(model, chat.Prompts[i], i == len(chat.Prompts)-1)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
encodedTokens, err := loaded.runner.Encode(ctx, promptText)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
|
||||
break // reached max context length, stop adding more prompts
|
||||
}
|
||||
|
||||
totalTokenLength += len(encodedTokens)
|
||||
systemPromptIncluded = systemPromptIncluded || chat.Prompts[i].System != ""
|
||||
promptsToAdd = append(promptsToAdd, promptInfo{vars: chat.Prompts[i], tokenLen: len(encodedTokens)})
|
||||
}
|
||||
|
||||
// ensure the system prompt is included, if not already
|
||||
if chat.LastSystem != "" && !systemPromptIncluded {
|
||||
var err error
|
||||
promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
promptsToAdd[len(promptsToAdd)-1].vars.First = true
|
||||
|
||||
// construct the final prompt string from the prompts which fit within the context window
|
||||
var result string
|
||||
for i, prompt := range promptsToAdd {
|
||||
promptText, err := promptString(model, prompt.vars, i == 0)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
result = promptText + result
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// promptString applies the model template to the prompt
|
||||
func promptString(model *Model, vars PromptVars, isMostRecent bool) (string, error) {
|
||||
if isMostRecent {
|
||||
p, err := model.PreResponsePrompt(vars)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("pre-response template: %w", err)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
p, err := Prompt(model.Template, vars)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// includeSystemPrompt adjusts the prompts to include the system prompt.
|
||||
func includeSystemPrompt(ctx context.Context, systemPrompt string, totalTokenLength int, promptsToAdd []promptInfo) ([]promptInfo, error) {
|
||||
systemTokens, err := loaded.runner.Encode(ctx, systemPrompt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := len(promptsToAdd) - 1; i >= 0; i-- {
|
||||
if totalTokenLength+len(systemTokens) <= loaded.NumCtx {
|
||||
promptsToAdd[i].vars.System = systemPrompt
|
||||
return promptsToAdd[:i+1], nil
|
||||
}
|
||||
totalTokenLength -= promptsToAdd[i].tokenLen
|
||||
}
|
||||
|
||||
// if got here, system did not fit anywhere, so return the most recent prompt with the system message set
|
||||
recent := promptsToAdd[len(promptsToAdd)-1]
|
||||
recent.vars.System = systemPrompt
|
||||
return []promptInfo{recent}, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user