session id

This commit is contained in:
Michael Yang
2023-07-18 11:59:42 -07:00
parent dbb3174cbc
commit 35af37a2cb
4 changed files with 67 additions and 36 deletions

View File

@@ -11,6 +11,7 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"time"
"dario.cat/mergo"
@@ -21,7 +22,17 @@ import (
"github.com/jmorganca/ollama/llama"
)
var mu sync.Mutex
var activeSession struct {
ID int64
*llama.LLM
}
func GenerateHandler(c *gin.Context) {
mu.Lock()
defer mu.Unlock()
start := time.Now()
var req api.GenerateRequest
@@ -36,15 +47,31 @@ func GenerateHandler(c *gin.Context) {
return
}
opts := api.DefaultOptions()
if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if req.SessionID == 0 || req.SessionID != activeSession.ID {
if activeSession.LLM != nil {
activeSession.Close()
activeSession.LLM = nil
}
if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
opts := api.DefaultOptions()
if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
llm, err := llama.New(model.ModelPath, opts)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
activeSession.ID = time.Now().UnixNano()
activeSession.LLM = llm
}
prompt, err := model.Prompt(req)
@@ -53,19 +80,13 @@ func GenerateHandler(c *gin.Context) {
return
}
llm, err := llama.New(model.ModelPath, opts)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer llm.Close()
ch := make(chan any)
go func() {
defer close(ch)
fn := func(r api.GenerateResponse) {
r.Model = req.Model
r.CreatedAt = time.Now().UTC()
r.SessionID = activeSession.ID
if r.Done {
r.TotalDuration = time.Since(start)
}
@@ -73,7 +94,7 @@ func GenerateHandler(c *gin.Context) {
ch <- r
}
if err := llm.Predict(req.Context, prompt, fn); err != nil {
if err := activeSession.LLM.Predict(req.Context, prompt, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()