diff --git a/cache/cache.go b/cache/cache.go index 210aaf234..ef9cf30f7 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -1,63 +1,420 @@ package cache import ( + "errors" + "fmt" + "log/slog" + "math" + "slices" + "github.com/ollama/ollama/ml" ) -type Options struct { - Position int -} +var ErrNotSupported = errors.New("model does not support operation") type Cache interface { + // ** used by model implementations ** + + // Returns an instance of the cache for layer 'i' Sub(i int) Cache - Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor) + + // Returns the history of key and value tensors plus a mask + // + // The tensors are of shape embed dim, kv heads, batch size + // The mask is of shape history size, batch size + Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) + + // Stores a batch of key and value in the cache + // + // The tensors must be of shape embed dim, kv heads, batch size + Put(ctx ml.Context, key, value ml.Tensor) + + // ** cache management ** + + // Closes the cache and frees resources associated with it + Close() + + // Called before the start of the model's forward pass. For each + // token in the coming batch, there must be a corresponding entry + // in positions and seqs. + StartForward(ctx ml.Context, positions []int32, seqs []int) error + + // Copies tokens in the range [0, len) from srcSeq to dstSeq + CopyPrefix(srcSeq, dstSeq int, len int32) + + // Removes tokens in the range [beginIndex, endIndex) from seq. Set + // endIndex to math.MaxInt32 to remove everything starting at beginIndex + Remove(seq int, beginIndex, endIndex int32) error } -type Simple struct { +type Causal struct { DType ml.DType - Capacity int + Capacity int32 + // current forward pass + curLayer int + curLoc int + curBatchSize int + curMask ml.Tensor + curCellRange cellRange + + // metadata + cells []cacheCell + cellRanges map[int]cellRange + + // cache data storage + backend ml.Backend + cacheCtx ml.Context keys, values []ml.Tensor } -func (c *Simple) Sub(i int) Cache { +type seqCell struct { + seq int + pos int32 +} + +type cacheCell struct { + sequences []seqCell +} + +type cellRange struct { + min int + max int +} + +func (cell cacheCell) findSeq(seq int) *seqCell { + for i := range cell.sequences { + if cell.sequences[i].seq == seq { + return &cell.sequences[i] + } + } + return nil +} + +func NewCausalCache(backend ml.Backend, dtype ml.DType, capacity int32) Cache { + return &Causal{ + Capacity: capacity, + DType: dtype, + cells: make([]cacheCell, capacity), + cellRanges: make(map[int]cellRange), + backend: backend, + cacheCtx: backend.NewContext(), + } +} + +func (c *Causal) Close() { + c.cacheCtx.Close() +} + +var ErrKvCacheFull = errors.New("could not find a kv cache slot") + +func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error { + if len(positions) != len(seqs) { + return fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(positions), len(seqs)) + } + + c.curBatchSize = len(positions) + + if c.curBatchSize < 1 { + return errors.New("batch size cannot be less than 1") + } + + var err error + c.curLoc, err = c.findStartLoc() + if errors.Is(err, ErrKvCacheFull) { + c.defrag() + c.curLoc, err = c.findStartLoc() + } + if err != nil { + return err + } + + c.curCellRange = newRange() + for i, pos := range positions { + seq := seqs[i] + + c.cells[c.curLoc+i] = cacheCell{sequences: []seqCell{{seq: seq, pos: pos}}} + + ranges, ok := c.cellRanges[seq] + if !ok { + ranges = newRange() + } + + if c.curLoc+i > ranges.max { + ranges.max = c.curLoc + i + } + if ranges.max > c.curCellRange.max { + c.curCellRange.max = ranges.max + } + + if c.curLoc+i < ranges.min { + ranges.min = c.curLoc + i + } + if ranges.min < c.curCellRange.min { + c.curCellRange.min = ranges.min + } + c.cellRanges[seq] = ranges + } + + c.curMask, err = c.buildMask(ctx, positions, seqs) + + return err +} + +func newRange() cellRange { + return cellRange{ + min: math.MaxInt, + max: 0, + } +} + +func (c *Causal) findStartLoc() (int, error) { + var start, count int + for i := range c.cells { + if len(c.cells[i].sequences) == 0 { + count++ + if count >= c.curBatchSize { + return start, nil + } + } else { + start = i + 1 + count = 0 + } + } + + return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity) +} + +func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) { + // TODO(jessegross): This makes a number of simplifications such as no padding, + // which could be an issue for CUDA graphs and/or flash attention + len := c.curCellRange.max - c.curCellRange.min + 1 + mask := make([]float32, c.curBatchSize*len) + + for i := range c.curBatchSize { + for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { + cellSeq := c.cells[j].findSeq(seqs[i]) + if cellSeq == nil || cellSeq.pos > positions[i] { + mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1)) + } + } + } + + return ctx.FromFloatSlice(mask, len, c.curBatchSize) +} + +func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) { + for _, obj := range objs { + srcView := obj.View(ctx, int(obj.Stride(2))*src, int(obj.Dim(0)*obj.Dim(1))*len) + dstView := obj.View(ctx, int(obj.Stride(2))*dst, int(obj.Dim(0)*obj.Dim(1))*len) + + ctx.Forward(srcView.Copy(ctx, dstView)) + } +} + +func (c *Causal) defrag() { + slog.Debug("defragmenting kv cache") + + // Defrag strategy: + // - Search for empty holes at the beginning of the cache, + // filling them with active data starting at the end + // - If there are contiguous elements that need to be moved, + // combine them into a single operation by holding new moves + // until we see the next one is non-contiguous + // - Fill up the context with the maximum number of operations it + // can hold then compute that and continue with a new context + // + // We could try to optimize placement by grouping blocks from + // the same sequences together but most likely the next forward + // pass will disrupt this anyways, so the real world benefit + // seems limited as this time. + + ctx := c.backend.NewContext() + + // For every move, 6 tensors are required per layer (2 views and a + // copy for each of k and v). For efficiency, we try to group + // multiple contiguous blocks into a single move. However, if we + // exceed the maximum number of tensors then we need to compute + // what we have and start a new batch. + maxMoves := ctx.MaxTensors() / (6 * len(c.keys)) + moves := 0 + + var pendingSrc, pendingDst, pendingLen int + + for dst := range c.cells { + if len(c.cells[dst].sequences) == 0 { + for src := len(c.cells) - 1; src > dst; src-- { + if len(c.cells[src].sequences) != 0 { + c.cells[dst] = c.cells[src] + c.cells[src] = cacheCell{} + + if pendingLen > 0 { + if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen { + pendingSrc = src + pendingLen++ + break + } else { + moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) + moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen) + moves++ + } + } + + pendingSrc = src + pendingDst = dst + pendingLen = 1 + + break + } + } + } + + if moves >= maxMoves { + ctx.Compute(nil) + ctx.Close() + ctx = c.backend.NewContext() + + moves = 0 + } + } + + if pendingLen > 0 { + moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) + moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen) + moves++ + } + + if moves > 0 { + ctx.Compute(nil) + } + ctx.Close() + + for seq := range c.cellRanges { + seqRange := newRange() + + for i, cell := range c.cells { + if cell.findSeq(seq) != nil { + if i < seqRange.min { + seqRange.min = i + } + if i > seqRange.max { + seqRange.max = i + } + } + } + + c.cellRanges[seq] = seqRange + } +} + +func (c *Causal) Sub(i int) Cache { if i >= len(c.keys) { c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...) c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...) } - return &Simple{ - keys: c.keys[i : i+1], - values: c.values[i : i+1], - Capacity: c.Capacity, - DType: c.DType, - } + c.curLayer = i + + return c } -func (c *Simple) Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor) { - if c.keys[0] == nil || c.values[0] == nil { - c.keys[0] = ctx.Zeros(c.DType, int(key.Dim(0)*key.Dim(1))*c.Capacity) - c.values[0] = ctx.Zeros(c.DType, int(value.Dim(0)*value.Dim(1))*c.Capacity) - } +func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { + key := c.keys[c.curLayer] + value := c.values[c.curLayer] - ctx.Forward(key.Copy(ctx, c.keys[0].View(ctx, int(key.Stride(2))*opts.Position, int(key.Dim(0)*key.Dim(1)*key.Dim(2))))) - ctx.Forward(value.Copy(ctx, c.values[0].View(ctx, int(value.Stride(2))*opts.Position, int(value.Dim(0)*value.Dim(1)*value.Dim(2))))) - - n := min(c.Capacity, int(key.Dim(2))+opts.Position) - - key = c.keys[0].View(ctx, 0, + key = key.View(ctx, int(key.Stride(2))*c.curCellRange.min, int(key.Dim(0)), int(key.Stride(1)), int(key.Dim(1)), int(key.Stride(2)), - n, + int(c.curMask.Dim(0)), ) - value = c.values[0].View(ctx, 0, + value = value.View(ctx, int(key.Stride(2))*c.curCellRange.min, int(value.Dim(0)), int(value.Stride(1)), int(value.Dim(1)), int(value.Stride(2)), - n, + int(c.curMask.Dim(0)), ) - // TODO shift context if necessary - - return key, value + return key, value, c.curMask +} + +func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { + if c.curBatchSize != int(key.Dim(2)) { + panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, int(key.Dim(2)))) + } + + if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { + c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int64(c.Capacity)) + c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int64(c.Capacity)) + } + + ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, int(key.Stride(2))*c.curLoc, int(key.Dim(0)*key.Dim(1)*key.Dim(2))))) + ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, int(value.Stride(2))*c.curLoc, int(value.Dim(0)*value.Dim(1)*value.Dim(2))))) +} + +func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { + seqRange := newRange() + + for i := range c.cells { + srcCellSeq := c.cells[i].findSeq(srcSeq) + dstCellSeq := c.cells[i].findSeq(dstSeq) + + if dstCellSeq != nil { + c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s seqCell) bool { return s.seq == dstSeq }) + } + + if srcCellSeq != nil && srcCellSeq.pos < len { + c.cells[i].sequences = append(c.cells[i].sequences, seqCell{seq: dstSeq, pos: srcCellSeq.pos}) + if i < seqRange.min { + seqRange.min = i + } + if i > seqRange.max { + seqRange.max = i + } + } + } + + c.cellRanges[dstSeq] = seqRange +} + +func (c *Causal) shift(seq int, beginIndex, offset int32) error { + panic("Shift not yet implemented") +} + +func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error { + var offset int32 + if endIndex != math.MaxInt32 { + offset = beginIndex - endIndex + } + + seqRange := newRange() + + for i := range c.cells { + cellSeq := c.cells[i].findSeq(seq) + if cellSeq != nil { + if cellSeq.pos >= beginIndex && cellSeq.pos < endIndex { + c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s seqCell) bool { return s.seq == seq }) + } else { + if cellSeq.pos >= endIndex { + cellSeq.pos += offset + } + if i < seqRange.min { + seqRange.min = i + } + if i > seqRange.max { + seqRange.max = i + } + } + } + } + + if endIndex != math.MaxInt32 { + err := c.shift(seq, endIndex, offset) + if err != nil { + return err + } + } + + c.cellRanges[seq] = seqRange + + return nil } diff --git a/cache/tensor.go b/cache/tensor.go new file mode 100644 index 000000000..b714c9a73 --- /dev/null +++ b/cache/tensor.go @@ -0,0 +1,47 @@ +package cache + +import ( + "github.com/ollama/ollama/ml" +) + +type TensorCache struct { + curLayer int + + cacheCtx ml.Context + keys, values []ml.Tensor +} + +func NewTensorCache(backend ml.Backend) *TensorCache { + return &TensorCache{ + cacheCtx: backend.NewContext(), + } +} + +func (c *TensorCache) Close() { + c.cacheCtx.Close() +} + +func (c *TensorCache) Sub(i int) *TensorCache { + if i >= len(c.keys) { + c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...) + c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...) + } + + c.curLayer = i + + return c +} + +func (c *TensorCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { + return c.keys[c.curLayer], c.values[c.curLayer], nil +} + +func (c *TensorCache) Put(ctx ml.Context, key, value ml.Tensor) { + if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { + c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...) + c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...) + } + + ctx.Forward(key.Copy(ctx, c.keys[c.curLayer])) + ctx.Forward(value.Copy(ctx, c.values[c.curLayer])) +} diff --git a/cmd/cmd.go b/cmd/cmd.go index 17c607171..8c70e1550 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -35,9 +35,9 @@ import ( "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/llama" - "github.com/ollama/ollama/llama/runner" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" + "github.com/ollama/ollama/runner" "github.com/ollama/ollama/server" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -338,7 +338,10 @@ func RunHandler(cmd *cobra.Command, args []string) error { return err } - opts.MultiModal = len(info.ProjectorInfo) != 0 + // TODO(jessegross): We should either find another way to know if this is + // a vision model or remove the logic. Also consider that other modalities will + // need different behavior anyways. + opts.MultiModal = true opts.ParentModel = info.Details.ParentModel if interactive { diff --git a/cmd/runner/main.go b/cmd/runner/main.go index 34b0e9d21..fbfafc7ff 100644 --- a/cmd/runner/main.go +++ b/cmd/runner/main.go @@ -4,7 +4,7 @@ import ( "fmt" "os" - "github.com/ollama/ollama/llama/runner" + "github.com/ollama/ollama/runner" ) func main() { diff --git a/envconfig/config.go b/envconfig/config.go index c10095a64..f3796c71e 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -165,6 +165,8 @@ var ( IntelGPU = Bool("OLLAMA_INTEL_GPU") // MultiUserCache optimizes prompt caching for multi-user scenarios MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE") + // Enable the new Ollama engine + NewRunners = Bool("OLLAMA_NEW_RUNNERS") ) func String(s string) func() string { @@ -250,6 +252,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, + "OLLAMA_NEW_RUNNERS": {"OLLAMA_NEW_RUNNERS", NewRunners(), "Enable the new Ollama engine"}, // Informational "HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"}, diff --git a/llm/server.go b/llm/server.go index 350e14268..9d9922b80 100644 --- a/llm/server.go +++ b/llm/server.go @@ -252,6 +252,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range } finalParams := []string{"runner"} + if envconfig.NewRunners() { + finalParams = append(finalParams, "--new-runner") + } finalParams = append(finalParams, params...) finalParams = append(finalParams, "--port", strconv.Itoa(port)) diff --git a/ml/backend.go b/ml/backend.go index 1ffd2f631..a4e603f68 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -43,12 +43,13 @@ func NewBackend(f *os.File) (Backend, error) { } type Context interface { - Zeros(dtype DType, shape ...int) Tensor + Zeros(dtype DType, shape ...int64) Tensor FromFloatSlice(s []float32, shape ...int) (Tensor, error) FromIntSlice(s []int32, shape ...int) (Tensor, error) Forward(Tensor) Compute(Tensor) Tensor + MaxTensors() int Close() error } diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 8274f3ebb..20ecb5dc9 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -23,7 +23,7 @@ import ( "github.com/ollama/ollama/ml" "golang.org/x/sync/errgroup" - "github.com/ollama/ollama/ml/backend/ggml/ggml/src" + ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" ) type device struct { @@ -198,10 +198,9 @@ func (b *Backend) Get(name string) ml.Tensor { func (b *Backend) NewContext() ml.Context { nodes := max(8192, len(b.meta.Tensors().Items())*5) - bts := make([]byte, C.size_t(nodes)*C.ggml_tensor_overhead()+C.ggml_graph_overhead_custom(C.size_t(nodes), false)) c := C.ggml_init(C.struct_ggml_init_params{ - mem_buffer: unsafe.Pointer(&bts[0]), - mem_size: C.size_t(len(bts)), + mem_buffer: nil, + mem_size: C.size_t(nodes)*C.ggml_tensor_overhead() + C.ggml_graph_overhead_custom(C.size_t(nodes), false), no_alloc: true, }) @@ -244,17 +243,23 @@ func (c *Context) Forward(t ml.Tensor) { } func (c *Context) Compute(t ml.Tensor) ml.Tensor { - c.Forward(t) C.ggml_backend_sched_graph_compute_async(c.sched, c.graph) - backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t) + if t != nil && C.ggml_nbytes(t.(*Tensor).t) != 0 { + backend := C.ggml_backend_sched_get_tensor_backend(c.sched, t.(*Tensor).t) + + t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t)) + C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) + } - t.(*Tensor).data = make([]byte, C.ggml_nbytes(t.(*Tensor).t)) - C.ggml_backend_tensor_get_async(backend, t.(*Tensor).t, unsafe.Pointer(&t.(*Tensor).data[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) return t } -func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { +func (c *Context) MaxTensors() int { + return c.nodes +} + +func (c Context) Zeros(dtype ml.DType, shape ...int64) ml.Tensor { if len(shape) < 1 || len(shape) > 4 { panic("unsupported number of dimensions") } @@ -283,6 +288,13 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) { n := len(s) + + if n == 0 { + shape := 0 + t := C.ggml_new_tensor(ctx.ctx, dtype, 1, (*C.int64_t)(unsafe.Pointer(&shape))) + return &Tensor{t: t}, nil + } + for _, v := range shape { n /= v } diff --git a/model/cmd/main.go b/model/cmd/main.go deleted file mode 100644 index ed7901c7e..000000000 --- a/model/cmd/main.go +++ /dev/null @@ -1,160 +0,0 @@ -package main - -import ( - "errors" - "flag" - "fmt" - "image" - "io" - "log/slog" - "os" - "path/filepath" - "strings" - - "github.com/ollama/ollama/cache" - "github.com/ollama/ollama/ml" - "github.com/ollama/ollama/model" - _ "github.com/ollama/ollama/model/llama" - _ "github.com/ollama/ollama/model/mllama" - "github.com/ollama/ollama/sample" -) - -var args struct { - n int - debug bool - image string - cache bool -} - -func temp() error { - flag.IntVar(&args.n, "n", 10, "number of samples") - flag.BoolVar(&args.debug, "debug", false, "enable debug logging") - flag.StringVar(&args.image, "image", "", "path to image file") - flag.BoolVar(&args.cache, "cache", false, "enable KV cache") - - flag.Parse() - - var prompt string - if n := len(flag.Args()); n == 1 { - bts, err := io.ReadAll(os.Stdin) - if err != nil { - return err - } - - prompt = string(bts) - } else if n > 1 { - prompt = strings.Join(flag.Args()[1:], " ") - } else { - return fmt.Errorf("usage: %s path/to/file longest { + longest = count + longestSlot = &c.slots[i] + } + } + + if longestSlot == nil { + return nil, 0, errors.New("no available cache slots") + } + + return longestSlot, longest, nil +} + +func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) { + oldest := time.Now() + var oldestSlot *InputCacheSlot + + longest := int32(-1) + var longestSlot *InputCacheSlot + + for i, s := range c.slots { + count := countCommonPrefix(s.Inputs, prompt) + if count > longest { + longest = count + longestSlot = &c.slots[i] + } + + if s.lastUsed.Compare(oldest) < 0 && !s.InUse { + oldest = s.lastUsed + oldestSlot = &c.slots[i] + } + } + + if longest == int32(len(longestSlot.Inputs)) && !longestSlot.InUse { + return longestSlot, longest, nil + } + + if oldestSlot.InUse { + return nil, 0, errors.New("no available cache slots") + } + + if len(oldestSlot.Inputs) != 0 { + slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs), + "used", oldestSlot.lastUsed) + } + + if longest > 0 && longestSlot != oldestSlot { + slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total", + len(longestSlot.Inputs)) + oldestSlot.Inputs = make([]input, longest) + copy(oldestSlot.Inputs, longestSlot.Inputs[:longest]) + // This is only nil for unit tests + if c.cache != nil { + c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest) + } + } + + return oldestSlot, longest, nil +} + +func countCommonPrefix(a []input, b []input) int32 { + var count int32 + + for i := range a { + if i >= len(b) { + break + } + + if !reflect.DeepEqual(a[i], b[i]) { + break + } + + count++ + } + + return count +} + +func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 { + targetFree := (c.numCtx - numKeep) / 2 + targetFree = max(targetFree, 1) + + currentFree := c.numCtx - inputLen + discard := targetFree - currentFree + + if discard < 0 { + discard = 0 + } + + return discard +} + +// Frees up space in the KV cache by deleting the oldest half of history and shifting +// the newest half into that space (saving numKeep inputs at the beginning). +// +// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx) +func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error { + if numKeep >= c.numCtx { + return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx) + } + + inputLen := int32(len(slot.Inputs)) + discard := c.ShiftDiscard(inputLen, numKeep) + + if discard <= 0 { + return nil + } + + slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs), + "keep", numKeep, "discard", discard) + + // TODO (jessegross): KV cache removal can fail for certain types of models + err := c.cache.Remove(slot.Id, numKeep, numKeep+discard) + if err != nil { + return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err) + } + + for i := numKeep + discard; i < inputLen; i++ { + slot.Inputs[i-discard] = slot.Inputs[i] + } + slot.Inputs = slot.Inputs[:inputLen-discard] + + return nil +} diff --git a/runner/newrunner/cache_test.go b/runner/newrunner/cache_test.go new file mode 100644 index 000000000..b411c7b76 --- /dev/null +++ b/runner/newrunner/cache_test.go @@ -0,0 +1,291 @@ +package newrunner + +import ( + "image" + "testing" + "time" +) + +func TestCountCommon(t *testing.T) { + imgA := image.NewRGBA(image.Rect(0, 0, 100, 100)) + imgB := image.NewRGBA(image.Rect(0, 0, 50, 50)) + imgC := image.NewRGBA(image.Rect(50, 50, 100, 100)) + + tests := []struct { + name string + t1 []input + t2 []input + expected int32 + }{ + { + name: "Equal", + t1: []input{{token: 1}, {token: 2}, {token: 3}}, + t2: []input{{token: 1}, {token: 2}, {token: 3}}, + expected: 3, + }, + { + name: "Prefix", + t1: []input{{token: 1}}, + t2: []input{{token: 1}, {token: 2}, {token: 3}}, + expected: 1, + }, + { + name: "Image Prefix", + t1: []input{{image: imgA}}, + t2: []input{{image: imgA}, {image: imgB}, {image: imgC}}, + expected: 1, + }, + { + name: "Mixed", + t1: []input{{token: 1}, {image: imgA}}, + t2: []input{{token: 1}, {image: imgA}, {token: 5}}, + expected: 2, + }, + { + name: "Empty", + t1: []input{}, + t2: []input{{token: 1}, {token: 2}, {token: 3}}, + expected: 0, + }, + { + name: "Both Empty", + t1: []input{}, + t2: []input{}, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := countCommonPrefix(tt.t1, tt.t2) + if result != tt.expected { + t.Errorf("countCommonPrefix(%v, %v): have %v; want %v", tt.t1, tt.t2, result, tt.expected) + } + }) + } +} + +func TestFindCacheSlot(t *testing.T) { + type expected struct { + result int + len int32 + } + + tests := []struct { + name string + cache InputCache + prompt []input + longest expected + best expected + }{ + { + name: "Empty", + cache: InputCache{slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{}, + InUse: false, + lastUsed: time.Time{}, + }, + { + Id: 1, + Inputs: []input{}, + InUse: false, + lastUsed: time.Time{}, + }, + }}, + prompt: []input{{token: 1}}, + longest: expected{result: 0, len: 0}, + best: expected{result: 0, len: 0}, + }, + { + name: "Extend", + cache: InputCache{slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{{token: 1}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input{{token: 1}, {token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-2 * time.Second), + }, + }}, + prompt: []input{{token: 1}, {token: 2}}, + longest: expected{result: 1, len: 2}, + best: expected{result: 1, len: 2}, + }, + { + name: "New", + cache: InputCache{slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{{token: 1}, {token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input{}, + InUse: false, + lastUsed: time.Time{}, + }, + }}, + prompt: []input{{token: 2}}, + longest: expected{result: 0, len: 0}, + best: expected{result: 1, len: 0}, + }, + { + name: "Fork", + cache: InputCache{ + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{{token: 1}, {token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input{}, + InUse: false, + lastUsed: time.Time{}, + }, + }, + }, + prompt: []input{{token: 1}}, + longest: expected{result: 0, len: 1}, + best: expected{result: 1, len: 1}, + }, + { + name: "Evict", + cache: InputCache{slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{{token: 1}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input{{token: 1}, {token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-2 * time.Second), + }, + }}, + prompt: []input{{token: 2}, {token: 3}}, + longest: expected{result: 0, len: 0}, + best: expected{result: 1, len: 0}, + }, + { + name: "In use", + cache: InputCache{slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{{token: 1}, {token: 2}}, + InUse: true, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input{{token: 1}}, + InUse: false, + lastUsed: time.Now().Add(-2 * time.Second), + }, + }}, + prompt: []input{{token: 1}, {token: 2}}, + longest: expected{result: 1, len: 1}, + best: expected{result: 1, len: 2}, + }, + } + + for _, tt := range tests { + t.Run("Longest-"+tt.name, func(t *testing.T) { + result, resultLen, err := tt.cache.findLongestCacheSlot(tt.prompt) + if err != nil { + t.Errorf("findLongestCacheSlot: err %v", err) + } else if result.Id != tt.longest.result || resultLen != tt.longest.len { + t.Errorf("findLongestCacheSlot: slot have %v, want %v len have %v, want %v", + result.Id, tt.longest.result, resultLen, tt.longest.len) + } + }) + } + + for _, tt := range tests { + t.Run("Best-"+tt.name, func(t *testing.T) { + result, resultLen, err := tt.cache.findBestCacheSlot(tt.prompt) + if err != nil { + t.Errorf("findBestCacheSlot: err %v", err) + } else if result.Id != tt.best.result || resultLen != tt.best.len { + t.Errorf("findBestCacheSlot: slot have %v, want %v len have %v, want %v", + result.Id, tt.best.result, resultLen, tt.best.len) + } + }) + } +} + +func TestShiftDiscard(t *testing.T) { + tests := []struct { + name string + numCtx int32 + numKeep int32 + inputLen int32 + expected int32 + }{ + { + name: "Shift", + numCtx: 2048, + numKeep: 5, + inputLen: 2048, + expected: 1021, + }, + { + name: "Max Keep", + numCtx: 2048, + numKeep: 2047, + inputLen: 2048, + expected: 1, + }, + { + name: "No Keep", + numCtx: 2048, + numKeep: 0, + inputLen: 2048, + expected: 1024, + }, + { + name: "Truncate", + numCtx: 2048, + numKeep: 5, + inputLen: 5000, + expected: 3973, + }, + { + name: "Truncate Keep", + numCtx: 2048, + numKeep: 2047, + inputLen: 5000, + expected: 2953, + }, + { + name: "No Op", + numCtx: 2048, + numKeep: 5, + inputLen: 512, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := InputCache{numCtx: tt.numCtx} + result := c.ShiftDiscard(tt.inputLen, tt.numKeep) + if result != tt.expected { + t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected) + } + }) + } +} diff --git a/runner/newrunner/runner.go b/runner/newrunner/runner.go new file mode 100644 index 000000000..e872ecf22 --- /dev/null +++ b/runner/newrunner/runner.go @@ -0,0 +1,941 @@ +package newrunner + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "image" + "io" + "log" + "log/slog" + "net" + "net/http" + "os" + "path/filepath" + "regexp" + "runtime" + "strconv" + "strings" + "sync" + "time" + "unicode/utf8" + + "golang.org/x/sync/semaphore" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/runner/common" + "github.com/ollama/ollama/sample" + + _ "github.com/ollama/ollama/model/llama" + _ "github.com/ollama/ollama/model/mllama" +) + +// input is an element of the prompt to process, either a token or an image +type input struct { + token int32 + + image image.Image +} + +type Sequence struct { + // batch index + iBatch int + + // prompt inputs left to evaluate + inputs []input + + // inputs that have been added to a batch but not yet submitted to Forward + pendingInputs []input + + // tokens that have been generated but not returned yet (e.g. for stop sequences) + pendingResponses []string + + // input cache being used by this sequence + cache *InputCacheSlot + + // channel to send responses over + responses chan string + + // channel to stop decoding (such as if the remote connection is closed) + quit chan bool + + // number of tokens to predict + numPredict int + + // set of samplers to run on generated logits + samplers []sample.Sampler + + // channel to send back the embedding if embedding only + embedding chan []float32 + + // stop sequences + stop []string + + // number of inputs to keep at the beginning when shifting context window + numKeep int32 + + // true if an embedding are to be returned instead of text generation + embeddingOnly bool + + doneReason string + + // Metrics + startProcessingTime time.Time + startGenerationTime time.Time + numPredicted int + numPromptInputs int +} + +type NewSequenceParams struct { + numPredict int + stop []string + numKeep int32 + samplers []sample.Sampler + embedding bool +} + +func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { + s.ready.Wait() + + startTime := time.Now() + + inputs, err := s.inputs(prompt, images) + if err != nil { + return nil, fmt.Errorf("failed to process inputs: %w", err) + } else if len(inputs) == 0 { + return nil, errors.New("no input provided") + } + + if params.numKeep < 0 { + params.numKeep = int32(len(inputs)) + } + + // Ensure that at least 1 input can be discarded during shift + params.numKeep = min(params.numKeep, s.cache.numCtx-1) + + if int32(len(inputs)) > s.cache.numCtx { + discard := int32(len(inputs)) - s.cache.numCtx + newInputs := inputs[:params.numKeep] + newInputs = append(newInputs, inputs[params.numKeep+discard:]...) + + slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs)) + inputs = newInputs + } + + // TODO(jessegross): Ingest cached history for grammar + + return &Sequence{ + inputs: inputs, + numPromptInputs: len(inputs), + startProcessingTime: startTime, + numPredict: params.numPredict, + pendingResponses: make([]string, 0), + responses: make(chan string, 100), + quit: make(chan bool, 1), + embedding: make(chan []float32, 1), + samplers: params.samplers, + embeddingOnly: params.embedding, + stop: params.stop, + numKeep: params.numKeep, + }, nil +} + +// inputs processes the prompt and images into a list of inputs +// by splitting the prompt on [img-] tags, tokenizing text and +// decoding images +func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { + var inputs []input + var parts []string + var matches [][]string + + // TODO(jessegross): This can sometimes trigger for matching text in the + // user's prompt. We previously tried to avoid it by only looking for images + // on image models. We don't have a clear indication now but it would be better + // to properly escape it in any case. + re := regexp.MustCompile(`\[img-(\d+)\]`) + parts = re.Split(prompt, -1) + matches = re.FindAllStringSubmatch(prompt, -1) + + for i, part := range parts { + // text - tokenize + tokens, err := s.model.(model.TextProcessor).Encode(part) + if err != nil { + return nil, err + } + + for _, t := range tokens { + inputs = append(inputs, input{token: t}) + } + + // image - decode and store + if i < len(matches) { + n, _ := strconv.Atoi(matches[i][1]) + + imageIndex := -1 + for j := range images { + if images[j].ID == n { + imageIndex = j + break + } + } + + if imageIndex < 0 { + return nil, fmt.Errorf("invalid image index: %d", n) + } + + image, _, err := image.Decode(bytes.NewReader(images[imageIndex].Data)) + if err != nil { + return nil, err + } + + inputs = append(inputs, input{image: image}) + } + } + + return inputs, nil +} + +type Server struct { + // is the server ready to process requests? + // protects access to model and image + ready sync.WaitGroup + + // loaded model + model model.Model + + // status for external health reporting - loading, ready to serve, etc. + status ServerStatus + + // current progress on loading the model + progress float32 + + // number of simultaneous requests to handle + parallel int + + // maximum number of elements in a batch (per sequence) + // TODO (jmorganca): make this n_batch + batchSize int + + // protects access to everything below this line + // this is context state needed for decoding + mu sync.Mutex + + // indicates that data is ready for processing + cond *sync.Cond + + // the list of simultaneous sequences being evaluated + seqs []*Sequence + + // seqs can have a maximum of parallel entries, which + // is enfoced by seqSem + seqsSem *semaphore.Weighted + + // KV cache + cache *InputCache +} + +func (s *Server) allNil() bool { + for _, item := range s.seqs { + if item != nil { + return false + } + } + return true +} + +func flushPending(seq *Sequence) bool { + joined := strings.Join(seq.pendingResponses, "") + seq.pendingResponses = []string{} + + // Check if there are any partial UTF-8 characters remaining. + // We already check and queue as we are generating but some may + // still make it here: + // - Sequence is ending, e.g. generation limit has been hit + // - Invalid characters in the middle of a string + // This is a stricter check to ensure we never output invalid Unicode. + for !utf8.ValidString(joined) { + joined = joined[:len(joined)-1] + } + + if len(joined) == 0 { + return true + } + + select { + case seq.responses <- joined: + return true + case <-seq.quit: + return false + } +} + +func (s *Server) removeSequence(seqIndex int, reason string) { + seq := s.seqs[seqIndex] + + flushPending(seq) + seq.doneReason = reason + close(seq.responses) + close(seq.embedding) + seq.cache.InUse = false + s.seqs[seqIndex] = nil + s.seqsSem.Release(1) +} + +func (s *Server) run(ctx context.Context) { + s.ready.Wait() + + for { + select { + case <-ctx.Done(): + return + default: + err := s.processBatch() + if err != nil { + panic(err) + } + } + } +} + +func (s *Server) processBatch() error { + s.mu.Lock() + for s.allNil() { + s.cond.Wait() // Wait until an item is added + } + defer s.mu.Unlock() + + var inputIDs []int32 + var pos []int32 + var outputs []int32 + var seqs []int + + var image image.Image + + for i, seq := range s.seqs { + if seq == nil { + continue + } + + // if past the num predict limit + if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { + s.removeSequence(i, "limit") + continue + } + + for j, input := range seq.inputs { + if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx { + if len(seq.pendingInputs) == 0 { + err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) + if err != nil { + return err + } + } else { + break + } + } + + if j >= s.batchSize { + break + } + + if input.image != nil { + if image != nil { + break + } + image = input.image + seq.pendingInputs = append(seq.pendingInputs, input) + continue + } + + inputIDs = append(inputIDs, input.token) + pos = append(pos, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) + seqs = append(seqs, seq.cache.Id) + + seq.iBatch = len(outputs) + if j+1 == len(seq.inputs) { + outputs = append(outputs, int32(len(inputIDs)-1)) + } + seq.pendingInputs = append(seq.pendingInputs, input) + } + + seq.inputs = seq.inputs[len(seq.pendingInputs):] + } + + if len(inputIDs) == 0 { + return nil + } + + var options []model.OptionsFunc + if image != nil { + options = append(options, model.WithImage(image)) + } + + ctx := s.model.Backend().NewContext() + defer ctx.Close() + + logit, err := model.Forward(ctx, s.model, append(options, model.WithCache(s.cache.cache), model.WithInputIDs(inputIDs), model.WithPositions(pos), model.WithOutputs(outputs), model.WithSequences(seqs))...) + if err != nil { + return err + } + + f32s := logit.Floats() + + for i, seq := range s.seqs { + if seq == nil { + continue + } + + // After calling Forward, pending inputs are now in the cache + if len(seq.pendingInputs) > 0 { + seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) + seq.pendingInputs = []input{} + } + + // don't sample prompt processing + if len(seq.inputs) != 0 { + continue + } + + seq.numPredicted++ + if seq.numPredicted == 1 { + seq.startGenerationTime = time.Now() + } + + // if done processing the prompt, generate an embedding and return + if seq.embeddingOnly { + /*embed := s.lc.GetEmbeddingsSeq(seq.cache.Id) + if embed == nil { + embed = s.lc.GetEmbeddingsIth(seq.iBatch) + } + + seq.embedding <- embed*/ + s.removeSequence(i, "") + continue + } + + vocabSize := len(f32s) / len(outputs) + seqLogits := f32s[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize] + + // TODO(jessegross): The data type and number of outputs for the samplers seem inconsistent + f64s := make([]float64, vocabSize) + for j, f32 := range seqLogits { + f64s[j] = float64(f32) + } + + // do sampling + f64s, err = sample.Sample(f64s, seq.samplers...) + if err != nil { + return err + } + + var outputIDs []int32 + for _, f64 := range f64s { + if !s.model.(model.TextProcessor).Is(uint32(f64), model.SpecialEOS) { + outputIDs = append(outputIDs, int32(f64)) + } else { + s.removeSequence(i, "stop") + continue + } + } + + if len(outputIDs) == 0 { + continue + } + + piece, err := s.model.(model.TextProcessor).Decode(outputIDs) + if errors.Is(err, io.EOF) { + continue + } else if err != nil { + return err + } + + for _, id := range outputIDs { + seq.inputs = append(seq.inputs, input{token: id}) + } + + seq.pendingResponses = append(seq.pendingResponses, piece) + sequence := strings.Join(seq.pendingResponses, "") + + if ok, stop := common.FindStop(sequence, seq.stop); ok { + slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) + + var tokenTruncated bool + origLen := len(seq.pendingResponses) + seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop) + newLen := len(seq.pendingResponses) + + // Update the cache based on the tokens that will be returned: + // - We have more tokens than are currently in the cache because + // the last ones generated weren't submitted to Forward + // - Remove any stop sequences that we stripped out + // - If truncateStop removed a portion of a token, drop that + // - As defense-in-depth, if truncatedToken didn't find a stop token + // remove the extra ones that we added to the cache len + tokenLen := len(seq.cache.Inputs) + len(outputIDs) + tokenLen -= origLen - newLen + if tokenTruncated { + tokenLen-- + } + if origLen == newLen { + tokenLen = len(seq.cache.Inputs) + } + seq.cache.Inputs = seq.cache.Inputs[:tokenLen] + + s.removeSequence(i, "stop") + continue + } + + if common.ContainsStopSuffix(sequence, seq.stop) { + continue + } + + if common.IncompleteUnicode(sequence) { + continue + } + + if !flushPending(seq) { + s.removeSequence(i, "connection") + } + } + + return nil +} + +// TODO (jmorganca): use structs from the api package to avoid duplication +// this way the api acts as a proxy instead of using a different api for the +// runner +type Options struct { + api.Runner + + NumKeep int `json:"n_keep"` + Seed int `json:"seed"` + NumPredict int `json:"n_predict"` + TopK int `json:"top_k"` + TopP float32 `json:"top_p"` + MinP float32 `json:"min_p"` + TypicalP float32 `json:"typical_p"` + RepeatLastN int `json:"repeat_last_n"` + Temperature float32 `json:"temperature"` + RepeatPenalty float32 `json:"repeat_penalty"` + PresencePenalty float32 `json:"presence_penalty"` + FrequencyPenalty float32 `json:"frequency_penalty"` + Mirostat int `json:"mirostat"` + MirostatTau float32 `json:"mirostat_tau"` + MirostatEta float32 `json:"mirostat_eta"` + Stop []string `json:"stop"` +} + +type ImageData struct { + Data []byte `json:"data"` + ID int `json:"id"` + AspectRatioID int `json:"aspect_ratio_id"` +} + +type CompletionRequest struct { + Prompt string `json:"prompt"` + Images []ImageData `json:"image_data"` + Grammar string `json:"grammar"` + CachePrompt bool `json:"cache_prompt"` + + Options +} + +type Timings struct { + PredictedN int `json:"predicted_n"` + PredictedMS float64 `json:"predicted_ms"` + PromptN int `json:"prompt_n"` + PromptMS float64 `json:"prompt_ms"` +} + +type CompletionResponse struct { + Content string `json:"content"` + Stop bool `json:"stop"` + + Model string `json:"model,omitempty"` + Prompt string `json:"prompt,omitempty"` + StoppedLimit bool `json:"stopped_limit,omitempty"` + PredictedN int `json:"predicted_n,omitempty"` + PredictedMS float64 `json:"predicted_ms,omitempty"` + PromptN int `json:"prompt_n,omitempty"` + PromptMS float64 `json:"prompt_ms,omitempty"` + + Timings Timings `json:"timings"` +} + +func getSamplers(_ CompletionRequest) []sample.Sampler { + // TODO(jessegross): Waiting for sampling code + + /*var samplingParams llama.SamplingParams + samplingParams.TopK = req.TopK + samplingParams.TopP = req.TopP + samplingParams.MinP = req.MinP + samplingParams.TypicalP = req.TypicalP + samplingParams.Temp = req.Temperature + samplingParams.RepeatLastN = req.RepeatLastN + samplingParams.PenaltyRepeat = req.RepeatPenalty + samplingParams.PenaltyFreq = req.FrequencyPenalty + samplingParams.PenaltyPresent = req.PresencePenalty + samplingParams.Mirostat = req.Mirostat + samplingParams.MirostatTau = req.MirostatTau + samplingParams.MirostatEta = req.MirostatEta + samplingParams.Seed = uint32(req.Seed) + samplingParams.Grammar = req.Grammar*/ + + return []sample.Sampler{sample.Greedy()} +} + +func (s *Server) completion(w http.ResponseWriter, r *http.Request) { + var req CompletionRequest + req.Options = Options(api.DefaultOptions()) + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + // Set the headers to indicate streaming + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Transfer-Encoding", "chunked") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming not supported", http.StatusInternalServerError) + return + } + + seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ + numPredict: req.NumPredict, + stop: req.Stop, + numKeep: int32(req.NumKeep), + samplers: getSamplers(req), + embedding: false, + }) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) + return + } + + // Ensure there is a place to put the sequence, released when removed from s.seqs + if err := s.seqsSem.Acquire(r.Context(), 1); err != nil { + if errors.Is(err, context.Canceled) { + slog.Info("aborting completion request due to client closing the connection") + } else { + slog.Error("Failed to acquire semaphore", "error", err) + } + return + } + + s.mu.Lock() + found := false + for i, sq := range s.seqs { + if sq == nil { + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + if err != nil { + s.mu.Unlock() + http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) + return + } + + s.seqs[i] = seq + s.cond.Signal() + found = true + break + } + } + s.mu.Unlock() + + if !found { + http.Error(w, "could not find an available sequence", http.StatusInternalServerError) + return + } + + for { + select { + case <-r.Context().Done(): + close(seq.quit) + return + case content, ok := <-seq.responses: + if ok { + if err := json.NewEncoder(w).Encode(&CompletionResponse{ + Content: content, + }); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + close(seq.quit) + return + } + + flusher.Flush() + } else { + // Send the final response + if err := json.NewEncoder(w).Encode(&CompletionResponse{ + Stop: true, + StoppedLimit: seq.doneReason == "limit", + Timings: Timings{ + PromptN: seq.numPromptInputs, + PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()), + PredictedN: seq.numPredicted, + PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()), + }, + }); err != nil { + http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) + } + + return + } + } + } +} + +type EmbeddingRequest struct { + Content string `json:"content"` + CachePrompt bool `json:"cache_prompt"` +} + +type EmbeddingResponse struct { + Embedding []float32 `json:"embedding"` +} + +func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { + var req EmbeddingRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + + slog.Debug("embedding request", "content", req.Content) + + seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true}) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) + return + } + + // Ensure there is a place to put the sequence, released when removed from s.seqs + if err := s.seqsSem.Acquire(r.Context(), 1); err != nil { + if errors.Is(err, context.Canceled) { + slog.Info("aborting embeddings request due to client closing the connection") + } else { + slog.Error("Failed to acquire semaphore", "error", err) + } + return + } + + s.mu.Lock() + found := false + for i, sq := range s.seqs { + if sq == nil { + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + if err != nil { + s.mu.Unlock() + http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) + return + } + s.seqs[i] = seq + s.cond.Signal() + found = true + break + } + } + s.mu.Unlock() + + if !found { + http.Error(w, "could not find an available sequence", http.StatusInternalServerError) + return + } + + embedding := <-seq.embedding + + if err := json.NewEncoder(w).Encode(&EmbeddingResponse{ + Embedding: embedding, + }); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + } +} + +type HealthResponse struct { + Status string `json:"status"` + Progress float32 `json:"progress"` +} + +type ServerStatus int + +const ( + ServerStatusReady ServerStatus = iota + ServerStatusLoadingModel + ServerStatusError +) + +func (s ServerStatus) ToString() string { + switch s { + case ServerStatusReady: + return "ok" + case ServerStatusLoadingModel: + return "loading model" + default: + return "server error" + } +} + +func (s *Server) health(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(&HealthResponse{ + Status: s.status.ToString(), + Progress: s.progress, + }); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + } +} + +type multiLPath []string + +func (m *multiLPath) Set(value string) error { + *m = append(*m, value) + return nil +} + +func (m *multiLPath) String() string { + return strings.Join(*m, ", ") +} + +func (s *Server) loadModel( + mpath string, + lpath multiLPath, + kvCacheType string, + kvSize int, + multiUserCache bool, +) { + var err error + s.model, err = model.New(mpath) + if err != nil { + panic(err) + } + + // TODO(jessegross): LoRA loading + if lpath.String() != "" { + panic("loras are not yet implemented") + } + + s.cache, err = NewInputCache(s.model.Backend(), kvCacheType, int32(kvSize), s.parallel, multiUserCache) + if err != nil { + panic(err) + } + + s.status = ServerStatusReady + s.ready.Done() +} + +func Execute(args []string) error { + fs := flag.NewFlagSet("runner", flag.ExitOnError) + mpath := fs.String("model", "", "Path to model binary file") + parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously") + batchSize := fs.Int("batch-size", 512, "Batch size") + _ = fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU") + _ = fs.Int("main-gpu", 0, "Main GPU") + _ = fs.Bool("flash-attn", false, "Enable flash attention") + kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size") + kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)") + port := fs.Int("port", 8080, "Port to expose the server on") + _ = fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation") + verbose := fs.Bool("verbose", false, "verbose output (default: disabled)") + _ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)") + _ = fs.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing") + _ = fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions") + multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users") + + var lpaths multiLPath + fs.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)") + + fs.Usage = func() { + fmt.Fprintf(fs.Output(), "Runner usage\n") + fs.PrintDefaults() + } + if err := fs.Parse(args); err != nil { + return err + } + level := slog.LevelInfo + if *verbose { + level = slog.LevelDebug + } + handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: level, + AddSource: true, + ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr { + if attr.Key == slog.SourceKey { + source := attr.Value.Any().(*slog.Source) + source.File = filepath.Base(source.File) + } + return attr + }, + }) + slog.SetDefault(slog.New(handler)) + slog.Info("starting ollama engine") + // TODO(jessegross): Some system info would be useful + + server := &Server{ + batchSize: *batchSize, + parallel: *parallel, + seqs: make([]*Sequence, *parallel), + seqsSem: semaphore.NewWeighted(int64(*parallel)), + status: ServerStatusLoadingModel, + } + + // TODO(jessegross): Parameters that need to be implemented: + // n-gpu-layers + // main-gpu + // flash-attn + // threads + // no-mmap + // mlock + // tensor-split + + /*var tensorSplitFloats []float32 + if *tensorSplit != "" { + stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1) + + tensorSplitFloats = make([]float32, 0, len(stringFloats)) + for _, s := range stringFloats { + f, _ := strconv.ParseFloat(s, 32) + tensorSplitFloats = append(tensorSplitFloats, float32(f)) + } + }*/ + + server.ready.Add(1) + go server.loadModel(*mpath, lpaths, *kvCacheType, *kvSize, *multiUserCache) + + server.cond = sync.NewCond(&server.mu) + + ctx, cancel := context.WithCancel(context.Background()) + go server.run(ctx) + + addr := "127.0.0.1:" + strconv.Itoa(*port) + listener, err := net.Listen("tcp", addr) + if err != nil { + fmt.Println("Listen error:", err) + cancel() + return err + } + defer listener.Close() + + mux := http.NewServeMux() + mux.HandleFunc("/embedding", server.embeddings) + mux.HandleFunc("/completion", server.completion) + mux.HandleFunc("/health", server.health) + + httpServer := http.Server{ + Handler: mux, + } + + log.Println("Server listening on", addr) + if err := httpServer.Serve(listener); err != nil { + log.Fatal("server error:", err) + return err + } + + cancel() + return nil +} diff --git a/llama/runner/cache.go b/runner/oldrunner/cache.go similarity index 99% rename from llama/runner/cache.go rename to runner/oldrunner/cache.go index e8a2d2994..445d0f685 100644 --- a/llama/runner/cache.go +++ b/runner/oldrunner/cache.go @@ -1,4 +1,4 @@ -package runner +package oldrunner import ( "errors" diff --git a/llama/runner/cache_test.go b/runner/oldrunner/cache_test.go similarity index 99% rename from llama/runner/cache_test.go rename to runner/oldrunner/cache_test.go index 9c838ed33..355e3a1e4 100644 --- a/llama/runner/cache_test.go +++ b/runner/oldrunner/cache_test.go @@ -1,4 +1,4 @@ -package runner +package oldrunner import ( "testing" diff --git a/llama/runner/image.go b/runner/oldrunner/image.go similarity index 99% rename from llama/runner/image.go rename to runner/oldrunner/image.go index c1932443c..89e2ae8ac 100644 --- a/llama/runner/image.go +++ b/runner/oldrunner/image.go @@ -1,4 +1,4 @@ -package runner +package oldrunner import ( "errors" diff --git a/llama/runner/image_test.go b/runner/oldrunner/image_test.go similarity index 99% rename from llama/runner/image_test.go rename to runner/oldrunner/image_test.go index d5c3bc1e2..6f5d6829d 100644 --- a/llama/runner/image_test.go +++ b/runner/oldrunner/image_test.go @@ -1,4 +1,4 @@ -package runner +package oldrunner import ( "reflect" diff --git a/llama/runner/runner.go b/runner/oldrunner/runner.go similarity index 98% rename from llama/runner/runner.go rename to runner/oldrunner/runner.go index 60ae88dac..ccfe6107f 100644 --- a/llama/runner/runner.go +++ b/runner/oldrunner/runner.go @@ -1,4 +1,4 @@ -package runner +package oldrunner import ( "context" @@ -24,6 +24,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/llama" + "github.com/ollama/ollama/runner/common" ) // input is an element of the prompt to process, either @@ -498,12 +499,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) seq.pendingResponses = append(seq.pendingResponses, piece) sequence := strings.Join(seq.pendingResponses, "") - if ok, stop := findStop(sequence, seq.stop); ok { + if ok, stop := common.FindStop(sequence, seq.stop); ok { slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) var tokenTruncated bool origLen := len(seq.pendingResponses) - seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop) + seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop) newLen := len(seq.pendingResponses) // Update the cache based on the tokens that will be returned: @@ -524,11 +525,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) continue } - if containsStopSuffix(sequence, seq.stop) { + if common.ContainsStopSuffix(sequence, seq.stop) { continue } - if incompleteUnicode(sequence) { + if common.IncompleteUnicode(sequence) { continue } @@ -885,9 +886,6 @@ func (s *Server) loadModel( } func Execute(args []string) error { - if args[0] == "runner" { - args = args[1:] - } fs := flag.NewFlagSet("runner", flag.ExitOnError) mpath := fs.String("model", "", "Path to model binary file") ppath := fs.String("mmproj", "", "Path to projector binary file") diff --git a/runner/runner.go b/runner/runner.go new file mode 100644 index 000000000..b67e160c9 --- /dev/null +++ b/runner/runner.go @@ -0,0 +1,24 @@ +package runner + +import ( + "github.com/ollama/ollama/runner/newrunner" + "github.com/ollama/ollama/runner/oldrunner" +) + +func Execute(args []string) error { + if args[0] == "runner" { + args = args[1:] + } + + var newRunner bool + if args[0] == "--new-runner" { + args = args[1:] + newRunner = true + } + + if newRunner { + return newrunner.Execute(args) + } else { + return oldrunner.Execute(args) + } +} diff --git a/server/prompt.go b/server/prompt.go index cc69fe8cf..49ecbd8e4 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/model/mllama" "github.com/ollama/ollama/template" @@ -92,26 +93,33 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. var imgData llm.ImageData if isMllama { - data, opts, err := mllama.Preprocess(bytes.NewReader(i)) - if err != nil { - return "", nil, err - } + if envconfig.NewRunners() { + imgData = llm.ImageData{ + ID: len(images), + Data: i, + } + } else { + data, opts, err := mllama.Preprocess(bytes.NewReader(i)) + if err != nil { + return "", nil, err + } - buf := new(bytes.Buffer) - err = binary.Write(buf, binary.LittleEndian, data) - if err != nil { - return "", nil, err - } + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, data) + if err != nil { + return "", nil, err + } - ar, ok := opts["aspectRatioIndex"].(int) - if !ok { - return "", nil, fmt.Errorf("missing aspect ratio for image") - } + ar, ok := opts["aspectRatioIndex"].(int) + if !ok { + return "", nil, fmt.Errorf("missing aspect ratio for image") + } - imgData = llm.ImageData{ - ID: len(images), - Data: buf.Bytes(), - AspectRatioID: ar, + imgData = llm.ImageData{ + ID: len(images), + Data: buf.Bytes(), + AspectRatioID: ar, + } } imgPrompt = "<|image|>" } else { diff --git a/server/routes.go b/server/routes.go index 2670ca954..fe5032a40 100644 --- a/server/routes.go +++ b/server/routes.go @@ -203,7 +203,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { images := make([]llm.ImageData, len(req.Images)) for i := range req.Images { - if isMllama { + if isMllama && !envconfig.NewRunners() { data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i])) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})