diff --git a/openai/openai.go b/openai/openai.go index b25ba0949..daafabcee 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -67,14 +67,14 @@ type Usage struct { // is requested to follow OpenAI's behavior. type ChunkUsage = Usage -var nullChunkUsage = ChunkUsage{} +// var nullChunkUsage = ChunkUsage{} -func (u *ChunkUsage) MarshalJSON() ([]byte, error) { - if u == &nullChunkUsage { - return []byte("null"), nil - } - return json.Marshal(*u) -} +// func (u *ChunkUsage) MarshalJSON() ([]byte, error) { +// if u == &nullChunkUsage { +// return []byte("null"), nil +// } +// return json.Marshal(*u) +// } type ResponseFormat struct { Type string `json:"type"` @@ -601,16 +601,16 @@ type BaseWriter struct { } type ChatWriter struct { - stream bool - streamUsage bool - id string + stream bool + streamOptions *StreamOptions + id string BaseWriter } type CompleteWriter struct { - stream bool - streamUsage bool - id string + stream bool + streamOptions *StreamOptions + id string BaseWriter } @@ -654,8 +654,8 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) { // chat chunk if w.stream { c := toChunk(w.id, chatResponse) - if w.streamUsage { - c.Usage = &nullChunkUsage + if w.streamOptions != nil && w.streamOptions.IncludeUsage { + c.Usage = &ChunkUsage{} } d, err := json.Marshal(c) if err != nil { @@ -669,7 +669,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) { } if chatResponse.Done { - if w.streamUsage { + if w.streamOptions != nil && w.streamOptions.IncludeUsage { u := toUsage(chatResponse) d, err := json.Marshal(ChatCompletionChunk{Choices: []ChunkChoice{}, Usage: &u}) if err != nil { @@ -718,8 +718,8 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) { // completion chunk if w.stream { c := toCompleteChunk(w.id, generateResponse) - if w.streamUsage { - c.Usage = &nullChunkUsage + if w.streamOptions != nil && w.streamOptions.IncludeUsage { + c.Usage = &ChunkUsage{} } d, err := json.Marshal(c) if err != nil { @@ -733,7 +733,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) { } if generateResponse.Done { - if w.streamUsage { + if w.streamOptions != nil && w.streamOptions.IncludeUsage { u := toUsageGenerate(generateResponse) d, err := json.Marshal(CompletionChunk{Choices: []CompleteChunkChoice{}, Usage: &u}) if err != nil { @@ -906,10 +906,10 @@ func CompletionsMiddleware() gin.HandlerFunc { c.Request.Body = io.NopCloser(&b) w := &CompleteWriter{ - BaseWriter: BaseWriter{ResponseWriter: c.Writer}, - stream: req.Stream, - id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), - streamUsage: req.StreamOptions != nil && req.StreamOptions.IncludeUsage, + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + stream: req.Stream, + id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), + streamOptions: req.StreamOptions, } c.Writer = w @@ -989,10 +989,10 @@ func ChatMiddleware() gin.HandlerFunc { c.Request.Body = io.NopCloser(&b) w := &ChatWriter{ - BaseWriter: BaseWriter{ResponseWriter: c.Writer}, - stream: req.Stream, - id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), - streamUsage: req.StreamOptions != nil && req.StreamOptions.IncludeUsage, + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + stream: req.Stream, + id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), + streamOptions: req.StreamOptions, } c.Writer = w