mirror of
https://github.com/ollama/ollama.git
synced 2025-11-11 18:57:00 +01:00
session id
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"dario.cat/mergo"
|
||||
@@ -21,7 +22,17 @@ import (
|
||||
"github.com/jmorganca/ollama/llama"
|
||||
)
|
||||
|
||||
var mu sync.Mutex
|
||||
|
||||
var activeSession struct {
|
||||
ID int64
|
||||
*llama.LLM
|
||||
}
|
||||
|
||||
func GenerateHandler(c *gin.Context) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
start := time.Now()
|
||||
|
||||
var req api.GenerateRequest
|
||||
@@ -36,15 +47,31 @@ func GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if req.SessionID == 0 || req.SessionID != activeSession.ID {
|
||||
if activeSession.LLM != nil {
|
||||
activeSession.Close()
|
||||
activeSession.LLM = nil
|
||||
}
|
||||
|
||||
if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
opts := api.DefaultOptions()
|
||||
if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
llm, err := llama.New(model.ModelPath, opts)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
activeSession.ID = time.Now().UnixNano()
|
||||
activeSession.LLM = llm
|
||||
}
|
||||
|
||||
prompt, err := model.Prompt(req)
|
||||
@@ -53,19 +80,13 @@ func GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
llm, err := llama.New(model.ModelPath, opts)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer llm.Close()
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
fn := func(r api.GenerateResponse) {
|
||||
r.Model = req.Model
|
||||
r.CreatedAt = time.Now().UTC()
|
||||
r.SessionID = activeSession.ID
|
||||
if r.Done {
|
||||
r.TotalDuration = time.Since(start)
|
||||
}
|
||||
@@ -73,7 +94,7 @@ func GenerateHandler(c *gin.Context) {
|
||||
ch <- r
|
||||
}
|
||||
|
||||
if err := llm.Predict(req.Context, prompt, fn); err != nil {
|
||||
if err := activeSession.LLM.Predict(req.Context, prompt, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
Reference in New Issue
Block a user