do not reload the running llm when runtime params change (#840)

- only reload the running llm if the model has changed, or the options for loading the running model have changed
- rename loaded llm to runner to differentiate from loaded model image
- remove logic which keeps the first system prompt in the generation context
This commit is contained in:
Bruce MacDonald
2023-10-19 10:39:58 -04:00
committed by GitHub
parent 235e43d7f6
commit fe6f3b48f7
3 changed files with 66 additions and 86 deletions

View File

@@ -46,13 +46,13 @@ func init() {
var loaded struct {
mu sync.Mutex
llm llm.LLM
runner llm.LLM
expireAt time.Time
expireTimer *time.Timer
digest string
options api.Options
*Model
*api.Options
}
var defaultSessionDuration = 5 * time.Minute
@@ -70,59 +70,39 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
}
// check if the loaded model is still running in a subprocess, in case something unexpected happened
if loaded.llm != nil {
if err := loaded.llm.Ping(ctx); err != nil {
if loaded.runner != nil {
if err := loaded.runner.Ping(ctx); err != nil {
log.Print("loaded llm process not responding, closing now")
// the subprocess is no longer running, so close it
loaded.llm.Close()
loaded.llm = nil
loaded.digest = ""
loaded.runner.Close()
loaded.runner = nil
loaded.Model = nil
loaded.Options = nil
}
}
if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, opts) {
if loaded.llm != nil {
needLoad := loaded.runner == nil || // is there a model loaded?
loaded.ModelPath != model.ModelPath || // has the base model changed?
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) // have the runner options changed?
if needLoad {
if loaded.runner != nil {
log.Println("changing loaded model")
loaded.llm.Close()
loaded.llm = nil
loaded.digest = ""
loaded.runner.Close()
loaded.runner = nil
loaded.Model = nil
loaded.Options = nil
}
llmModel, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
if err != nil {
return err
}
// set cache values before modifying opts
loaded.llm = llmModel
loaded.digest = model.Digest
loaded.options = opts
if opts.NumKeep < 0 {
promptWithSystem, err := model.Prompt(api.GenerateRequest{})
if err != nil {
return err
}
promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}})
if err != nil {
return err
}
tokensWithSystem, err := llmModel.Encode(ctx, promptWithSystem)
if err != nil {
return err
}
tokensNoSystem, err := llmModel.Encode(ctx, promptNoSystem)
if err != nil {
return err
}
opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem)
llmModel.SetOptions(opts)
}
loaded.Model = model
loaded.runner = llmRunner
loaded.Options = &opts
}
loaded.expireAt = time.Now().Add(sessionDuration)
@@ -136,13 +116,13 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
return
}
if loaded.llm == nil {
return
if loaded.runner != nil {
loaded.runner.Close()
}
loaded.llm.Close()
loaded.llm = nil
loaded.digest = ""
loaded.runner = nil
loaded.Model = nil
loaded.Options = nil
})
}
@@ -215,7 +195,7 @@ func GenerateHandler(c *gin.Context) {
if req.Prompt == "" && req.Template == "" && req.System == "" {
ch <- api.GenerateResponse{Model: req.Model, Done: true}
} else {
if err := loaded.llm.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
if err := loaded.runner.Predict(c.Request.Context(), req.Context, prompt, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}
@@ -263,12 +243,12 @@ func EmbeddingHandler(c *gin.Context) {
return
}
if !loaded.options.EmbeddingOnly {
if !loaded.Options.EmbeddingOnly {
c.JSON(http.StatusBadRequest, gin.H{"error": "embedding option must be set to true"})
return
}
embedding, err := loaded.llm.Embedding(c.Request.Context(), req.Prompt)
embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
log.Printf("embedding generation failed: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
@@ -599,8 +579,8 @@ func Serve(ln net.Listener, allowOrigins []string) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-signals
if loaded.llm != nil {
loaded.llm.Close()
if loaded.runner != nil {
loaded.runner.Close()
}
os.RemoveAll(workDir)
os.Exit(0)