mirror of
https://github.com/ollama/ollama.git
synced 2025-10-10 21:32:57 +02:00
subprocess llama.cpp server (#401)
* remove c code * pack llama.cpp * use request context for llama_cpp * let llama_cpp decide the number of threads to use * stop llama runner when app stops * remove sample count and duration metrics * use go generate to get libraries * tmp dir for running llm
This commit is contained in:
@@ -10,10 +10,12 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
@@ -55,7 +57,7 @@ var loaded struct {
|
||||
var defaultSessionDuration = 5 * time.Minute
|
||||
|
||||
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
|
||||
func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
|
||||
func load(ctx context.Context, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
|
||||
opts := api.DefaultOptions()
|
||||
if err := opts.FromMap(model.Options); err != nil {
|
||||
log.Printf("could not load model options: %v", err)
|
||||
@@ -67,8 +69,20 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur
|
||||
return err
|
||||
}
|
||||
|
||||
// check if the loaded model is still running in a subprocess, in case something unexpected happened
|
||||
if loaded.llm != nil {
|
||||
if err := loaded.llm.Ping(ctx); err != nil {
|
||||
log.Print("loaded llm process not responding, closing now")
|
||||
// the subprocess is no longer running, so close it
|
||||
loaded.llm.Close()
|
||||
loaded.llm = nil
|
||||
loaded.digest = ""
|
||||
}
|
||||
}
|
||||
|
||||
if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, opts) {
|
||||
if loaded.llm != nil {
|
||||
log.Println("changing loaded model")
|
||||
loaded.llm.Close()
|
||||
loaded.llm = nil
|
||||
loaded.digest = ""
|
||||
@@ -100,8 +114,14 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur
|
||||
return err
|
||||
}
|
||||
|
||||
tokensWithSystem := llmModel.Encode(promptWithSystem)
|
||||
tokensNoSystem := llmModel.Encode(promptNoSystem)
|
||||
tokensWithSystem, err := llmModel.Encode(ctx, promptWithSystem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tokensNoSystem, err := llmModel.Encode(ctx, promptNoSystem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1
|
||||
|
||||
@@ -151,7 +171,7 @@ func GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
sessionDuration := defaultSessionDuration // TODO: set this duration from the request if specified
|
||||
if err := load(model, req.Options, sessionDuration); err != nil {
|
||||
if err := load(c.Request.Context(), model, req.Options, sessionDuration); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -160,7 +180,7 @@ func GenerateHandler(c *gin.Context) {
|
||||
|
||||
embedding := ""
|
||||
if model.Embeddings != nil && len(model.Embeddings) > 0 {
|
||||
promptEmbed, err := loaded.llm.Embedding(req.Prompt)
|
||||
promptEmbed, err := loaded.llm.Embedding(c.Request.Context(), req.Prompt)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -196,7 +216,7 @@ func GenerateHandler(c *gin.Context) {
|
||||
ch <- r
|
||||
}
|
||||
|
||||
if err := loaded.llm.Predict(req.Context, prompt, fn); err != nil {
|
||||
if err := loaded.llm.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@@ -219,7 +239,7 @@ func EmbeddingHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := load(model, req.Options, 5*time.Minute); err != nil {
|
||||
if err := load(c.Request.Context(), model, req.Options, 5*time.Minute); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -229,7 +249,7 @@ func EmbeddingHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
embedding, err := loaded.llm.Embedding(req.Prompt)
|
||||
embedding, err := loaded.llm.Embedding(c.Request.Context(), req.Prompt)
|
||||
if err != nil {
|
||||
log.Printf("embedding generation failed: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
@@ -455,6 +475,17 @@ func Serve(ln net.Listener, origins []string) error {
|
||||
Handler: r,
|
||||
}
|
||||
|
||||
// listen for a ctrl+c and stop any loaded llm
|
||||
signals := make(chan os.Signal, 1)
|
||||
signal.Notify(signals, syscall.SIGINT)
|
||||
go func() {
|
||||
<-signals
|
||||
if loaded.llm != nil {
|
||||
loaded.llm.Close()
|
||||
}
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
return s.Serve(ln)
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user