diff --git a/api/types.go b/api/types.go index 95ed5d37e..5f7ced143 100644 --- a/api/types.go +++ b/api/types.go @@ -94,11 +94,14 @@ type ChatRequest struct { Format string `json:"format"` // KeepAlive controls how long the model will stay loaded into memory - // followin the request. + // following the request. KeepAlive *Duration `json:"keep_alive,omitempty"` // Options lists model-specific options. Options map[string]interface{} `json:"options"` + + // OpenAI indicates redirection from the compatibility endpoint. + OpenAI bool `json:"openai,omitempty"` } // Message is a single message in a chat sequence. The message contains the diff --git a/llm/server.go b/llm/server.go index ad67138b5..ad9578a08 100644 --- a/llm/server.go +++ b/llm/server.go @@ -656,6 +656,7 @@ type completion struct { Prompt string `json:"prompt"` Stop bool `json:"stop"` StoppedLimit bool `json:"stopped_limit"` + TokensEval int `json:"tokens_evaluated"` Timings struct { PredictedN int `json:"predicted_n"` @@ -680,6 +681,7 @@ type CompletionResponse struct { PromptEvalDuration time.Duration EvalCount int EvalDuration time.Duration + TokensEval int } func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { @@ -828,6 +830,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), EvalCount: c.Timings.PredictedN, EvalDuration: parseDurationMs(c.Timings.PredictedMS), + TokensEval: c.TokensEval, }) return nil } diff --git a/openai/openai.go b/openai/openai.go index 706d31aa2..15f76b7cd 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -117,7 +117,6 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { }(r.DoneReason), }}, Usage: Usage{ - // TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count PromptTokens: r.PromptEvalCount, CompletionTokens: r.EvalCount, TotalTokens: r.PromptEvalCount + r.EvalCount, @@ -205,6 +204,7 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest { Format: format, Options: options, Stream: &r.Stream, + OpenAI: true, } } diff --git a/server/routes.go b/server/routes.go index ff66663c0..a6ecbc77d 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1372,6 +1372,9 @@ func (s *Server) ChatHandler(c *gin.Context) { defer close(ch) fn := func(r llm.CompletionResponse) { + if req.OpenAI { + r.PromptEvalCount = r.TokensEval + } resp := api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(),