perf: build graph for next batch in parallel to keep GPU busy

This refactors the main run loop of the ollama runner to perform the main GPU
intensive tasks (Compute+Floats) in a go routine so we can prepare the next
batch in parallel to reduce the amount of time the GPU stalls waiting for the
next batch of work.
This commit is contained in:
Daniel Hiltgen
2025-08-13 12:32:25 -07:00
parent f804e8a460
commit 31f64183dc
15 changed files with 531 additions and 150 deletions

View File

@@ -2,10 +2,13 @@
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...`
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"`
The integration tests have 2 modes of operating.
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable
> [!IMPORTANT]
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.

View File

@@ -66,7 +66,7 @@ func TestContextExhaustion(t *testing.T) {
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
}
// Send multiple requests with prior context and ensure the response is coherant and expected
// Send multiple generate requests with prior context and ensure the response is coherant and expected
func TestGenerateWithHistory(t *testing.T) {
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
req, resp := GenerateRequests()
@@ -111,5 +111,56 @@ func TestGenerateWithHistory(t *testing.T) {
}(i)
}
wg.Wait()
}
// Send multiple chat requests with prior context and ensure the response is coherant and expected
func TestChatWithHistory(t *testing.T) {
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
req, resp := ChatRequests()
numParallel := 2
iterLimit := 2
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Get the server running (if applicable) warm the model up with a single initial empty request
slog.Info("loading", "model", modelOverride)
err := client.Generate(ctx,
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", modelOverride, err)
}
var wg sync.WaitGroup
wg.Add(numParallel)
for i := range numParallel {
go func(i int) {
defer wg.Done()
k := i % len(req)
req[k].Model = modelOverride
for j := 0; j < iterLimit; j++ {
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
return
}
slog.Info("Starting", "thread", i, "iter", j)
// On slower GPUs it can take a while to process the concurrent requests
// so we allow a much longer initial timeout
assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
if assistant == nil {
t.Fatalf("didn't get an assistant response for context")
}
req[k].Messages = append(req[k].Messages,
*assistant,
api.Message{Role: "user", Content: "tell me more!"},
)
}
}(i)
}
wg.Wait()
}

View File

