subprocess llama.cpp server (#401)

* remove c code
* pack llama.cpp
* use request context for llama_cpp
* let llama_cpp decide the number of threads to use
* stop llama runner when app stops
* remove sample count and duration metrics
* use go generate to get libraries
* tmp dir for running llm
This commit is contained in:
Bruce MacDonald
2023-08-30 16:35:03 -04:00
committed by GitHub
parent f4432e1dba
commit 42998d797d
37 changed files with 958 additions and 43928 deletions

View File

@@ -10,10 +10,12 @@ import (
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"reflect"
"strings"
"sync"
"syscall"
"time"
"github.com/gin-contrib/cors"
@@ -55,7 +57,7 @@ var loaded struct {
var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
func load(ctx context.Context, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
log.Printf("could not load model options: %v", err)
@@ -67,8 +69,20 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur
return err
}
// check if the loaded model is still running in a subprocess, in case something unexpected happened
if loaded.llm != nil {
if err := loaded.llm.Ping(ctx); err != nil {
log.Print("loaded llm process not responding, closing now")
// the subprocess is no longer running, so close it
loaded.llm.Close()
loaded.llm = nil
loaded.digest = ""
}
}
if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, opts) {
if loaded.llm != nil {
log.Println("changing loaded model")
loaded.llm.Close()
loaded.llm = nil
loaded.digest = ""
@@ -100,8 +114,14 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur
return err
}
tokensWithSystem := llmModel.Encode(promptWithSystem)
tokensNoSystem := llmModel.Encode(promptNoSystem)
tokensWithSystem, err := llmModel.Encode(ctx, promptWithSystem)
if err != nil {
return err
}
tokensNoSystem, err := llmModel.Encode(ctx, promptNoSystem)
if err != nil {
return err
}
opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1
@@ -151,7 +171,7 @@ func GenerateHandler(c *gin.Context) {
}
sessionDuration := defaultSessionDuration // TODO: set this duration from the request if specified
if err := load(model, req.Options, sessionDuration); err != nil {
if err := load(c.Request.Context(), model, req.Options, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -160,7 +180,7 @@ func GenerateHandler(c *gin.Context) {
embedding := ""
if model.Embeddings != nil && len(model.Embeddings) > 0 {
promptEmbed, err := loaded.llm.Embedding(req.Prompt)
promptEmbed, err := loaded.llm.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -196,7 +216,7 @@ func GenerateHandler(c *gin.Context) {
ch <- r
}
if err := loaded.llm.Predict(req.Context, prompt, fn); err != nil {
if err := loaded.llm.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@@ -219,7 +239,7 @@ func EmbeddingHandler(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := load(model, req.Options, 5*time.Minute); err != nil {
if err := load(c.Request.Context(), model, req.Options, 5*time.Minute); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -229,7 +249,7 @@ func EmbeddingHandler(c *gin.Context) {
return
}
embedding, err := loaded.llm.Embedding(req.Prompt)
embedding, err := loaded.llm.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
log.Printf("embedding generation failed: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@@ -455,6 +475,17 @@ func Serve(ln net.Listener, origins []string) error {
Handler: r,
}
// listen for a ctrl+c and stop any loaded llm
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT)
go func() {
<-signals
if loaded.llm != nil {
loaded.llm.Close()
}
os.Exit(0)
}()
return s.Serve(ln)
}