diff --git a/integration/README.md b/integration/README.md index e2bdd6b21d..e52ba71ee7 100644 --- a/integration/README.md +++ b/integration/README.md @@ -2,10 +2,13 @@ This directory contains integration tests to exercise Ollama end-to-end to verify behavior -By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` +By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"` The integration tests have 2 modes of operating. 1. By default, they will start the server on a random port, run the tests, and then shutdown the server. -2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote +2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable + +> [!IMPORTANT] +> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree. diff --git a/integration/context_test.go b/integration/context_test.go index b28d11380a..88d29f73f7 100644 --- a/integration/context_test.go +++ b/integration/context_test.go @@ -66,7 +66,7 @@ func TestContextExhaustion(t *testing.T) { DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second) } -// Send multiple requests with prior context and ensure the response is coherant and expected +// Send multiple generate requests with prior context and ensure the response is coherant and expected func TestGenerateWithHistory(t *testing.T) { modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model req, resp := GenerateRequests() @@ -111,5 +111,56 @@ func TestGenerateWithHistory(t *testing.T) { }(i) } wg.Wait() - +} + +// Send multiple chat requests with prior context and ensure the response is coherant and expected +func TestChatWithHistory(t *testing.T) { + modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model + req, resp := ChatRequests() + numParallel := 2 + iterLimit := 2 + + softTimeout, hardTimeout := getTimeouts(t) + ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + // Get the server running (if applicable) warm the model up with a single initial empty request + slog.Info("loading", "model", modelOverride) + err := client.Generate(ctx, + &api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}}, + func(response api.GenerateResponse) error { return nil }, + ) + if err != nil { + t.Fatalf("failed to load model %s: %s", modelOverride, err) + } + + var wg sync.WaitGroup + wg.Add(numParallel) + for i := range numParallel { + go func(i int) { + defer wg.Done() + k := i % len(req) + req[k].Model = modelOverride + for j := 0; j < iterLimit; j++ { + if time.Now().Sub(started) > softTimeout { + slog.Info("exceeded soft timeout, winding down test") + return + } + slog.Info("Starting", "thread", i, "iter", j) + // On slower GPUs it can take a while to process the concurrent requests + // so we allow a much longer initial timeout + assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second) + if assistant == nil { + t.Fatalf("didn't get an assistant response for context") + } + req[k].Messages = append(req[k].Messages, + *assistant, + api.Message{Role: "user", Content: "tell me more!"}, + ) + } + }(i) + } + wg.Wait() } diff --git a/integration/max_queue_test.go b/integration/max_queue_test.go index 7bb9336a0c..e68a08b1cf 100644 --- a/integration/max_queue_test.go +++ b/integration/max_queue_test.go @@ -19,6 +19,8 @@ import ( ) func TestMaxQueue(t *testing.T) { + t.Skip("this test needs to be re-evaluated to use a proper embedding model") + if os.Getenv("OLLAMA_TEST_EXISTING") != "" { t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size") return diff --git a/integration/utils_test.go b/integration/utils_test.go index d7e3790b1a..5f753f8824 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -567,6 +567,76 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { } } +func ChatRequests() ([]api.ChatRequest, [][]string) { + genReqs, results := GenerateRequests() + reqs := make([]api.ChatRequest, len(genReqs)) + for i := range reqs { + reqs[i].Model = genReqs[i].Model + reqs[i].Stream = genReqs[i].Stream + reqs[i].KeepAlive = genReqs[i].KeepAlive + reqs[i].Messages = []api.Message{ + { + Role: "user", + Content: genReqs[i].Prompt, + }, + } + } + return reqs, results +} + +func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message { + stallTimer := time.NewTimer(initialTimeout) + var buf bytes.Buffer + role := "assistant" + fn := func(response api.ChatResponse) error { + // fmt.Print(".") + role = response.Message.Role + buf.Write([]byte(response.Message.Content)) + if !stallTimer.Reset(streamTimeout) { + return errors.New("stall was detected while streaming response, aborting") + } + return nil + } + + stream := true + req.Stream = &stream + done := make(chan int) + var genErr error + go func() { + genErr = client.Chat(ctx, &req, fn) + done <- 0 + }() + + select { + case <-stallTimer.C: + if buf.Len() == 0 { + t.Errorf("generate never started. Timed out after :%s", initialTimeout.String()) + } else { + t.Errorf("generate stalled. Response so far:%s", buf.String()) + } + case <-done: + if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") { + slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr) + return nil + } + require.NoError(t, genErr, "failed with %s request Messages %s ", req.Model, req.Messages) + // Verify the response contains the expected data + response := buf.String() + atLeastOne := false + for _, resp := range anyResp { + if strings.Contains(strings.ToLower(response), resp) { + atLeastOne = true + break + } + } + require.True(t, atLeastOne, "%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages) + slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response) + case <-ctx.Done(): + t.Error("outer test context done while waiting for generate") + } + return &api.Message{Role: role, Content: buf.String()} +} + func skipUnderMinVRAM(t *testing.T, gb uint64) { // TODO use info API in the future if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" { diff --git a/ml/backend.go b/ml/backend.go index 638a05d144..22ac157c7a 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -400,6 +400,8 @@ type Tensor interface { Bytes() []byte Floats() []float32 + BackendSetFromIntSlice(s []int32) + Neg(ctx Context) Tensor Add(ctx Context, t2 Tensor) Tensor Sub(ctx Context, t2 Tensor) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index f8121cc530..98b1d8e0de 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -82,6 +82,7 @@ type Backend struct { // to the name that is used by the model definition tensorLoadTargets map[string][]string + schedMu sync.Mutex // Only one Compute can run at a time sched C.ggml_backend_sched_t schedBackends []C.ggml_backend_t schedBufts []C.ggml_backend_buffer_type_t @@ -769,6 +770,8 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context { } func (c *Context) Compute(tensors ...ml.Tensor) { + c.b.schedMu.Lock() + defer c.b.schedMu.Unlock() if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS { panic(fmt.Errorf("error computing ggml graph: %v", status)) } @@ -1037,6 +1040,12 @@ func (t *Tensor) Floats() (data []float32) { return } +func (t *Tensor) BackendSetFromIntSlice(s []int32) { + if len(s) > 0 { + C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t)) + } +} + func (t *Tensor) DType() ml.DType { switch t.t._type { case C.GGML_TYPE_F32: diff --git a/model/model.go b/model/model.go index d0fe26d7e1..9df46a0ad4 100644 --- a/model/model.go +++ b/model/model.go @@ -64,7 +64,7 @@ type MultimodalProcessor interface { // This function is also responsible for updating MultimodalHash for any Multimodal // that is modified to ensure that there is a unique hash value that accurately // represents the contents. - PostTokenize([]input.Input) ([]input.Input, error) + PostTokenize([]*input.Input) ([]*input.Input, error) } // Base implements the common fields and methods for all models @@ -278,13 +278,13 @@ func canNil(t reflect.Type) bool { t.Kind() == reflect.Slice } -func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) { +func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, ml.Tensor, error) { if len(batch.Positions) != len(batch.Sequences) { - return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences)) + return nil, nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences)) } if len(batch.Positions) < 1 { - return nil, errors.New("batch size cannot be less than 1") + return nil, nil, errors.New("batch size cannot be less than 1") } batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs)) @@ -293,16 +293,16 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten if cache != nil { err := cache.StartForward(ctx, batch, false) if err != nil { - return nil, err + return nil, nil, err } } t, err := m.Forward(ctx, batch) if err != nil { - return nil, err + return nil, nil, err } - ctx.Forward(t).Compute(t) + ctx.Forward(t) - return t, nil + return batch.Inputs, t, nil } diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 53bf827587..3ff7a3ada9 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -112,8 +112,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return []input.Multimodal{{Tensor: visionOutputs}}, nil } -func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { - var result []input.Input +func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { + var result []*input.Input for _, inp := range inputs { if len(inp.Multimodal) == 0 { @@ -122,17 +122,17 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { inputMultimodal := inp.Multimodal[0].Tensor result = append(result, - input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" - input.Input{Token: 255999}, // """ - input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder + &input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n" + &input.Input{Token: 255999}, // """ + &input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder ) // add image token placeholders - result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...) + result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...) result = append(result, - input.Input{Token: 256000}, // - input.Input{Token: 108}, // "\n\n" + &input.Input{Token: 256000}, // + &input.Input{Token: 108}, // "\n\n" ) } } diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index 8084760b0c..187f1d6d7b 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -134,16 +134,16 @@ type separator struct { y bool } -func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { - var result []input.Input +func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { + var result []*input.Input for _, inp := range inputs { if len(inp.Multimodal) == 0 { result = append(result, inp) continue } - var imageInputs []input.Input - imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|> + var imageInputs []*input.Input + imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_start|> for i, mm := range inp.Multimodal { patchesPerChunk := mm.Tensor.Dim(1) @@ -151,20 +151,20 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { if i < len(inp.Multimodal)-1 { separator := mm.Data.(*separator) - imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> - imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) + imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> + imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...) if separator.x { - imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|> + imageInputs = append(imageInputs, &input.Input{Token: 200084}) // <|tile_x_separator|> } if separator.y { - imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|> + imageInputs = append(imageInputs, &input.Input{Token: 200085}) // <|tile_y_separator|> } } else { - imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|> - imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> - imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...) - imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|> + imageInputs = append(imageInputs, &input.Input{Token: 200090}) // <|image|> + imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|> + imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...) + imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_end|> } } diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index 9d662fc110..712b3b505c 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -133,22 +133,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input // [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END] // Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings // that can be processed together. -func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { - var result []input.Input +func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { + var result []*input.Input for _, inp := range inputs { if len(inp.Multimodal) == 0 { result = append(result, inp) } else { for i, row := range inp.Multimodal { // [IMG] - result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)}) - result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...) + result = append(result, &input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)}) + result = append(result, slices.Repeat([]*input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...) if i == len(inp.Multimodal)-1 { // [IMG_END] - result = append(result, input.Input{Token: 13}) + result = append(result, &input.Input{Token: 13}) } else { // [IMG_BREAK] - result = append(result, input.Input{Token: 12}) + result = append(result, &input.Input{Token: 12}) } } } diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 45cb3e02c2..033e369814 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -90,7 +90,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return []input.Multimodal{{Tensor: projectedOutputs}}, nil } -func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { +func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { for i := range inputs { if inputs[i].Multimodal != nil { inputs[i].Token = 128256 // <|image|> diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index ee38cad924..c24e2dccbc 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -89,8 +89,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input } // PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass -func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { - var result []input.Input +func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { + var result []*input.Input var ( imageToken int32 = 151655 @@ -112,16 +112,16 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { return nil, fmt.Errorf("failed to encode image prompt: %w", err) } for i := range pre { - result = append(result, input.Input{Token: pre[i]}) + result = append(result, &input.Input{Token: pre[i]}) } patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1) // First add the vision start token - result = append(result, input.Input{Token: visionStartToken}) + result = append(result, &input.Input{Token: visionStartToken}) // Add the image token with the multimodal tensor data at the first position - result = append(result, input.Input{ + result = append(result, &input.Input{ Token: imageToken, Multimodal: inp.Multimodal, MultimodalHash: inp.MultimodalHash, @@ -129,9 +129,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { }) // Add the placeholder tokens for the remaining positions (tokensPerGrid-1) - result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...) + result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...) - result = append(result, input.Input{Token: visionEndToken}) + result = append(result, &input.Input{Token: visionEndToken}) } } diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index 8c8a29d85f..e974ba4d7b 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -86,7 +86,7 @@ type InputCacheSlot struct { Id int // Inputs that are stored in the KV cache - Inputs []input.Input + Inputs []*input.Input // is this cache actively being processed as part of a sequence? InUse bool @@ -95,7 +95,7 @@ type InputCacheSlot struct { lastUsed time.Time } -func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) { +func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*input.Input, error) { var slot *InputCacheSlot var numPast int32 var err error @@ -146,7 +146,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp return slot, prompt, nil } -func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { +func (c *InputCache) findLongestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) { longest := int32(-1) var longestSlot *InputCacheSlot @@ -169,7 +169,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot return longestSlot, longest, nil } -func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { +func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) { oldest := time.Now() var oldestSlot *InputCacheSlot @@ -205,7 +205,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i if longest > 0 && longestSlot != oldestSlot { slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total", len(longestSlot.Inputs)) - oldestSlot.Inputs = make([]input.Input, longest) + oldestSlot.Inputs = make([]*input.Input, longest) copy(oldestSlot.Inputs, longestSlot.Inputs[:longest]) if c.cache != nil { c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest) @@ -215,7 +215,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i return oldestSlot, longest, nil } -func countCommonPrefix(a []input.Input, b []input.Input) int32 { +func countCommonPrefix(a []*input.Input, b []*input.Input) int32 { var count int32 for i := range a { @@ -250,7 +250,7 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 { } type ErrReprocessInputs struct { - Inputs []input.Input + Inputs []*input.Input } func (e *ErrReprocessInputs) Error() string { @@ -283,13 +283,13 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error { "id", slot.Id, "error", err) // Create new input slice with preserved tokens (numKeep + remaining tokens after discard) - newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard)) + newInputs := make([]*input.Input, numKeep+inputLen-(numKeep+discard)) copy(newInputs[:numKeep], slot.Inputs[:numKeep]) copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:]) // Reset the cache _ = c.cache.Remove(slot.Id, 0, math.MaxInt32) - slot.Inputs = []input.Input{} + slot.Inputs = []*input.Input{} // Return error with inputs that need to be reprocessed return &ErrReprocessInputs{Inputs: newInputs} diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 6897b5e463..49cb6c5474 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -13,50 +13,50 @@ import ( func TestCountCommon(t *testing.T) { tests := []struct { name string - t1 []input.Input - t2 []input.Input + t1 []*input.Input + t2 []*input.Input expected int32 }{ { name: "Equal", - t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, - t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t1: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 3, }, { name: "Prefix", - t1: []input.Input{{Token: 1}}, - t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t1: []*input.Input{{Token: 1}}, + t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 1, }, { name: "Image Prefix", - t1: []input.Input{{MultimodalHash: 1}}, - t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}}, + t1: []*input.Input{{MultimodalHash: 1}}, + t2: []*input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}}, expected: 1, }, { name: "Mixed", - t1: []input.Input{{Token: 1}, {MultimodalHash: 1}}, - t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}}, + t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}}, + t2: []*input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}}, expected: 2, }, { name: "Mixed, Same Length", - t1: []input.Input{{Token: 1}, {MultimodalHash: 1}}, - t2: []input.Input{{Token: 1}, {MultimodalHash: 2}}, + t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}}, + t2: []*input.Input{{Token: 1}, {MultimodalHash: 2}}, expected: 1, }, { name: "Empty", - t1: []input.Input{}, - t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t1: []*input.Input{}, + t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 0, }, { name: "Both Empty", - t1: []input.Input{}, - t2: []input.Input{}, + t1: []*input.Input{}, + t2: []*input.Input{}, expected: 0, }, } @@ -80,7 +80,7 @@ func TestFindCacheSlot(t *testing.T) { tests := []struct { name string cache InputCache - prompt []input.Input + prompt []*input.Input longest expected best expected }{ @@ -89,18 +89,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{}, + Inputs: []*input.Input{}, InUse: false, lastUsed: time.Time{}, }, { Id: 1, - Inputs: []input.Input{}, + Inputs: []*input.Input{}, InUse: false, lastUsed: time.Time{}, }, }}, - prompt: []input.Input{{Token: 1}}, + prompt: []*input.Input{{Token: 1}}, longest: expected{result: 0, len: 0}, best: expected{result: 0, len: 0}, }, @@ -109,18 +109,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}}, + Inputs: []*input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []input.Input{{Token: 1}, {Token: 2}}, + prompt: []*input.Input{{Token: 1}, {Token: 2}}, longest: expected{result: 1, len: 2}, best: expected{result: 1, len: 2}, }, @@ -129,18 +129,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input.Input{}, + Inputs: []*input.Input{}, InUse: false, lastUsed: time.Time{}, }, }}, - prompt: []input.Input{{Token: 2}}, + prompt: []*input.Input{{Token: 2}}, longest: expected{result: 0, len: 0}, best: expected{result: 1, len: 0}, }, @@ -150,19 +150,19 @@ func TestFindCacheSlot(t *testing.T) { slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input.Input{}, + Inputs: []*input.Input{}, InUse: false, lastUsed: time.Time{}, }, }, }, - prompt: []input.Input{{Token: 1}}, + prompt: []*input.Input{{Token: 1}}, longest: expected{result: 0, len: 1}, best: expected{result: 1, len: 1}, }, @@ -171,18 +171,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}}, + Inputs: []*input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []input.Input{{Token: 2}, {Token: 3}}, + prompt: []*input.Input{{Token: 2}, {Token: 3}}, longest: expected{result: 0, len: 0}, best: expected{result: 1, len: 0}, }, @@ -191,18 +191,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: true, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input.Input{{Token: 1}}, + Inputs: []*input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []input.Input{{Token: 1}, {Token: 2}}, + prompt: []*input.Input{{Token: 1}, {Token: 2}}, longest: expected{result: 1, len: 1}, best: expected{result: 1, len: 2}, }, @@ -300,7 +300,7 @@ func TestLoadCacheSlot(t *testing.T) { tests := []struct { name string cache InputCache - prompt []input.Input + prompt []*input.Input wantErr bool expectedSlotId int expectedPrompt int // expected length of remaining prompt @@ -312,19 +312,19 @@ func TestLoadCacheSlot(t *testing.T) { slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input.Input{}, + Inputs: []*input.Input{}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }, }, - prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, wantErr: false, expectedSlotId: 0, expectedPrompt: 1, // Only token 3 remains @@ -336,19 +336,19 @@ func TestLoadCacheSlot(t *testing.T) { slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []input.Input{}, + Inputs: []*input.Input{}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }, }, - prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, wantErr: false, expectedSlotId: 0, expectedPrompt: 1, // Only token 3 remains @@ -360,13 +360,13 @@ func TestLoadCacheSlot(t *testing.T) { slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, }, }, - prompt: []input.Input{{Token: 1}, {Token: 2}}, + prompt: []*input.Input{{Token: 1}, {Token: 2}}, wantErr: false, expectedSlotId: 0, expectedPrompt: 1, // Should leave 1 token for sampling @@ -378,13 +378,13 @@ func TestLoadCacheSlot(t *testing.T) { slots: []InputCacheSlot{ { Id: 0, - Inputs: []input.Input{{Token: 1}, {Token: 2}}, + Inputs: []*input.Input{{Token: 1}, {Token: 2}}, InUse: true, lastUsed: time.Now().Add(-time.Second), }, }, }, - prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, wantErr: true, expectedSlotId: -1, expectedPrompt: -1, @@ -452,7 +452,7 @@ func TestShiftCacheSlot(t *testing.T) { tests := []struct { name string numCtx int32 - inputs []input.Input + inputs []*input.Input numKeep int32 cacheErr bool wantErr any @@ -461,7 +461,7 @@ func TestShiftCacheSlot(t *testing.T) { { name: "Normal shift", numCtx: 10, - inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, + inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, numKeep: 2, cacheErr: false, // No error wantErr: nil, @@ -470,7 +470,7 @@ func TestShiftCacheSlot(t *testing.T) { { name: "Cache removal fails", numCtx: 10, - inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, + inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}}, numKeep: 2, cacheErr: true, wantErr: &ErrReprocessInputs{}, @@ -487,7 +487,7 @@ func TestShiftCacheSlot(t *testing.T) { } slot := &InputCacheSlot{ Id: 123, - Inputs: make([]input.Input, len(tt.inputs)), + Inputs: make([]*input.Input, len(tt.inputs)), } copy(slot.Inputs, tt.inputs) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 2f41f68f22..25f0633b03 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -17,6 +17,7 @@ import ( "reflect" "regexp" "runtime" + "runtime/debug" "strconv" "strings" "sync" @@ -51,10 +52,10 @@ type Sequence struct { iBatch int // prompt inputs left to evaluate - inputs []input.Input + inputs []*input.Input // inputs that have been added to a batch but not yet submitted to Forward - pendingInputs []input.Input + pendingInputs []*input.Input // tokens that have been generated but not returned yet (e.g. for stop sequences) pendingResponses []string @@ -86,6 +87,12 @@ type Sequence struct { // true if an embedding are to be returned instead of text generation embeddingOnly bool + // true if the sequence if finished and marked for removal on next pass + finished bool + + // True if we have to skip this sequence to shift the cache + skipForShift bool + doneReason llm.DoneReason // Metrics @@ -182,8 +189,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // decoding images -func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) { - var inputs []input.Input +func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) { + var inputs []*input.Input var ctxs []ml.Context var mmStore multimodalStore @@ -210,7 +217,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ } for _, t := range tokens { - inputs = append(inputs, input.Input{Token: t}) + inputs = append(inputs, &input.Input{Token: t}) } // image - decode and store @@ -243,7 +250,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ mmStore.addMultimodal(imageEmbeddings) - inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) + inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) postTokenize = true } } @@ -259,6 +266,27 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ return inputs, ctxs, mmStore, nil } +type batchState struct { + id int + ctx ml.Context + modelInput ml.Tensor + modelOutput ml.Tensor + batchInputs []*input.Input + batch input.Batch + seqs []*Sequence // full set of seqs at the time this batch was initiated + initSeqIdx int // The initial value for the set of sequences evaluated (s.nextSeq - 1) + + // Signaled when this batches inputs are ready and compute can proceed + inputsReadyCh chan struct{} + + // Signaling when Compute is about to begin on this batch, and + // seqs have been updated to prepare for the next batch + computeStartedCh chan struct{} + + // Signaled when this batches outputs are complete and the next batch can proceed + outputsReadyCh chan struct{} +} + type Server struct { // modelPath is the location of the model to be loaded modelPath string @@ -290,6 +318,16 @@ type Server struct { // TODO (jmorganca): make this n_batch batchSize int + // Used to signal a hard failure during async processing which will panic the runner + hardErrCh chan error + + // A prior batch that's still being processed + // only read or written by forwardBatch + pendingBatch *batchState + + // Simple counter used only for trace logging batches + batchID int + // protects access to everything below this line // this is context state needed for decoding mu sync.Mutex @@ -350,45 +388,132 @@ func flushPending(seq *Sequence) bool { } } -func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) { +func (s *Server) finishSequence(seqIndex int, reason llm.DoneReason) { seq := s.seqs[seqIndex] + // finish could be called multiple times since we prepare 1 batch ahead + // and multiple scenarios can lead to finishing a sequence + // ensure only the first finish called is processed + if seq.finished { + return + } + flushPending(seq) seq.doneReason = reason + seq.finished = true close(seq.responses) close(seq.embedding) seq.cache.InUse = false +} + +func (s *Server) removeFinishedSequence(seqIndex int) { s.seqs[seqIndex] = nil s.seqsSem.Release(1) } +// track batch state between forwardBatch, computeBatch and predictForwardBatch + func (s *Server) run(ctx context.Context) { s.ready.Wait() + var bs *batchState for { select { case <-ctx.Done(): return + case err := <-s.hardErrCh: + panic(err) default: - err := s.processBatch() + var err error + bs, err = s.forwardBatch() if err != nil { panic(err) } + if bs == nil { + continue + } + go s.computeBatch(bs) } } } -func (s *Server) processBatch() error { +// forwardBatch will calculate a batch. +func (s *Server) forwardBatch() (*batchState, error) { + inputsReady := false + var inputsReadyCh chan struct{} + + // If we have a pending batch still processing, wait until Compute has started + // before setting up the next batch so the seqs inputs are ready to receive their + // token values and we get the correct input pointers for the batchInputs + if s.pendingBatch != nil { + slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch waiting for compute to start", "pendingBatch.id", s.pendingBatch.id) + <-s.pendingBatch.computeStartedCh + slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch compute started, setting up next batch", "pendingBatch.id", s.pendingBatch.id, "id", s.batchID) + inputsReadyCh = s.pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch + } else { + slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no pending batch detected", "batchID", s.batchID) + inputsReady = true // No pendingBatch, so the inputs will be ready in the seqs immediately + inputsReadyCh = make(chan struct{}, 1) + } + s.mu.Lock() for s.allNil() { s.cond.Wait() // Wait until an item is added } defer s.mu.Unlock() - ctx := s.model.Backend().NewContext() - defer ctx.Close() + // If new sequences have been added with an active batch we delay preparing the next batch + // until Compute has finished + if s.pendingBatch != nil { + for seqIdx := range s.seqs { + if s.seqs[seqIdx] != s.pendingBatch.seqs[seqIdx] { + slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch seqs changed, waiting for compute to finish to pick up new sequence(s)", "pendingBatch.id", s.pendingBatch.id) + s.mu.Unlock() // release the lock so computeBatch can finish up + <-s.pendingBatch.outputsReadyCh + slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch pending batch outputs ready", "pendingBatch.id", s.pendingBatch.id) + s.mu.Lock() + inputsReady = true // pendingBatch completed, so the inputs are ready in the seqs + break + } + } + } + // Clear pending Batch - we'll set it if we have a batch with any inputs + s.pendingBatch = nil - var batchInputs []int32 + // Remove any finished sequences before recording the active set of seqs in the batch + for seqIdx := range s.seqs { + seq := s.seqs[seqIdx] + if seq == nil { + continue + } + if seq.finished { + s.removeFinishedSequence(seqIdx) + continue + } + if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { + s.finishSequence(seqIdx, llm.DoneReasonLength) + s.removeFinishedSequence(seqIdx) + continue + } + } + + // next batch + nb := &batchState{ + id: s.batchID, + initSeqIdx: s.nextSeq - 1, + seqs: make([]*Sequence, len(s.seqs)), + inputsReadyCh: inputsReadyCh, + computeStartedCh: make(chan struct{}, 1), + outputsReadyCh: make(chan struct{}, 1), + } + ctx := s.model.Backend().NewContext() + nb.ctx = ctx + + // Record the sequences at the time we create the batch so we can detect if new sequences are added on the next pass + copy(nb.seqs, s.seqs) + + // Prepare the seqs and batch, but defer the input token values as we may not be ready yet + var batchInputs []*input.Input var batch input.Batch resumeSeq := -1 @@ -396,20 +521,13 @@ func (s *Server) processBatch() error { for range s.seqs { seqIdx = (seqIdx + 1) % len(s.seqs) seq := s.seqs[seqIdx] - if seq == nil { continue } - // if past the num predict limit - if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { - s.removeSequence(seqIdx, llm.DoneReasonLength) - continue - } - if !s.cache.enabled { seq.inputs = append(seq.cache.Inputs, seq.inputs...) - seq.cache.Inputs = []input.Input{} + seq.cache.Inputs = []*input.Input{} } batchSize := s.batchSize @@ -449,18 +567,21 @@ func (s *Server) processBatch() error { // Prepend these inputs to the sequence's inputs queue for reprocessing seq.inputs = append(reprocess.Inputs, seq.inputs...) // Skip this sequence but continue processing the rest + seq.skipForShift = true // cleared in computeBatch below for the next batch continue } else { - return err + ctx.Close() + return nil, err } } } - batchInputs = append(batchInputs, inp.Token) + batchInputs = append(batchInputs, seq.inputs[i]) if inp.Multimodal != nil { mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false) if err != nil { - return err + ctx.Close() + return nil, err } batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm}) } @@ -468,10 +589,13 @@ func (s *Server) processBatch() error { batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) batch.Sequences = append(batch.Sequences, seq.cache.Id) + // TODO BUG HERE!!! + // Somehow sometimes iBatch isn't set correctly seq.iBatch = len(batch.Outputs) if i+1 == len(seq.inputs) { batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1)) } + slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs)) seq.pendingInputs = append(seq.pendingInputs, inp) } @@ -485,36 +609,138 @@ func (s *Server) processBatch() error { } if len(batchInputs) == 0 { - return nil + slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no batchInputs, going idle", "batchID", s.batchID) + ctx.Close() + return nil, nil } + s.batchID++ - modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch) + var err error + // Actual batchInputs values will be injected into the modelInput tensor before calling Compute + nb.modelInput, nb.modelOutput, err = model.Forward(ctx, s.model, make([]int32, len(batchInputs)), batch) if err != nil { - return fmt.Errorf("failed to decode batch: %w", err) + ctx.Close() + return nil, fmt.Errorf("failed to build graph: %w", err) + } + nb.batchInputs = batchInputs + nb.batch = batch + + // computeBatch will close the context in the batch upon completion + s.pendingBatch = nb + + if inputsReady { + nb.inputsReadyCh <- struct{}{} } - logits := modelOutput.Floats() + return nb, nil +} +// Async processing of the next batch +func (s *Server) computeBatch(bs *batchState) { + if bs == nil || bs.ctx == nil { + // Nothing to compute + return + } + defer bs.ctx.Close() + + // Wait until inputs are ready + slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: waiting for inputs to be ready", "batchID", bs.id) + <-bs.inputsReadyCh + slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: inputs are ready", "batchID", bs.id) + + // Once we complete, signal the next batch of inputs are ready + // This will unblock the next computeBatch, or forwardBatch if new seqs come in + defer func() { + slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: outputs are ready", "batchID", bs.id) + bs.outputsReadyCh <- struct{}{} + }() + + s.mu.Lock() + + // Gather the actual input token values now that they're ready + batchInputs := make([]int32, len(bs.batchInputs)) + for i := range batchInputs { + batchInputs[i] = bs.batchInputs[i].Token + } + + // TODO the following logic could be run in a go routine to possibly speed up getting to Compute + + // Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens + // so that forwardBatch can build a batchInputs set which will eventually contain the actual + // decoded tokens. + promptProcessing := make([]bool, len(s.seqs)) // track seq's we skip + nextBatchTokens := make([]*input.Input, len(s.seqs)) + iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock for i, seq := range s.seqs { + iBatches[i] = -1 if seq == nil { continue } + // Skip over any newly added sequences + if bs.seqs[i] == nil { + continue + } // After calling Forward, pending inputs are now in the cache if len(seq.pendingInputs) > 0 { seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) - seq.pendingInputs = []input.Input{} + seq.pendingInputs = []*input.Input{} } // don't sample prompt processing if len(seq.inputs) != 0 { if !s.cache.enabled { - return errors.New("caching disabled but unable to fit entire input in a batch") + s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch") + return } + // Record so we can skip during Decode + promptProcessing[i] = true continue } seq.numPredicted++ + nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats + seq.inputs = []*input.Input{nextToken} + nextBatchTokens[i] = nextToken + iBatches[i] = seq.iBatch + } + + // At this point the seqs are ready for forwardBatch to move forward so unblock + s.mu.Unlock() + slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: signaling computeStartedCh", "batchID", bs.id) + bs.computeStartedCh <- struct{}{} + + bs.modelInput.BackendSetFromIntSlice(batchInputs) + bs.ctx.Compute(bs.modelOutput) + logits := bs.modelOutput.Floats() + + slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: logits ready", "batchID", bs.id) + + s.mu.Lock() + defer s.mu.Unlock() + + slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: decoding", "batchID", bs.id) + for i, seq := range s.seqs { + if seq == nil { + continue + } + // Skip over any newly added sequences + if bs.seqs[i] == nil { + continue + } + + // Detect if the sequence we're processing has already been completed and replaced + // with a new sequence + if seq != bs.seqs[i] { + slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: sequence replaced, discarding its results", "batchID", bs.id, "seqIdx", i) + continue + } + + // don't sample prompt processing + if promptProcessing[i] { + continue + } + if seq.numPredicted == 1 { seq.startGenerationTime = time.Now() } @@ -522,35 +748,46 @@ func (s *Server) processBatch() error { // if done processing the prompt, generate an embedding and return if seq.embeddingOnly { // TODO(jessegross): Embedding support - slog.Warn("generation of embedding outputs not yet supported") - s.removeSequence(i, llm.DoneReasonStop) + slog.Warn("generation of embedding outputs not yet supported", "id", bs.id, "seqIdx", i) + s.finishSequence(i, llm.DoneReasonStop) continue } // sample a token - vocabSize := len(logits) / len(batch.Outputs) - - token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize]) + vocabSize := len(logits) / len(bs.batch.Outputs) + slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: vocab details", "batchID", bs.id, "seqIdx", i, "len(logits)", len(logits), "len(bs.batch.Outputs)", len(bs.batch.Outputs), "vocabSize", vocabSize, "seq.iBatch", seq.iBatch) + token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]) if err != nil { - return fmt.Errorf("failed to sample token: %w", err) + s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err) + return } + nextBatchTokens[i].Token = token + // if it's an end of sequence token, break if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) { // TODO (jmorganca): we should send this back // as it's important for the /api/generate context // seq.responses <- piece - - s.removeSequence(i, llm.DoneReasonStop) + slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: EOS", "batchID", bs.id, "seqIdx", i) + s.finishSequence(i, llm.DoneReasonStop) continue } piece, err := s.model.(model.TextProcessor).Decode([]int32{token}) if err != nil { - return err + s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err) + return } - seq.inputs = []input.Input{{Token: token}} + if nextBatchTokens[i] == nil { + slog.Error("batch corrupted", "id", bs.id, "batch", bs.batch, "seqIdx", i, "seq", seq) + s.hardErrCh <- fmt.Errorf("expected a single token during decode") + return + } + + // fill in the final selected token value to replace the placeholder in the next batch + // nextBatchTokensWritten++ seq.pendingResponses = append(seq.pendingResponses, piece) sequence := strings.Join(seq.pendingResponses, "") @@ -575,9 +812,10 @@ func (s *Server) processBatch() error { if tokenTruncated || origLen == newLen { tokenLen-- } + seq.cache.Inputs = seq.cache.Inputs[:tokenLen] - s.removeSequence(i, llm.DoneReasonStop) + s.finishSequence(i, llm.DoneReasonStop) continue } @@ -590,11 +828,9 @@ func (s *Server) processBatch() error { } if !flushPending(seq) { - s.removeSequence(i, llm.DoneReasonConnectionClosed) + s.finishSequence(i, llm.DoneReasonConnectionClosed) } } - - return nil } func (s *Server) completion(w http.ResponseWriter, r *http.Request) { @@ -736,7 +972,10 @@ func (s *Server) reserveWorstCaseGraph() error { defer ctx.Close() var err error - inputs := make([]input.Input, s.batchSize) + inputs := make([]*input.Input, s.batchSize) + for i := range inputs { + inputs[i] = &input.Input{} + } mmStore := newMultimodalStore() // Multimodal strategy: @@ -778,8 +1017,11 @@ func (s *Server) reserveWorstCaseGraph() error { } if len(inputs) < s.batchSize { - newInputs := make([]input.Input, s.batchSize) + newInputs := make([]*input.Input, s.batchSize) copy(newInputs, inputs) + for i := len(inputs); i < s.batchSize; i++ { + newInputs[i] = &input.Input{} + } inputs = newInputs } } @@ -842,6 +1084,7 @@ func (s *Server) allocModel( // Convert memory allocation panics to errors defer func() { if r := recover(); r != nil { + debug.PrintStack() if err, ok := r.(error); ok { panicErr = err } else { @@ -1011,6 +1254,7 @@ func Execute(args []string) error { server := &Server{ modelPath: *mpath, status: llm.ServerStatusLaunched, + hardErrCh: make(chan error, 1), } server.cond = sync.NewCond(&server.mu)