Switch back to subprocessing for llama.cpp

This should resolve a number of memory leak and stability defects by allowing
us to isolate llama.cpp in a separate process and shutdown when idle, and
gracefully restart if it has problems.  This also serves as a first step to be
able to run multiple copies to support multiple models concurrently.
This commit is contained in:
Daniel Hiltgen
2024-03-14 10:24:13 -07:00
parent 3b6a9154dd
commit 58d95cc9bd
35 changed files with 1416 additions and 1910 deletions

View File

@@ -56,12 +56,13 @@ func init() {
var loaded struct {
mu sync.Mutex
runner llm.LLM
llama *llm.LlamaServer
expireAt time.Time
expireTimer *time.Timer
*Model
model string
adapters []string
projectors []string
*api.Options
}
@@ -69,21 +70,28 @@ var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(c *gin.Context, model *Model, opts *api.Options, sessionDuration time.Duration) error {
needLoad := loaded.runner == nil || // is there a model loaded?
loaded.ModelPath != model.ModelPath || // has the base model changed?
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) // have the runner options changed?
ctx, cancel := context.WithTimeout(c, 10*time.Second)
defer cancel()
needLoad := loaded.llama == nil || // is there a model loaded?
loaded.model != model.ModelPath || // has the base model changed?
!reflect.DeepEqual(loaded.adapters, model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.projectors, model.ProjectorPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) || // have the runner options changed?
loaded.llama.Ping(ctx) != nil
if needLoad {
if loaded.runner != nil {
if loaded.llama != nil {
slog.Info("changing loaded model")
loaded.runner.Close()
loaded.runner = nil
loaded.Model = nil
loaded.llama.Close()
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
}
llmRunner, err := llm.New(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
// show a generalized compatibility error until there is a better way to
@@ -95,28 +103,26 @@ func load(c *gin.Context, model *Model, opts *api.Options, sessionDuration time.
return err
}
loaded.Model = model
loaded.runner = llmRunner
loaded.model = model.ModelPath
loaded.adapters = model.AdapterPaths
loaded.projectors = model.ProjectorPaths
loaded.llama = llama
loaded.Options = opts
}
loaded.expireAt = time.Now().Add(sessionDuration)
if loaded.expireTimer == nil {
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
loaded.mu.Lock()
defer loaded.mu.Unlock()
if time.Now().Before(loaded.expireAt) {
return
if loaded.llama != nil {
loaded.llama.Close()
}
if loaded.runner != nil {
loaded.runner.Close()
}
loaded.runner = nil
loaded.Model = nil
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
})
}
@@ -265,7 +271,7 @@ func GenerateHandler(c *gin.Context) {
sb.Reset()
if req.Context != nil {
prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
prev, err := loaded.llama.Detokenize(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -286,9 +292,8 @@ func GenerateHandler(c *gin.Context) {
go func() {
defer close(ch)
fn := func(r llm.PredictResult) {
fn := func(r llm.CompletionResponse) {
// Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
// Build up the full response
@@ -322,7 +327,7 @@ func GenerateHandler(c *gin.Context) {
}
// TODO (jmorganca): encode() should not strip special tokens
tokens, err := loaded.runner.Encode(c.Request.Context(), p)
tokens, err := loaded.llama.Tokenize(c.Request.Context(), p)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
@@ -344,13 +349,13 @@ func GenerateHandler(c *gin.Context) {
}
// Start prediction
predictReq := llm.PredictOpts{
req := llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
Options: opts,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
if err := loaded.llama.Completion(c.Request.Context(), req, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@@ -471,7 +476,7 @@ func EmbeddingsHandler(c *gin.Context) {
return
}
embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt)
embedding, err := loaded.llama.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@@ -1123,8 +1128,8 @@ func Serve(ln net.Listener) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-signals
if loaded.runner != nil {
loaded.runner.Close()
if loaded.llama != nil {
loaded.llama.Close()
}
gpu.Cleanup()
os.Exit(0)
@@ -1196,7 +1201,7 @@ func streamResponse(c *gin.Context, ch chan any) {
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
encode := func(s string) ([]int, error) {
return loaded.runner.Encode(ctx, s)
return loaded.llama.Tokenize(ctx, s)
}
prompt, err := ChatPrompt(template, messages, numCtx, encode)
@@ -1326,9 +1331,8 @@ func ChatHandler(c *gin.Context) {
go func() {
defer close(ch)
fn := func(r llm.PredictResult) {
fn := func(r llm.CompletionResponse) {
// Update model expiration
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
resp := api.ChatResponse{
@@ -1352,14 +1356,12 @@ func ChatHandler(c *gin.Context) {
ch <- resp
}
// Start prediction
predictReq := llm.PredictOpts{
if err := loaded.llama.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Format: req.Format,
Images: images,
Options: opts,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
}, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()