@@ -19,6 +19,8 @@ import (
)
func TestMaxQueue(t *testing.T) {
t.Skip("this test needs to be re-evaluated to use a proper embedding model")
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
return

View File

@@ -567,6 +567,76 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
}
}
func ChatRequests() ([]api.ChatRequest, [][]string) {
genReqs, results := GenerateRequests()
reqs := make([]api.ChatRequest, len(genReqs))
for i := range reqs {
reqs[i].Model = genReqs[i].Model
reqs[i].Stream = genReqs[i].Stream
reqs[i].KeepAlive = genReqs[i].KeepAlive
reqs[i].Messages = []api.Message{
{
Role: "user",
Content: genReqs[i].Prompt,
},
}
}
return reqs, results
}
func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message {
stallTimer := time.NewTimer(initialTimeout)
var buf bytes.Buffer
role := "assistant"
fn := func(response api.ChatResponse) error {
// fmt.Print(".")
role = response.Message.Role
buf.Write([]byte(response.Message.Content))
if !stallTimer.Reset(streamTimeout) {
return errors.New("stall was detected while streaming response, aborting")
}
return nil
}
stream := true
req.Stream = &stream
done := make(chan int)
var genErr error
go func() {
genErr = client.Chat(ctx, &req, fn)
done <- 0
}()
select {
case <-stallTimer.C:
if buf.Len() == 0 {
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
} else {
t.Errorf("generate stalled. Response so far:%s", buf.String())
}
case <-done:
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr)
return nil
}
require.NoError(t, genErr, "failed with %s request Messages %s ", req.Model, req.Messages)
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
require.True(t, atLeastOne, "%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
}
return &api.Message{Role: role, Content: buf.String()}
}
func skipUnderMinVRAM(t *testing.T, gb uint64) {
// TODO use info API in the future
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {

View File

@@ -400,6 +400,8 @@ type Tensor interface {
Bytes() []byte
Floats() []float32
BackendSetFromIntSlice(s []int32)
Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor

View File

@@ -82,6 +82,7 @@ type Backend struct {
// to the name that is used by the model definition
tensorLoadTargets map[string][]string
schedMu sync.Mutex // Only one Compute can run at a time
sched C.ggml_backend_sched_t
schedBackends []C.ggml_backend_t
schedBufts []C.ggml_backend_buffer_type_t
@@ -769,6 +770,8 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
}
func (c *Context) Compute(tensors ...ml.Tensor) {
c.b.schedMu.Lock()
defer c.b.schedMu.Unlock()
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
panic(fmt.Errorf("error computing ggml graph: %v", status))
}
@@ -1037,6 +1040,12 @@ func (t *Tensor) Floats() (data []float32) {
return
}
func (t *Tensor) BackendSetFromIntSlice(s []int32) {
if len(s) > 0 {
C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t))
}
}
func (t *Tensor) DType() ml.DType {
switch t.t._type {
case C.GGML_TYPE_F32:

View File

@@ -64,7 +64,7 @@ type MultimodalProcessor interface {
// This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately
// represents the contents.
PostTokenize([]input.Input) ([]input.Input, error)
PostTokenize([]*input.Input) ([]*input.Input, error)
}
// Base implements the common fields and methods for all models
@@ -278,13 +278,13 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice
}
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, ml.Tensor, error) {
if len(batch.Positions) != len(batch.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
return nil, nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
}
if len(batch.Positions) < 1 {
return nil, errors.New("batch size cannot be less than 1")
return nil, nil, errors.New("batch size cannot be less than 1")
}
batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
@@ -293,16 +293,16 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
if cache != nil {
err := cache.StartForward(ctx, batch, false)
if err != nil {
return nil, err
return nil, nil, err
}
}
t, err := m.Forward(ctx, batch)
if err != nil {
return nil, err
return nil, nil, err
}
ctx.Forward(t).Compute(t)
ctx.Forward(t)
return t, nil
return batch.Inputs, t, nil
}

View File

@@ -112,8 +112,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return []input.Multimodal{{Tensor: visionOutputs}}, nil
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
@@ -122,17 +122,17 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
inputMultimodal := inp.Multimodal[0].Tensor
result = append(result,
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
input.Input{Token: 255999}, // "<start_of_image>""
input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
&input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
&input.Input{Token: 255999}, // "<start_of_image>""
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
)
// add image token placeholders
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
result = append(result,
input.Input{Token: 256000}, // <end_of_image>
input.Input{Token: 108}, // "\n\n"
&input.Input{Token: 256000}, // <end_of_image>
&input.Input{Token: 108}, // "\n\n"
)
}
}

View File

@@ -134,16 +134,16 @@ type separator struct {
y bool
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
result = append(result, inp)
continue
}
var imageInputs []input.Input
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
var imageInputs []*input.Input
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_start|>
for i, mm := range inp.Multimodal {
patchesPerChunk := mm.Tensor.Dim(1)
@@ -151,20 +151,20 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
if i < len(inp.Multimodal)-1 {
separator := mm.Data.(*separator)
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
if separator.x {
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
imageInputs = append(imageInputs, &input.Input{Token: 200084}) // <|tile_x_separator|>
}
if separator.y {
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
imageInputs = append(imageInputs, &input.Input{Token: 200085}) // <|tile_y_separator|>
}
} else {
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
imageInputs = append(imageInputs, &input.Input{Token: 200090}) // <|image|>
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_end|>
}
}

View File

