From 4cba75efc589a4d630dd9182593046ae565fff86 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Thu, 21 Sep 2023 20:38:49 +0100 Subject: [PATCH] remove tmp directories created by previous servers (#559) * remove tmp directories created by previous servers * clean up on server stop * Update routes.go * Update server/routes.go Co-authored-by: Jeffrey Morgan * create top-level temp ollama dir * check file exists before creating --------- Co-authored-by: Jeffrey Morgan Co-authored-by: Michael Yang --- llm/ggml.go | 13 ------------- llm/gguf.go | 14 -------------- llm/llama.go | 34 +++++++++++++++++++--------------- llm/llm.go | 6 +++--- server/images.go | 8 ++++---- server/routes.go | 41 ++++++++++++++++++++++++++++++----------- 6 files changed, 56 insertions(+), 60 deletions(-) diff --git a/llm/ggml.go b/llm/ggml.go index 6d9bab558..a734b0384 100644 --- a/llm/ggml.go +++ b/llm/ggml.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "errors" "io" - "sync" ) type GGML struct { @@ -165,18 +164,6 @@ func (c *containerLORA) Decode(r io.Reader) (model, error) { return nil, nil } -var ( - ggmlInit sync.Once - ggmlRunners []ModelRunner // a slice of ModelRunners ordered by priority -) - -func ggmlRunner() []ModelRunner { - ggmlInit.Do(func() { - ggmlRunners = chooseRunners("ggml") - }) - return ggmlRunners -} - const ( // Magic constant for `ggml` files (unversioned). FILE_MAGIC_GGML = 0x67676d6c diff --git a/llm/gguf.go b/llm/gguf.go index ff2b4f714..27a647d19 100644 --- a/llm/gguf.go +++ b/llm/gguf.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "sync" ) type containerGGUF struct { @@ -368,16 +367,3 @@ func (llm *ggufModel) readArray(r io.Reader) (arr []any, err error) { return } - -var ( - ggufInit sync.Once - ggufRunners []ModelRunner // a slice of ModelRunners ordered by priority -) - -func ggufRunner() []ModelRunner { - ggufInit.Do(func() { - ggufRunners = chooseRunners("gguf") - }) - - return ggufRunners -} diff --git a/llm/llama.go b/llm/llama.go index 9118da2a2..d5f223684 100644 --- a/llm/llama.go +++ b/llm/llama.go @@ -32,7 +32,7 @@ type ModelRunner struct { Path string // path to the model runner executable } -func chooseRunners(runnerType string) []ModelRunner { +func chooseRunners(workDir, runnerType string) []ModelRunner { buildPath := path.Join("llama.cpp", runnerType, "build") var runners []string @@ -61,11 +61,6 @@ func chooseRunners(runnerType string) []ModelRunner { } } - // copy the files locally to run the llama.cpp server - tmpDir, err := os.MkdirTemp("", "llama-*") - if err != nil { - log.Fatalf("load llama runner: failed to create temp dir: %v", err) - } runnerAvailable := false // if no runner files are found in the embed, this flag will cause a fast fail for _, r := range runners { // find all the files in the runner's bin directory @@ -85,18 +80,27 @@ func chooseRunners(runnerType string) []ModelRunner { defer srcFile.Close() // create the directory in case it does not exist - destPath := filepath.Join(tmpDir, filepath.Dir(f)) + destPath := filepath.Join(workDir, filepath.Dir(f)) if err := os.MkdirAll(destPath, 0o755); err != nil { log.Fatalf("create runner temp dir %s: %v", filepath.Dir(f), err) } - destFile, err := os.OpenFile(filepath.Join(destPath, filepath.Base(f)), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755) - if err != nil { - log.Fatalf("write llama runner %s: %v", f, err) - } - defer destFile.Close() - if _, err := io.Copy(destFile, srcFile); err != nil { - log.Fatalf("copy llama runner %s: %v", f, err) + destFile := filepath.Join(destPath, filepath.Base(f)) + + _, err = os.Stat(destFile) + switch { + case errors.Is(err, os.ErrNotExist): + destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755) + if err != nil { + log.Fatalf("write llama runner %s: %v", f, err) + } + defer destFile.Close() + + if _, err := io.Copy(destFile, srcFile); err != nil { + log.Fatalf("copy llama runner %s: %v", f, err) + } + case err != nil: + log.Fatalf("stat llama runner %s: %v", f, err) } } } @@ -107,7 +111,7 @@ func chooseRunners(runnerType string) []ModelRunner { // return the runners to try in priority order localRunnersByPriority := []ModelRunner{} for _, r := range runners { - localRunnersByPriority = append(localRunnersByPriority, ModelRunner{Path: path.Join(tmpDir, r)}) + localRunnersByPriority = append(localRunnersByPriority, ModelRunner{Path: path.Join(workDir, r)}) } return localRunnersByPriority diff --git a/llm/llm.go b/llm/llm.go index 537898f5e..25fce3ebf 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -21,7 +21,7 @@ type LLM interface { Ping(context.Context) error } -func New(model string, adapters []string, opts api.Options) (LLM, error) { +func New(workDir, model string, adapters []string, opts api.Options) (LLM, error) { if _, err := os.Stat(model); err != nil { return nil, err } @@ -91,9 +91,9 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) { switch ggml.Name() { case "gguf": opts.NumGQA = 0 // TODO: remove this when llama.cpp runners differ enough to need separate newLlama functions - return newLlama(model, adapters, ggufRunner(), opts) + return newLlama(model, adapters, chooseRunners(workDir, "gguf"), opts) case "ggml", "ggmf", "ggjt", "ggla": - return newLlama(model, adapters, ggmlRunner(), opts) + return newLlama(model, adapters, chooseRunners(workDir, "ggml"), opts) default: return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily()) } diff --git a/server/images.go b/server/images.go index 75ff1aabd..57e786315 100644 --- a/server/images.go +++ b/server/images.go @@ -267,7 +267,7 @@ func filenameWithPath(path, f string) (string, error) { return f, nil } -func CreateModel(ctx context.Context, name string, path string, fn func(resp api.ProgressResponse)) error { +func CreateModel(ctx context.Context, workDir, name string, path string, fn func(resp api.ProgressResponse)) error { mp := ParseModelPath(name) var manifest *ManifestV2 @@ -524,7 +524,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api } // generate the embedding layers - embeddingLayers, err := embeddingLayers(embed) + embeddingLayers, err := embeddingLayers(workDir, embed) if err != nil { return err } @@ -581,7 +581,7 @@ type EmbeddingParams struct { } // embeddingLayers loads the associated LLM and generates the embeddings to be stored from an input file -func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { +func embeddingLayers(workDir string, e EmbeddingParams) ([]*LayerReader, error) { layers := []*LayerReader{} if len(e.files) > 0 { // check if the model is a file path or a model name @@ -594,7 +594,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) { model = &Model{ModelPath: e.model} } - if err := load(context.Background(), model, e.opts, defaultSessionDuration); err != nil { + if err := load(context.Background(), workDir, model, e.opts, defaultSessionDuration); err != nil { return nil, fmt.Errorf("load model to generate embeddings: %v", err) } diff --git a/server/routes.go b/server/routes.go index c463a1af0..c00bad65d 100644 --- a/server/routes.go +++ b/server/routes.go @@ -58,7 +58,7 @@ var loaded struct { var defaultSessionDuration = 5 * time.Minute // load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function -func load(ctx context.Context, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error { +func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]interface{}, sessionDuration time.Duration) error { opts := api.DefaultOptions() if err := opts.FromMap(model.Options); err != nil { log.Printf("could not load model options: %v", err) @@ -94,7 +94,7 @@ func load(ctx context.Context, model *Model, reqOpts map[string]interface{}, ses loaded.Embeddings = model.Embeddings } - llmModel, err := llm.New(model.ModelPath, model.AdapterPaths, opts) + llmModel, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts) if err != nil { return err } @@ -130,6 +130,7 @@ func load(ctx context.Context, model *Model, reqOpts map[string]interface{}, ses llmModel.SetOptions(opts) } } + loaded.expireAt = time.Now().Add(sessionDuration) if loaded.expireTimer == nil { @@ -150,6 +151,7 @@ func load(ctx context.Context, model *Model, reqOpts map[string]interface{}, ses loaded.digest = "" }) } + loaded.expireTimer.Reset(sessionDuration) return nil } @@ -172,8 +174,11 @@ func GenerateHandler(c *gin.Context) { return } - sessionDuration := defaultSessionDuration // TODO: set this duration from the request if specified - if err := load(c.Request.Context(), model, req.Options, sessionDuration); err != nil { + workDir := c.GetString("workDir") + + // TODO: set this duration from the request if specified + sessionDuration := defaultSessionDuration + if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } @@ -245,7 +250,9 @@ func EmbeddingHandler(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - if err := load(c.Request.Context(), model, req.Options, 5*time.Minute); err != nil { + + workDir := c.GetString("workDir") + if err := load(c.Request.Context(), workDir, model, req.Options, 5*time.Minute); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } @@ -335,6 +342,8 @@ func CreateModelHandler(c *gin.Context) { return } + workDir := c.GetString("workDir") + ch := make(chan any) go func() { defer close(ch) @@ -345,7 +354,7 @@ func CreateModelHandler(c *gin.Context) { ctx, cancel := context.WithCancel(c.Request.Context()) defer cancel() - if err := CreateModel(ctx, req.Name, req.Path, fn); err != nil { + if err := CreateModel(ctx, workDir, req.Name, req.Path, fn); err != nil { ch <- gin.H{"error": err.Error()} } }() @@ -519,8 +528,20 @@ func Serve(ln net.Listener, allowOrigins []string) error { ) } + workDir, err := os.MkdirTemp("", "ollama") + if err != nil { + return err + } + defer os.RemoveAll(workDir) + r := gin.Default() - r.Use(cors.New(config)) + r.Use( + cors.New(config), + func(c *gin.Context) { + c.Set("workDir", workDir) + c.Next() + }, + ) r.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") @@ -546,12 +567,10 @@ func Serve(ln net.Listener, allowOrigins []string) error { // listen for a ctrl+c and stop any loaded llm signals := make(chan os.Signal, 1) - signal.Notify(signals, syscall.SIGINT) + signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) go func() { <-signals - if loaded.llm != nil { - loaded.llm.Close() - } + os.RemoveAll(workDir) os.Exit(0) }()