Terminate subprocess if receiving SIGINT or SIGTERM signals while model is loading (#3653)

* terminate subprocess if receiving `SIGINT` or `SIGTERM` signals while model is loading

* use `unload` in signal handler
This commit is contained in:
Jeffrey Morgan
2024-04-15 12:09:32 -04:00
committed by GitHub
parent 7027f264fb
commit a0b8a32eb4
2 changed files with 23 additions and 32 deletions

View File

@@ -68,6 +68,18 @@ var loaded struct {
var defaultSessionDuration = 5 * time.Minute
func unload() {
if loaded.llama != nil {
loaded.llama.Close()
}
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
}
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
ctx, cancel := context.WithTimeout(c, 10*time.Second)
@@ -83,12 +95,7 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
if needLoad {
if loaded.llama != nil {
slog.Info("changing loaded model")
loaded.llama.Close()
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
unload()
}
llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
@@ -108,22 +115,19 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
loaded.projectors = model.ProjectorPaths
loaded.llama = llama
loaded.Options = &opts
if err = llama.WaitUntilRunning(); err != nil {
slog.Error("error loading llama server", "error", err)
unload()
return err
}
}
if loaded.expireTimer == nil {
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
loaded.mu.Lock()
defer loaded.mu.Unlock()
if loaded.llama != nil {
loaded.llama.Close()
}
loaded.llama = nil
loaded.model = ""
loaded.adapters = nil
loaded.projectors = nil
loaded.Options = nil
unload()
})
}
@@ -1146,9 +1150,7 @@ func Serve(ln net.Listener) error {
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-signals
if loaded.llama != nil {
loaded.llama.Close()
}
unload()
gpu.Cleanup()
os.Exit(0)
}()