@@ -133,22 +133,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
// that can be processed together.
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
result = append(result, inp)
} else {
for i, row := range inp.Multimodal {
// [IMG]
result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
result = append(result, &input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
result = append(result, slices.Repeat([]*input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
if i == len(inp.Multimodal)-1 {
// [IMG_END]
result = append(result, input.Input{Token: 13})
result = append(result, &input.Input{Token: 13})
} else {
// [IMG_BREAK]
result = append(result, input.Input{Token: 12})
result = append(result, &input.Input{Token: 12})
}
}
}

View File

@@ -90,7 +90,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return []input.Multimodal{{Tensor: projectedOutputs}}, nil
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
for i := range inputs {
if inputs[i].Multimodal != nil {
inputs[i].Token = 128256 // <|image|>

View File

@@ -89,8 +89,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
}
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
var (
imageToken int32 = 151655
@@ -112,16 +112,16 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
}
for i := range pre {
result = append(result, input.Input{Token: pre[i]})
result = append(result, &input.Input{Token: pre[i]})
}
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
// First add the vision start token
result = append(result, input.Input{Token: visionStartToken})
result = append(result, &input.Input{Token: visionStartToken})
// Add the image token with the multimodal tensor data at the first position
result = append(result, input.Input{
result = append(result, &input.Input{
Token: imageToken,
Multimodal: inp.Multimodal,
MultimodalHash: inp.MultimodalHash,
@@ -129,9 +129,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
})
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
result = append(result, input.Input{Token: visionEndToken})
result = append(result, &input.Input{Token: visionEndToken})
}
}

View File

@@ -86,7 +86,7 @@ type InputCacheSlot struct {
Id int
// Inputs that are stored in the KV cache
Inputs []input.Input
Inputs []*input.Input
// is this cache actively being processed as part of a sequence?
InUse bool
@@ -95,7 +95,7 @@ type InputCacheSlot struct {
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*input.Input, error) {
var slot *InputCacheSlot
var numPast int32
var err error
@@ -146,7 +146,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
return slot, prompt, nil
}
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
func (c *InputCache) findLongestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
longest := int32(-1)
var longestSlot *InputCacheSlot
@@ -169,7 +169,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot
return longestSlot, longest, nil
}
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
oldest := time.Now()
var oldestSlot *InputCacheSlot
@@ -205,7 +205,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
if longest > 0 && longestSlot != oldestSlot {
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
len(longestSlot.Inputs))
oldestSlot.Inputs = make([]input.Input, longest)
oldestSlot.Inputs = make([]*input.Input, longest)
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
if c.cache != nil {
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
@@ -215,7 +215,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
return oldestSlot, longest, nil
}
func countCommonPrefix(a []input.Input, b []input.Input) int32 {
func countCommonPrefix(a []*input.Input, b []*input.Input) int32 {
var count int32
for i := range a {
@@ -250,7 +250,7 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
}
type ErrReprocessInputs struct {
Inputs []input.Input
Inputs []*input.Input
}
func (e *ErrReprocessInputs) Error() string {
@@ -283,13 +283,13 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
"id", slot.Id, "error", err)
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
newInputs := make([]*input.Input, numKeep+inputLen-(numKeep+discard))
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
// Reset the cache
_ = c.cache.Remove(slot.Id, 0, math.MaxInt32)
slot.Inputs = []input.Input{}
slot.Inputs = []*input.Input{}
// Return error with inputs that need to be reprocessed
return &ErrReprocessInputs{Inputs: newInputs}

View File

@@ -13,50 +13,50 @@ import (
func TestCountCommon(t *testing.T) {
tests := []struct {
name string
t1 []input.Input
t2 []input.Input
t1 []*input.Input
t2 []*input.Input
expected int32
}{
{
name: "Equal",
t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t1: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 3,
},
{
name: "Prefix",
t1: []input.Input{{Token: 1}},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t1: []*input.Input{{Token: 1}},
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 1,
},
{
name: "Image Prefix",
t1: []input.Input{{MultimodalHash: 1}},
t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
t1: []*input.Input{{MultimodalHash: 1}},
t2: []*input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
expected: 1,
},
{
name: "Mixed",
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []*input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
expected: 2,
},
{
name: "Mixed, Same Length",
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {MultimodalHash: 2}},
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []*input.Input{{Token: 1}, {MultimodalHash: 2}},
expected: 1,
},
{
name: "Empty",
t1: []input.Input{},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t1: []*input.Input{},
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 0,
},
{
name: "Both Empty",
t1: []input.Input{},
t2: []input.Input{},
t1: []*input.Input{},
t2: []*input.Input{},
expected: 0,
},
}
@@ -80,7 +80,7 @@ func TestFindCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
prompt []input.Input
prompt []*input.Input
longest expected
best expected
}{
@@ -89,18 +89,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{},
Inputs: []*input.Input{},
InUse: false,
lastUsed: time.Time{},
},
{
Id: 1,
Inputs: []input.Input{},
Inputs: []*input.Input{},
InUse: false,
lastUsed: time.Time{},
},
}},
prompt: []input.Input{{Token: 1}},
prompt: []*input.Input{{Token: 1}},
longest: expected{result: 0, len: 0},
best: expected{result: 0, len: 0},
},
@@ -109,18 +109,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}},
Inputs: []*input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input.Input{{Token: 1}, {Token: 2}},
prompt: []*input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 2},
best: expected{result: 1, len: 2},
},
@@ -129,18 +129,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
Inputs: []*input.Input{},
InUse: false,
lastUsed: time.Time{},
},
}},
prompt: []input.Input{{Token: 2}},
prompt: []*input.Input{{Token: 2}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
@@ -150,19 +150,19 @@ func TestFindCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
Inputs: []*input.Input{},
InUse: false,
lastUsed: time.Time{},
},
},
},
prompt: []input.Input{{Token: 1}},
prompt: []*input.Input{{Token: 1}},
longest: expected{result: 0, len: 1},
best: expected{result: 1, len: 1},
},
@@ -171,18 +171,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}},
Inputs: []*input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input.Input{{Token: 2}, {Token: 3}},
prompt: []*input.Input{{Token: 2}, {Token: 3}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
@@ -191,18 +191,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{{Token: 1}},
Inputs: []*input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []input.Input{{Token: 1}, {Token: 2}},
prompt: []*input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 1},
best: expected{result: 1, len: 2},
},
@@ -300,7 +300,7 @@ func TestLoadCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
prompt []input.Input
prompt []*input.Input
wantErr bool
expectedSlotId int
expectedPrompt int // expected length of remaining prompt
@@ -312,19 +312,19 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
Inputs: []*input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
@@ -336,19 +336,19 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
Inputs: []*input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
@@ -360,13 +360,13 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}},
prompt: []*input.Input{{Token: 1}, {Token: 2}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Should leave 1 token for sampling
@@ -378,13 +378,13 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: true,
expectedSlotId: -1,
expectedPrompt: -1,
@@ -452,7 +452,7 @@ func TestShiftCacheSlot(t *testing.T) {
tests := []struct {
name string
numCtx int32
inputs []input.Input
inputs []*input.Input
numKeep int32
cacheErr bool
wantErr any
@@ -461,7 +461,7 @@ func TestShiftCacheSlot(t *testing.T) {
{
name: "Normal shift",
numCtx: 10,
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2,
cacheErr: false, // No error
wantErr: nil,
@@ -470,7 +470,7 @@ func TestShiftCacheSlot(t *testing.T) {
{
name: "Cache removal fails",
numCtx: 10,
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2,
cacheErr: true,
wantErr: &ErrReprocessInputs{},
@@ -487,7 +487,7 @@ func TestShiftCacheSlot(t *testing.T) {
}
slot := &InputCacheSlot{
Id: 123,
Inputs: make([]input.Input, len(tt.inputs)),
Inputs: make([]*input.Input, len(tt.inputs)),
}
copy(slot.Inputs, tt.inputs)

View File

@@ -17,6 +17,7 @@ import (
"reflect"
"regexp"
"runtime"
"runtime/debug"
"strconv"
"strings"
"sync"
@@ -51,10 +52,10 @@ type Sequence struct {
iBatch int
// prompt inputs left to evaluate
inputs []input.Input
inputs []*input.Input
// inputs that have been added to a batch but not yet submitted to Forward
pendingInputs []input.Input
pendingInputs []*input.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
@@ -86,6 +87,12 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
embeddingOnly bool
// true if the sequence if finished and marked for removal on next pass
finished bool
// True if we have to skip this sequence to shift the cache
skipForShift bool
doneReason llm.DoneReason
// Metrics
@@ -182,8 +189,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) {
var inputs []input.Input
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) {
var inputs []*input.Input
var ctxs []ml.Context
var mmStore multimodalStore
@@ -210,7 +217,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
}
for _, t := range tokens {
inputs = append(inputs, input.Input{Token: t})
inputs = append(inputs, &input.Input{Token: t})
}
// image - decode and store
@@ -243,7 +250,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
mmStore.addMultimodal(imageEmbeddings)
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
postTokenize = true
}
}
@@ -259,6 +266,27 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
return inputs, ctxs, mmStore, nil
}
type batchState struct {
id int
ctx ml.Context
modelInput ml.Tensor
modelOutput ml.Tensor
batchInputs []*input.Input
batch input.Batch
seqs []*Sequence // full set of seqs at the time this batch was initiated
initSeqIdx int // The initial value for the set of sequences evaluated (s.nextSeq - 1)
// Signaled when this batches inputs are ready and compute can proceed
inputsReadyCh chan struct{}
// Signaling when Compute is about to begin on this batch, and
// seqs have been updated to prepare for the next batch
computeStartedCh chan struct{}
// Signaled when this batches outputs are complete and the next batch can proceed
outputsReadyCh chan struct{}
}
type Server struct {
// modelPath is the location of the model to be loaded
modelPath string
@@ -290,6 +318,16 @@ type Server struct {
// TODO (jmorganca): make this n_batch
batchSize int
// Used to signal a hard failure during async processing which will panic the runner
hardErrCh chan error
// A prior batch that's still being processed
// only read or written by forwardBatch
pendingBatch *batchState
// Simple counter used only for trace logging batches
batchID int
// protects access to everything below this line
// this is context state needed for decoding
mu sync.Mutex
@@ -350,45 +388,132 @@ func flushPending(seq *Sequence) bool {
}
}
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
func (s *Server) finishSequence(seqIndex int, reason llm.DoneReason) {
seq := s.seqs[seqIndex]
// finish could be called multiple times since we prepare 1 batch ahead
// and multiple scenarios can lead to finishing a sequence
// ensure only the first finish called is processed
if seq.finished {
return
}
flushPending(seq)
seq.doneReason = reason
seq.finished = true
close(seq.responses)
close(seq.embedding)
seq.cache.InUse = false
}
func (s *Server) removeFinishedSequence(seqIndex int) {
s.seqs[seqIndex] = nil
s.seqsSem.Release(1)
}
// track batch state between forwardBatch, computeBatch and predictForwardBatch
func (s *Server) run(ctx context.Context) {
s.ready.Wait()
var bs *batchState
for {
select {
case <-ctx.Done():
return
case err := <-s.hardErrCh:
panic(err)
default:
err := s.processBatch()
var err error
bs, err = s.forwardBatch()
if err != nil {
panic(err)
}
if bs == nil {
continue
}
go s.computeBatch(bs)
}
}
}
func (s *Server) processBatch() error {
// forwardBatch will calculate a batch.
func (s *Server) forwardBatch() (*batchState, error) {
inputsReady := false
var inputsReadyCh chan struct{}
// If we have a pending batch still processing, wait until Compute has started
// before setting up the next batch so the seqs inputs are ready to receive their
// token values and we get the correct input pointers for the batchInputs
if s.pendingBatch != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch waiting for compute to start", "pendingBatch.id", s.pendingBatch.id)
<-s.pendingBatch.computeStartedCh
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch compute started, setting up next batch", "pendingBatch.id", s.pendingBatch.id, "id", s.batchID)
inputsReadyCh = s.pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
} else {
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no pending batch detected", "batchID", s.batchID)
inputsReady = true // No pendingBatch, so the inputs will be ready in the seqs immediately
inputsReadyCh = make(chan struct{}, 1)
}
s.mu.Lock()
for s.allNil() {
s.cond.Wait() // Wait until an item is added
}
defer s.mu.Unlock()
ctx := s.model.Backend().NewContext()
defer ctx.Close()
// If new sequences have been added with an active batch we delay preparing the next batch
// until Compute has finished
if s.pendingBatch != nil {
for seqIdx := range s.seqs {
if s.seqs[seqIdx] != s.pendingBatch.seqs[seqIdx] {
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch seqs changed, waiting for compute to finish to pick up new sequence(s)", "pendingBatch.id", s.pendingBatch.id)
s.mu.Unlock() // release the lock so computeBatch can finish up
<-s.pendingBatch.outputsReadyCh
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch pending batch outputs ready", "pendingBatch.id", s.pendingBatch.id)
s.mu.Lock()
inputsReady = true // pendingBatch completed, so the inputs are ready in the seqs
break
}
}
}
// Clear pending Batch - we'll set it if we have a batch with any inputs
s.pendingBatch = nil
var batchInputs []int32
// Remove any finished sequences before recording the active set of seqs in the batch
for seqIdx := range s.seqs {
seq := s.seqs[seqIdx]
if seq == nil {
continue
}
if seq.finished {
s.removeFinishedSequence(seqIdx)
continue
}
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.finishSequence(seqIdx, llm.DoneReasonLength)
s.removeFinishedSequence(seqIdx)
continue
}
}
// next batch
nb := &batchState{
id: s.batchID,
initSeqIdx: s.nextSeq - 1,
seqs: make([]*Sequence, len(s.seqs)),
inputsReadyCh: inputsReadyCh,
computeStartedCh: make(chan struct{}, 1),
outputsReadyCh: make(chan struct{}, 1),
}
ctx := s.model.Backend().NewContext()
nb.ctx = ctx
// Record the sequences at the time we create the batch so we can detect if new sequences are added on the next pass
copy(nb.seqs, s.seqs)
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
var batchInputs []*input.Input
var batch input.Batch
resumeSeq := -1
@@ -396,20 +521,13 @@ func (s *Server) processBatch() error {
for range s.seqs {
seqIdx = (seqIdx + 1) % len(s.seqs)
seq := s.seqs[seqIdx]
if seq == nil {
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, llm.DoneReasonLength)
continue
}
if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []input.Input{}
seq.cache.Inputs = []*input.Input{}
}
batchSize := s.batchSize
@@ -449,18 +567,21 @@ func (s *Server) processBatch() error {
// Prepend these inputs to the sequence's inputs queue for reprocessing
seq.inputs = append(reprocess.Inputs, seq.inputs...)
// Skip this sequence but continue processing the rest
seq.skipForShift = true // cleared in computeBatch below for the next batch
continue
} else {
return err
ctx.Close()
return nil, err
}
}
}
batchInputs = append(batchInputs, inp.Token)
batchInputs = append(batchInputs, seq.inputs[i])
if inp.Multimodal != nil {
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
if err != nil {
return err
ctx.Close()
return nil, err
}
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
}
@@ -468,10 +589,13 @@ func (s *Server) processBatch() error {
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
batch.Sequences = append(batch.Sequences, seq.cache.Id)
// TODO BUG HERE!!!
// Somehow sometimes iBatch isn't set correctly
seq.iBatch = len(batch.Outputs)
if i+1 == len(seq.inputs) {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
}
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
seq.pendingInputs = append(seq.pendingInputs, inp)
}
@@ -485,36 +609,138 @@ func (s *Server) processBatch() error {
}
if len(batchInputs) == 0 {
return nil
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no batchInputs, going idle", "batchID", s.batchID)
ctx.Close()
return nil, nil
}
s.batchID++
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
var err error
// Actual batchInputs values will be injected into the modelInput tensor before calling Compute
nb.modelInput, nb.modelOutput, err = model.Forward(ctx, s.model, make([]int32, len(batchInputs)), batch)
if err != nil {
return fmt.Errorf("failed to decode batch: %w", err)
ctx.Close()
return nil, fmt.Errorf("failed to build graph: %w", err)
}
nb.batchInputs = batchInputs
nb.batch = batch
// computeBatch will close the context in the batch upon completion
s.pendingBatch = nb
if inputsReady {
nb.inputsReadyCh <- struct{}{}
}
logits := modelOutput.Floats()
return nb, nil
}
// Async processing of the next batch
func (s *Server) computeBatch(bs *batchState) {
if bs == nil || bs.ctx == nil {
// Nothing to compute
return
}
defer bs.ctx.Close()
// Wait until inputs are ready
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: waiting for inputs to be ready", "batchID", bs.id)
<-bs.inputsReadyCh
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: inputs are ready", "batchID", bs.id)
// Once we complete, signal the next batch of inputs are ready
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
defer func() {
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: outputs are ready", "batchID", bs.id)
bs.outputsReadyCh <- struct{}{}
}()
s.mu.Lock()
// Gather the actual input token values now that they're ready
batchInputs := make([]int32, len(bs.batchInputs))
for i := range batchInputs {
batchInputs[i] = bs.batchInputs[i].Token
}
// TODO the following logic could be run in a go routine to possibly speed up getting to Compute
// Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
// so that forwardBatch can build a batchInputs set which will eventually contain the actual
// decoded tokens.
promptProcessing := make([]bool, len(s.seqs)) // track seq's we skip
nextBatchTokens := make([]*input.Input, len(s.seqs))
iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
for i, seq := range s.seqs {
iBatches[i] = -1
if seq == nil {
continue
}
// Skip over any newly added sequences
if bs.seqs[i] == nil {
continue
}
// After calling Forward, pending inputs are now in the cache
if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
seq.pendingInputs = []input.Input{}
seq.pendingInputs = []*input.Input{}
}
// don't sample prompt processing
if len(seq.inputs) != 0 {
if !s.cache.enabled {
return errors.New("caching disabled but unable to fit entire input in a batch")
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
return
}
// Record so we can skip during Decode
promptProcessing[i] = true
continue
}
seq.numPredicted++
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
seq.inputs = []*input.Input{nextToken}
nextBatchTokens[i] = nextToken
iBatches[i] = seq.iBatch
}
// At this point the seqs are ready for forwardBatch to move forward so unblock
s.mu.Unlock()
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: signaling computeStartedCh", "batchID", bs.id)
bs.computeStartedCh <- struct{}{}
bs.modelInput.BackendSetFromIntSlice(batchInputs)
bs.ctx.Compute(bs.modelOutput)
logits := bs.modelOutput.Floats()
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: logits ready", "batchID", bs.id)
s.mu.Lock()
defer s.mu.Unlock()
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: decoding", "batchID", bs.id)
for i, seq := range s.seqs {
if seq == nil {
continue
}
// Skip over any newly added sequences
if bs.seqs[i] == nil {
continue
}
// Detect if the sequence we're processing has already been completed and replaced
// with a new sequence
if seq != bs.seqs[i] {
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: sequence replaced, discarding its results", "batchID", bs.id, "seqIdx", i)
continue
}
// don't sample prompt processing
if promptProcessing[i] {
continue
}
if seq.numPredicted == 1 {
seq.startGenerationTime = time.Now()
}
@@ -522,35 +748,46 @@ func (s *Server) processBatch() error {
// if done processing the prompt, generate an embedding and return
if seq.embeddingOnly {
// TODO(jessegross): Embedding support
slog.Warn("generation of embedding outputs not yet supported")
s.removeSequence(i, llm.DoneReasonStop)
slog.Warn("generation of embedding outputs not yet supported", "id", bs.id, "seqIdx", i)
s.finishSequence(i, llm.DoneReasonStop)
continue
}
// sample a token
vocabSize := len(logits) / len(batch.Outputs)
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
vocabSize := len(logits) / len(bs.batch.Outputs)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: vocab details", "batchID", bs.id, "seqIdx", i, "len(logits)", len(logits), "len(bs.batch.Outputs)", len(bs.batch.Outputs), "vocabSize", vocabSize, "seq.iBatch", seq.iBatch)
token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil {
return fmt.Errorf("failed to sample token: %w", err)
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
return
}
nextBatchTokens[i].Token = token
// if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
s.removeSequence(i, llm.DoneReasonStop)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: EOS", "batchID", bs.id, "seqIdx", i)
s.finishSequence(i, llm.DoneReasonStop)
continue
}
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
if err != nil {
return err
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
return
}
seq.inputs = []input.Input{{Token: token}}
if nextBatchTokens[i] == nil {
slog.Error("batch corrupted", "id", bs.id, "batch", bs.batch, "seqIdx", i, "seq", seq)
s.hardErrCh <- fmt.Errorf("expected a single token during decode")
return
}
// fill in the final selected token value to replace the placeholder in the next batch
// nextBatchTokensWritten++
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
@@ -575,9 +812,10 @@ func (s *Server) processBatch() error {
if tokenTruncated || origLen == newLen {
tokenLen--
}
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
s.removeSequence(i, llm.DoneReasonStop)
s.finishSequence(i, llm.DoneReasonStop)
continue
}
@@ -590,11 +828,9 @@ func (s *Server) processBatch() error {
}
if !flushPending(seq) {
s.removeSequence(i, llm.DoneReasonConnectionClosed)
s.finishSequence(i, llm.DoneReasonConnectionClosed)
}
}
return nil
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
@@ -736,7 +972,10 @@ func (s *Server) reserveWorstCaseGraph() error {
defer ctx.Close()
var err error
inputs := make([]input.Input, s.batchSize)
inputs := make([]*input.Input, s.batchSize)
for i := range inputs {
inputs[i] = &input.Input{}
}
mmStore := newMultimodalStore()
// Multimodal strategy:
@@ -778,8 +1017,11 @@ func (s *Server) reserveWorstCaseGraph() error {
}
if len(inputs) < s.batchSize {
newInputs := make([]input.Input, s.batchSize)
newInputs := make([]*input.Input, s.batchSize)
copy(newInputs, inputs)
for i := len(inputs); i < s.batchSize; i++ {
newInputs[i] = &input.Input{}
}
inputs = newInputs
}
}
@@ -842,6 +1084,7 @@ func (s *Server) allocModel(
// Convert memory allocation panics to errors
defer func() {
if r := recover(); r != nil {
debug.PrintStack()
if err, ok := r.(error); ok {
panicErr = err
} else {
@@ -1011,6 +1254,7 @@ func Execute(args []string) error {
server := &Server{
modelPath: *mpath,
status: llm.ServerStatusLaunched,
hardErrCh: make(chan error, 1),
}
server.cond = sync.NewCond(&server.mu)