mirror of
https://github.com/ollama/ollama.git
synced 2025-11-10 23:37:17 +01:00
* bf16 * tests * gpt-oss * enable gptoss for engine * rough estimate * convert to mxfp4 * handle safetensors U8 * clamp glu/linear * update tokenizer * MXFP4 support This implements the Open Compute Microscaling (MX) FP4 format as a tensor type with backend implementations focusing on mulmat and mulmatid on CPU, CUDA, and Metal. * Unit tests for MXFP4 support This exercises various operations and shapes on both CPU and GPU (if detected on the system) * cuda graph * unit test adjustments * cuda: optimize memory access Read 4 bytes at a time (8 elements) when performing mul_mat_vec_mxfp4 * mac: fix crash on old macos versions cblas_sgemm is only supported on v13.3 and up, however bf16 is only supported on v14+ so we were falling back to ggml-blas and crashing on bf16 tensors. Checking for the function being null seems to be the simplest way to condittionally avoid registering the backend. * server: Minimum context length for gptoss This model requires a minimum context length of 8192 to function effectively. Users can set higher values through all normal mechanisms but lower values will be silently reset. * ggml: Multiply by numParallel for gptoss sliding window When computing the graph size estimate, the context size is already multiplied by numParallel so estimates reflect that. However, since sliding window models use a smaller, fixed context size, they need to manually take numParallel into account. * gpt-oss integration includes harmony parser and thinking levels, etc. * fix sync * fix tests * fix lint --------- Co-authored-by: Daniel Hiltgen <daniel@ollama.com> Co-authored-by: Jesse Gross <jesse@ollama.com> Co-authored-by: Devon Rifkin <drifkin@drifkin.net>
1012 lines
24 KiB
Go
1012 lines
24 KiB
Go
// openai package provides middleware for partial compatibility with the OpenAI REST API
|
|
package openai
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"math/rand"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/types/model"
|
|
)
|
|
|
|
var finishReasonToolCalls = "tool_calls"
|
|
|
|
type Error struct {
|
|
Message string `json:"message"`
|
|
Type string `json:"type"`
|
|
Param any `json:"param"`
|
|
Code *string `json:"code"`
|
|
}
|
|
|
|
type ErrorResponse struct {
|
|
Error Error `json:"error"`
|
|
}
|
|
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content any `json:"content"`
|
|
Reasoning string `json:"reasoning,omitempty"`
|
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
|
}
|
|
|
|
type Choice struct {
|
|
Index int `json:"index"`
|
|
Message Message `json:"message"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
}
|
|
|
|
type ChunkChoice struct {
|
|
Index int `json:"index"`
|
|
Delta Message `json:"delta"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
}
|
|
|
|
type CompleteChunkChoice struct {
|
|
Text string `json:"text"`
|
|
Index int `json:"index"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
}
|
|
|
|
type Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
type ResponseFormat struct {
|
|
Type string `json:"type"`
|
|
JsonSchema *JsonSchema `json:"json_schema,omitempty"`
|
|
}
|
|
|
|
type JsonSchema struct {
|
|
Schema json.RawMessage `json:"schema"`
|
|
}
|
|
|
|
type EmbedRequest struct {
|
|
Input any `json:"input"`
|
|
Model string `json:"model"`
|
|
}
|
|
|
|
type StreamOptions struct {
|
|
IncludeUsage bool `json:"include_usage"`
|
|
}
|
|
|
|
type Reasoning struct {
|
|
Effort *string `json:"effort,omitempty"`
|
|
}
|
|
|
|
type ChatCompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []Message `json:"messages"`
|
|
Stream bool `json:"stream"`
|
|
StreamOptions *StreamOptions `json:"stream_options"`
|
|
MaxTokens *int `json:"max_tokens"`
|
|
Seed *int `json:"seed"`
|
|
Stop any `json:"stop"`
|
|
Temperature *float64 `json:"temperature"`
|
|
FrequencyPenalty *float64 `json:"frequency_penalty"`
|
|
PresencePenalty *float64 `json:"presence_penalty"`
|
|
TopP *float64 `json:"top_p"`
|
|
ResponseFormat *ResponseFormat `json:"response_format"`
|
|
Tools []api.Tool `json:"tools"`
|
|
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
|
}
|
|
|
|
type ChatCompletion struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Choices []Choice `json:"choices"`
|
|
Usage Usage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type ChatCompletionChunk struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Choices []ChunkChoice `json:"choices"`
|
|
Usage *Usage `json:"usage,omitempty"`
|
|
}
|
|
|
|
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
|
type CompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Prompt string `json:"prompt"`
|
|
FrequencyPenalty float32 `json:"frequency_penalty"`
|
|
MaxTokens *int `json:"max_tokens"`
|
|
PresencePenalty float32 `json:"presence_penalty"`
|
|
Seed *int `json:"seed"`
|
|
Stop any `json:"stop"`
|
|
Stream bool `json:"stream"`
|
|
StreamOptions *StreamOptions `json:"stream_options"`
|
|
Temperature *float32 `json:"temperature"`
|
|
TopP float32 `json:"top_p"`
|
|
Suffix string `json:"suffix"`
|
|
}
|
|
|
|
type Completion struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Choices []CompleteChunkChoice `json:"choices"`
|
|
Usage Usage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type CompletionChunk struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Choices []CompleteChunkChoice `json:"choices"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Usage *Usage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type ToolCall struct {
|
|
ID string `json:"id"`
|
|
Index int `json:"index"`
|
|
Type string `json:"type"`
|
|
Function struct {
|
|
Name string `json:"name"`
|
|
Arguments string `json:"arguments"`
|
|
} `json:"function"`
|
|
}
|
|
|
|
type Model struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
OwnedBy string `json:"owned_by"`
|
|
}
|
|
|
|
type Embedding struct {
|
|
Object string `json:"object"`
|
|
Embedding []float32 `json:"embedding"`
|
|
Index int `json:"index"`
|
|
}
|
|
|
|
type ListCompletion struct {
|
|
Object string `json:"object"`
|
|
Data []Model `json:"data"`
|
|
}
|
|
|
|
type EmbeddingList struct {
|
|
Object string `json:"object"`
|
|
Data []Embedding `json:"data"`
|
|
Model string `json:"model"`
|
|
Usage EmbeddingUsage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type EmbeddingUsage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
func NewError(code int, message string) ErrorResponse {
|
|
var etype string
|
|
switch code {
|
|
case http.StatusBadRequest:
|
|
etype = "invalid_request_error"
|
|
case http.StatusNotFound:
|
|
etype = "not_found_error"
|
|
default:
|
|
etype = "api_error"
|
|
}
|
|
|
|
return ErrorResponse{Error{Type: etype, Message: message}}
|
|
}
|
|
|
|
func toUsage(r api.ChatResponse) Usage {
|
|
return Usage{
|
|
PromptTokens: r.PromptEvalCount,
|
|
CompletionTokens: r.EvalCount,
|
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
|
}
|
|
}
|
|
|
|
func toolCallId() string {
|
|
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
b := make([]byte, 8)
|
|
for i := range b {
|
|
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
|
}
|
|
return "call_" + strings.ToLower(string(b))
|
|
}
|
|
|
|
func toToolCalls(tc []api.ToolCall) []ToolCall {
|
|
toolCalls := make([]ToolCall, len(tc))
|
|
for i, tc := range tc {
|
|
toolCalls[i].ID = toolCallId()
|
|
toolCalls[i].Type = "function"
|
|
toolCalls[i].Function.Name = tc.Function.Name
|
|
toolCalls[i].Index = tc.Function.Index
|
|
|
|
args, err := json.Marshal(tc.Function.Arguments)
|
|
if err != nil {
|
|
slog.Error("could not marshall function arguments to json", "error", err)
|
|
continue
|
|
}
|
|
|
|
toolCalls[i].Function.Arguments = string(args)
|
|
}
|
|
return toolCalls
|
|
}
|
|
|
|
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|
toolCalls := toToolCalls(r.Message.ToolCalls)
|
|
return ChatCompletion{
|
|
Id: id,
|
|
Object: "chat.completion",
|
|
Created: r.CreatedAt.Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []Choice{{
|
|
Index: 0,
|
|
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
|
|
FinishReason: func(reason string) *string {
|
|
if len(toolCalls) > 0 {
|
|
reason = "tool_calls"
|
|
}
|
|
if len(reason) > 0 {
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
}},
|
|
Usage: toUsage(r),
|
|
}
|
|
}
|
|
|
|
func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
|
|
toolCalls := toToolCalls(r.Message.ToolCalls)
|
|
return ChatCompletionChunk{
|
|
Id: id,
|
|
Object: "chat.completion.chunk",
|
|
Created: time.Now().Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []ChunkChoice{{
|
|
Index: 0,
|
|
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
|
|
FinishReason: func(reason string) *string {
|
|
if len(reason) > 0 {
|
|
if toolCallSent || len(toolCalls) > 0 {
|
|
return &finishReasonToolCalls
|
|
}
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
}},
|
|
}
|
|
}
|
|
|
|
func toUsageGenerate(r api.GenerateResponse) Usage {
|
|
return Usage{
|
|
PromptTokens: r.PromptEvalCount,
|
|
CompletionTokens: r.EvalCount,
|
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
|
}
|
|
}
|
|
|
|
func toCompletion(id string, r api.GenerateResponse) Completion {
|
|
return Completion{
|
|
Id: id,
|
|
Object: "text_completion",
|
|
Created: r.CreatedAt.Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []CompleteChunkChoice{{
|
|
Text: r.Response,
|
|
Index: 0,
|
|
FinishReason: func(reason string) *string {
|
|
if len(reason) > 0 {
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
}},
|
|
Usage: toUsageGenerate(r),
|
|
}
|
|
}
|
|
|
|
func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
|
return CompletionChunk{
|
|
Id: id,
|
|
Object: "text_completion",
|
|
Created: time.Now().Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []CompleteChunkChoice{{
|
|
Text: r.Response,
|
|
Index: 0,
|
|
FinishReason: func(reason string) *string {
|
|
if len(reason) > 0 {
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
}},
|
|
}
|
|
}
|
|
|
|
func toListCompletion(r api.ListResponse) ListCompletion {
|
|
var data []Model
|
|
for _, m := range r.Models {
|
|
data = append(data, Model{
|
|
Id: m.Name,
|
|
Object: "model",
|
|
Created: m.ModifiedAt.Unix(),
|
|
OwnedBy: model.ParseName(m.Name).Namespace,
|
|
})
|
|
}
|
|
|
|
return ListCompletion{
|
|
Object: "list",
|
|
Data: data,
|
|
}
|
|
}
|
|
|
|
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
|
if r.Embeddings != nil {
|
|
var data []Embedding
|
|
for i, e := range r.Embeddings {
|
|
data = append(data, Embedding{
|
|
Object: "embedding",
|
|
Embedding: e,
|
|
Index: i,
|
|
})
|
|
}
|
|
|
|
return EmbeddingList{
|
|
Object: "list",
|
|
Data: data,
|
|
Model: model,
|
|
Usage: EmbeddingUsage{
|
|
PromptTokens: r.PromptEvalCount,
|
|
TotalTokens: r.PromptEvalCount,
|
|
},
|
|
}
|
|
}
|
|
|
|
return EmbeddingList{}
|
|
}
|
|
|
|
func toModel(r api.ShowResponse, m string) Model {
|
|
return Model{
|
|
Id: m,
|
|
Object: "model",
|
|
Created: r.ModifiedAt.Unix(),
|
|
OwnedBy: model.ParseName(m).Namespace,
|
|
}
|
|
}
|
|
|
|
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|
var messages []api.Message
|
|
for _, msg := range r.Messages {
|
|
switch content := msg.Content.(type) {
|
|
case string:
|
|
messages = append(messages, api.Message{Role: msg.Role, Content: content, Thinking: msg.Reasoning})
|
|
case []any:
|
|
for _, c := range content {
|
|
data, ok := c.(map[string]any)
|
|
if !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
switch data["type"] {
|
|
case "text":
|
|
text, ok := data["text"].(string)
|
|
if !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
messages = append(messages, api.Message{Role: msg.Role, Content: text})
|
|
case "image_url":
|
|
var url string
|
|
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
|
if url, ok = urlMap["url"].(string); !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
} else {
|
|
if url, ok = data["image_url"].(string); !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
}
|
|
|
|
types := []string{"jpeg", "jpg", "png", "webp"}
|
|
valid := false
|
|
for _, t := range types {
|
|
prefix := "data:image/" + t + ";base64,"
|
|
if strings.HasPrefix(url, prefix) {
|
|
url = strings.TrimPrefix(url, prefix)
|
|
valid = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !valid {
|
|
return nil, errors.New("invalid image input")
|
|
}
|
|
|
|
img, err := base64.StdEncoding.DecodeString(url)
|
|
if err != nil {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
|
|
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
|
default:
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
}
|
|
default:
|
|
if msg.ToolCalls == nil {
|
|
return nil, fmt.Errorf("invalid message content type: %T", content)
|
|
}
|
|
|
|
toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
|
|
for i, tc := range msg.ToolCalls {
|
|
toolCalls[i].Function.Name = tc.Function.Name
|
|
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
|
|
if err != nil {
|
|
return nil, errors.New("invalid tool call arguments")
|
|
}
|
|
}
|
|
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
|
|
}
|
|
}
|
|
|
|
options := make(map[string]any)
|
|
|
|
switch stop := r.Stop.(type) {
|
|
case string:
|
|
options["stop"] = []string{stop}
|
|
case []any:
|
|
var stops []string
|
|
for _, s := range stop {
|
|
if str, ok := s.(string); ok {
|
|
stops = append(stops, str)
|
|
}
|
|
}
|
|
options["stop"] = stops
|
|
}
|
|
|
|
if r.MaxTokens != nil {
|
|
options["num_predict"] = *r.MaxTokens
|
|
}
|
|
|
|
if r.Temperature != nil {
|
|
options["temperature"] = *r.Temperature
|
|
} else {
|
|
options["temperature"] = 1.0
|
|
}
|
|
|
|
if r.Seed != nil {
|
|
options["seed"] = *r.Seed
|
|
}
|
|
|
|
if r.FrequencyPenalty != nil {
|
|
options["frequency_penalty"] = *r.FrequencyPenalty
|
|
}
|
|
|
|
if r.PresencePenalty != nil {
|
|
options["presence_penalty"] = *r.PresencePenalty
|
|
}
|
|
|
|
if r.TopP != nil {
|
|
options["top_p"] = *r.TopP
|
|
} else {
|
|
options["top_p"] = 1.0
|
|
}
|
|
|
|
if r.Reasoning != nil {
|
|
options["reasoning"] = *r.Reasoning.Effort
|
|
}
|
|
|
|
var format json.RawMessage
|
|
if r.ResponseFormat != nil {
|
|
switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
|
|
// Support the old "json_object" type for OpenAI compatibility
|
|
case "json_object":
|
|
format = json.RawMessage(`"json"`)
|
|
case "json_schema":
|
|
if r.ResponseFormat.JsonSchema != nil {
|
|
format = r.ResponseFormat.JsonSchema.Schema
|
|
}
|
|
}
|
|
}
|
|
|
|
var think *api.ThinkValue
|
|
if r.Reasoning != nil {
|
|
think = &api.ThinkValue{
|
|
Value: *r.Reasoning.Effort,
|
|
}
|
|
}
|
|
|
|
return &api.ChatRequest{
|
|
Model: r.Model,
|
|
Messages: messages,
|
|
Format: format,
|
|
Options: options,
|
|
Stream: &r.Stream,
|
|
Tools: r.Tools,
|
|
Think: think,
|
|
}, nil
|
|
}
|
|
|
|
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
|
options := make(map[string]any)
|
|
|
|
switch stop := r.Stop.(type) {
|
|
case string:
|
|
options["stop"] = []string{stop}
|
|
case []any:
|
|
var stops []string
|
|
for _, s := range stop {
|
|
if str, ok := s.(string); ok {
|
|
stops = append(stops, str)
|
|
} else {
|
|
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
|
|
}
|
|
}
|
|
options["stop"] = stops
|
|
}
|
|
|
|
if r.MaxTokens != nil {
|
|
options["num_predict"] = *r.MaxTokens
|
|
}
|
|
|
|
if r.Temperature != nil {
|
|
options["temperature"] = *r.Temperature
|
|
} else {
|
|
options["temperature"] = 1.0
|
|
}
|
|
|
|
if r.Seed != nil {
|
|
options["seed"] = *r.Seed
|
|
}
|
|
|
|
options["frequency_penalty"] = r.FrequencyPenalty
|
|
|
|
options["presence_penalty"] = r.PresencePenalty
|
|
|
|
if r.TopP != 0.0 {
|
|
options["top_p"] = r.TopP
|
|
} else {
|
|
options["top_p"] = 1.0
|
|
}
|
|
|
|
return api.GenerateRequest{
|
|
Model: r.Model,
|
|
Prompt: r.Prompt,
|
|
Options: options,
|
|
Stream: &r.Stream,
|
|
Suffix: r.Suffix,
|
|
}, nil
|
|
}
|
|
|
|
type BaseWriter struct {
|
|
gin.ResponseWriter
|
|
}
|
|
|
|
type ChatWriter struct {
|
|
stream bool
|
|
streamOptions *StreamOptions
|
|
id string
|
|
toolCallSent bool
|
|
BaseWriter
|
|
}
|
|
|
|
type CompleteWriter struct {
|
|
stream bool
|
|
streamOptions *StreamOptions
|
|
id string
|
|
BaseWriter
|
|
}
|
|
|
|
type ListWriter struct {
|
|
BaseWriter
|
|
}
|
|
|
|
type RetrieveWriter struct {
|
|
BaseWriter
|
|
model string
|
|
}
|
|
|
|
type EmbedWriter struct {
|
|
BaseWriter
|
|
model string
|
|
}
|
|
|
|
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
|
var serr api.StatusError
|
|
err := json.Unmarshal(data, &serr)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
|
var chatResponse api.ChatResponse
|
|
err := json.Unmarshal(data, &chatResponse)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// chat chunk
|
|
if w.stream {
|
|
c := toChunk(w.id, chatResponse, w.toolCallSent)
|
|
d, err := json.Marshal(c)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
|
|
w.toolCallSent = true
|
|
}
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if chatResponse.Done {
|
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
|
u := toUsage(chatResponse)
|
|
c.Usage = &u
|
|
c.Choices = []ChunkChoice{}
|
|
d, err := json.Marshal(c)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
// chat completion
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *ChatWriter) Write(data []byte) (int, error) {
|
|
code := w.ResponseWriter.Status()
|
|
if code != http.StatusOK {
|
|
return w.writeError(data)
|
|
}
|
|
|
|
return w.writeResponse(data)
|
|
}
|
|
|
|
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
|
var generateResponse api.GenerateResponse
|
|
err := json.Unmarshal(data, &generateResponse)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// completion chunk
|
|
if w.stream {
|
|
c := toCompleteChunk(w.id, generateResponse)
|
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
|
c.Usage = &Usage{}
|
|
}
|
|
d, err := json.Marshal(c)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if generateResponse.Done {
|
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
|
u := toUsageGenerate(generateResponse)
|
|
c.Usage = &u
|
|
c.Choices = []CompleteChunkChoice{}
|
|
d, err := json.Marshal(c)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
// completion
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
|
code := w.ResponseWriter.Status()
|
|
if code != http.StatusOK {
|
|
return w.writeError(data)
|
|
}
|
|
|
|
return w.writeResponse(data)
|
|
}
|
|
|
|
func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
|
var listResponse api.ListResponse
|
|
err := json.Unmarshal(data, &listResponse)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *ListWriter) Write(data []byte) (int, error) {
|
|
code := w.ResponseWriter.Status()
|
|
if code != http.StatusOK {
|
|
return w.writeError(data)
|
|
}
|
|
|
|
return w.writeResponse(data)
|
|
}
|
|
|
|
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
|
var showResponse api.ShowResponse
|
|
err := json.Unmarshal(data, &showResponse)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// retrieve completion
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
|
code := w.ResponseWriter.Status()
|
|
if code != http.StatusOK {
|
|
return w.writeError(data)
|
|
}
|
|
|
|
return w.writeResponse(data)
|
|
}
|
|
|
|
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
|
var embedResponse api.EmbedResponse
|
|
err := json.Unmarshal(data, &embedResponse)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
|
code := w.ResponseWriter.Status()
|
|
if code != http.StatusOK {
|
|
return w.writeError(data)
|
|
}
|
|
|
|
return w.writeResponse(data)
|
|
}
|
|
|
|
func ListMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
w := &ListWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
}
|
|
|
|
c.Writer = w
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func RetrieveMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
var b bytes.Buffer
|
|
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
return
|
|
}
|
|
|
|
c.Request.Body = io.NopCloser(&b)
|
|
|
|
// response writer
|
|
w := &RetrieveWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
model: c.Param("model"),
|
|
}
|
|
|
|
c.Writer = w
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func CompletionsMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
var req CompletionRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
return
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
genReq, err := fromCompleteRequest(req)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
return
|
|
}
|
|
|
|
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
return
|
|
}
|
|
|
|
c.Request.Body = io.NopCloser(&b)
|
|
|
|
w := &CompleteWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
stream: req.Stream,
|
|
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
|
streamOptions: req.StreamOptions,
|
|
}
|
|
|
|
c.Writer = w
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func EmbeddingsMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
var req EmbedRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
return
|
|
}
|
|
|
|
if req.Input == "" {
|
|
req.Input = []string{""}
|
|
}
|
|
|
|
if req.Input == nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
|
return
|
|
}
|
|
|
|
if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
|
return
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
return
|
|
}
|
|
|
|
c.Request.Body = io.NopCloser(&b)
|
|
|
|
w := &EmbedWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
model: req.Model,
|
|
}
|
|
|
|
c.Writer = w
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func ChatMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
var req ChatCompletionRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
return
|
|
}
|
|
|
|
if len(req.Messages) == 0 {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
|
|
return
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
|
|
chatReq, err := fromChatRequest(req)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
return
|
|
}
|
|
|
|
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
return
|
|
}
|
|
|
|
c.Request.Body = io.NopCloser(&b)
|
|
|
|
w := &ChatWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
stream: req.Stream,
|
|
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
|
streamOptions: req.StreamOptions,
|
|
}
|
|
|
|
c.Writer = w
|
|
|
|
c.Next()
|
|
}
|
|
}
|