no errgroup

This commit is contained in:
Michael Yang
2023-07-11 14:57:17 -07:00
parent 948323fa78
commit a806b03f62
3 changed files with 24 additions and 41 deletions

View File

@@ -16,7 +16,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/lithammer/fuzzysearch/fuzzy"
"golang.org/x/sync/errgroup"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llama"
@@ -56,12 +55,8 @@ func generate(c *gin.Context) {
req.Model = path.Join(cacheDir(), "models", req.Model+".bin")
}
llm, err := llama.New(req.Model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer llm.Close()
ch := make(chan any)
go stream(c, ch)
templateNames := make([]string, 0, len(templates.Templates()))
for _, template := range templates.Templates() {
@@ -79,24 +74,22 @@ func generate(c *gin.Context) {
req.Prompt = sb.String()
}
ch := make(chan any)
g, _ := errgroup.WithContext(c.Request.Context())
g.Go(func() error {
defer close(ch)
return llm.Predict(req.Prompt, func(s string) {
ch <- api.GenerateResponse{Response: s}
})
})
g.Go(func() error {
stream(c, ch)
return nil
})
if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) {
llm, err := llama.New(req.Model, req.Options)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
defer llm.Close()
fn := func(s string) {
ch <- api.GenerateResponse{Response: s}
}
if err := llm.Predict(req.Prompt, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
func pull(c *gin.Context) {
@@ -113,24 +106,17 @@ func pull(c *gin.Context) {
}
ch := make(chan any)
g, _ := errgroup.WithContext(c.Request.Context())
g.Go(func() error {
defer close(ch)
return saveModel(remote, func(total, completed int64) {
ch <- api.PullProgress{
Total: total,
Completed: completed,
Percent: float64(total) / float64(completed) * 100,
}
})
})
go stream(c, ch)
g.Go(func() error {
stream(c, ch)
return nil
})
fn := func(total, completed int64) {
ch <- api.PullProgress{
Total: total,
Completed: completed,
Percent: float64(total) / float64(completed) * 100,
}
}
if err := g.Wait(); err != nil && !errors.Is(err, io.EOF) {
if err := saveModel(remote, fn); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}