mirror of
https://github.com/ollama/ollama.git
synced 2025-11-11 15:47:07 +01:00
perf: build graph for next batch in parallel to keep GPU busy
This refactors the main run loop of the ollama runner to perform the main GPU intensive tasks (Compute+Floats) in a go routine so we can prepare the next batch in parallel to reduce the amount of time the GPU stalls waiting for the next batch of work.
This commit is contained in:
@@ -2,10 +2,13 @@
|
||||
|
||||
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
||||
|
||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...`
|
||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"`
|
||||
|
||||
|
||||
The integration tests have 2 modes of operating.
|
||||
|
||||
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote
|
||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
||||
|
||||
@@ -66,7 +66,7 @@ func TestContextExhaustion(t *testing.T) {
|
||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
// Send multiple requests with prior context and ensure the response is coherant and expected
|
||||
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
||||
func TestGenerateWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
req, resp := GenerateRequests()
|
||||
@@ -111,5 +111,56 @@ func TestGenerateWithHistory(t *testing.T) {
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
}
|
||||
|
||||
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
||||
func TestChatWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
req, resp := ChatRequests()
|
||||
numParallel := 2
|
||||
iterLimit := 2
|
||||
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial empty request
|
||||
slog.Info("loading", "model", modelOverride)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", modelOverride, err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numParallel)
|
||||
for i := range numParallel {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
k := i % len(req)
|
||||
req[k].Model = modelOverride
|
||||
for j := 0; j < iterLimit; j++ {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
if assistant == nil {
|
||||
t.Fatalf("didn't get an assistant response for context")
|
||||
}
|
||||
req[k].Messages = append(req[k].Messages,
|
||||
*assistant,
|
||||
api.Message{Role: "user", Content: "tell me more!"},
|
||||
)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -19,6 +19,8 @@ import (
|
||||
)
|
||||
|
||||
func TestMaxQueue(t *testing.T) {
|
||||
t.Skip("this test needs to be re-evaluated to use a proper embedding model")
|
||||
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
|
||||
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
|
||||
return
|
||||
|
||||
@@ -567,6 +567,76 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
}
|
||||
}
|
||||
|
||||
func ChatRequests() ([]api.ChatRequest, [][]string) {
|
||||
genReqs, results := GenerateRequests()
|
||||
reqs := make([]api.ChatRequest, len(genReqs))
|
||||
for i := range reqs {
|
||||
reqs[i].Model = genReqs[i].Model
|
||||
reqs[i].Stream = genReqs[i].Stream
|
||||
reqs[i].KeepAlive = genReqs[i].KeepAlive
|
||||
reqs[i].Messages = []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: genReqs[i].Prompt,
|
||||
},
|
||||
}
|
||||
}
|
||||
return reqs, results
|
||||
}
|
||||
|
||||
func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message {
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var buf bytes.Buffer
|
||||
role := "assistant"
|
||||
fn := func(response api.ChatResponse) error {
|
||||
// fmt.Print(".")
|
||||
role = response.Message.Role
|
||||
buf.Write([]byte(response.Message.Content))
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return errors.New("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
req.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||
} else {
|
||||
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||
}
|
||||
case <-done:
|
||||
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
|
||||
slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr)
|
||||
return nil
|
||||
}
|
||||
require.NoError(t, genErr, "failed with %s request Messages %s ", req.Model, req.Messages)
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(t, atLeastOne, "%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
}
|
||||
return &api.Message{Role: role, Content: buf.String()}
|
||||
}
|
||||
|
||||
func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||
// TODO use info API in the future
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
|
||||
@@ -400,6 +400,8 @@ type Tensor interface {
|
||||
Bytes() []byte
|
||||
Floats() []float32
|
||||
|
||||
BackendSetFromIntSlice(s []int32)
|
||||
|
||||
Neg(ctx Context) Tensor
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Sub(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
@@ -82,6 +82,7 @@ type Backend struct {
|
||||
// to the name that is used by the model definition
|
||||
tensorLoadTargets map[string][]string
|
||||
|
||||
schedMu sync.Mutex // Only one Compute can run at a time
|
||||
sched C.ggml_backend_sched_t
|
||||
schedBackends []C.ggml_backend_t
|
||||
schedBufts []C.ggml_backend_buffer_type_t
|
||||
@@ -769,6 +770,8 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
||||
}
|
||||
|
||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
c.b.schedMu.Lock()
|
||||
defer c.b.schedMu.Unlock()
|
||||
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
|
||||
panic(fmt.Errorf("error computing ggml graph: %v", status))
|
||||
}
|
||||
@@ -1037,6 +1040,12 @@ func (t *Tensor) Floats() (data []float32) {
|
||||
return
|
||||
}
|
||||
|
||||
func (t *Tensor) BackendSetFromIntSlice(s []int32) {
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t))
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) DType() ml.DType {
|
||||
switch t.t._type {
|
||||
case C.GGML_TYPE_F32:
|
||||
|
||||
@@ -64,7 +64,7 @@ type MultimodalProcessor interface {
|
||||
// This function is also responsible for updating MultimodalHash for any Multimodal
|
||||
// that is modified to ensure that there is a unique hash value that accurately
|
||||
// represents the contents.
|
||||
PostTokenize([]input.Input) ([]input.Input, error)
|
||||
PostTokenize([]*input.Input) ([]*input.Input, error)
|
||||
}
|
||||
|
||||
// Base implements the common fields and methods for all models
|
||||
@@ -278,13 +278,13 @@ func canNil(t reflect.Type) bool {
|
||||
t.Kind() == reflect.Slice
|
||||
}
|
||||
|
||||
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
|
||||
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, ml.Tensor, error) {
|
||||
if len(batch.Positions) != len(batch.Sequences) {
|
||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||
return nil, nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||
}
|
||||
|
||||
if len(batch.Positions) < 1 {
|
||||
return nil, errors.New("batch size cannot be less than 1")
|
||||
return nil, nil, errors.New("batch size cannot be less than 1")
|
||||
}
|
||||
|
||||
batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||
@@ -293,16 +293,16 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
|
||||
if cache != nil {
|
||||
err := cache.StartForward(ctx, batch, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := m.Forward(ctx, batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
ctx.Forward(t).Compute(t)
|
||||
ctx.Forward(t)
|
||||
|
||||
return t, nil
|
||||
return batch.Inputs, t, nil
|
||||
}
|
||||
|
||||
@@ -112,8 +112,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
@@ -122,17 +122,17 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
inputMultimodal := inp.Multimodal[0].Tensor
|
||||
|
||||
result = append(result,
|
||||
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||
input.Input{Token: 255999}, // "<start_of_image>""
|
||||
input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
&input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||
&input.Input{Token: 255999}, // "<start_of_image>""
|
||||
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
)
|
||||
|
||||
// add image token placeholders
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||
|
||||
result = append(result,
|
||||
input.Input{Token: 256000}, // <end_of_image>
|
||||
input.Input{Token: 108}, // "\n\n"
|
||||
&input.Input{Token: 256000}, // <end_of_image>
|
||||
&input.Input{Token: 108}, // "\n\n"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,16 +134,16 @@ type separator struct {
|
||||
y bool
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
continue
|
||||
}
|
||||
|
||||
var imageInputs []input.Input
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
|
||||
var imageInputs []*input.Input
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_start|>
|
||||
|
||||
for i, mm := range inp.Multimodal {
|
||||
patchesPerChunk := mm.Tensor.Dim(1)
|
||||
@@ -151,20 +151,20 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
if i < len(inp.Multimodal)-1 {
|
||||
separator := mm.Data.(*separator)
|
||||
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
|
||||
if separator.x {
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200084}) // <|tile_x_separator|>
|
||||
}
|
||||
if separator.y {
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200085}) // <|tile_y_separator|>
|
||||
}
|
||||
} else {
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200090}) // <|image|>
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_end|>
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -133,22 +133,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
|
||||
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
|
||||
// that can be processed together.
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
for i, row := range inp.Multimodal {
|
||||
// [IMG]
|
||||
result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
|
||||
result = append(result, &input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
|
||||
if i == len(inp.Multimodal)-1 {
|
||||
// [IMG_END]
|
||||
result = append(result, input.Input{Token: 13})
|
||||
result = append(result, &input.Input{Token: 13})
|
||||
} else {
|
||||
// [IMG_BREAK]
|
||||
result = append(result, input.Input{Token: 12})
|
||||
result = append(result, &input.Input{Token: 12})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,7 +90,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return []input.Multimodal{{Tensor: projectedOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
for i := range inputs {
|
||||
if inputs[i].Multimodal != nil {
|
||||
inputs[i].Token = 128256 // <|image|>
|
||||
|
||||
@@ -89,8 +89,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
}
|
||||
|
||||
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
var (
|
||||
imageToken int32 = 151655
|
||||
@@ -112,16 +112,16 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
|
||||
}
|
||||
for i := range pre {
|
||||
result = append(result, input.Input{Token: pre[i]})
|
||||
result = append(result, &input.Input{Token: pre[i]})
|
||||
}
|
||||
|
||||
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
|
||||
|
||||
// First add the vision start token
|
||||
result = append(result, input.Input{Token: visionStartToken})
|
||||
result = append(result, &input.Input{Token: visionStartToken})
|
||||
|
||||
// Add the image token with the multimodal tensor data at the first position
|
||||
result = append(result, input.Input{
|
||||
result = append(result, &input.Input{
|
||||
Token: imageToken,
|
||||
Multimodal: inp.Multimodal,
|
||||
MultimodalHash: inp.MultimodalHash,
|
||||
@@ -129,9 +129,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
})
|
||||
|
||||
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
||||
|
||||
result = append(result, input.Input{Token: visionEndToken})
|
||||
result = append(result, &input.Input{Token: visionEndToken})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ type InputCacheSlot struct {
|
||||
Id int
|
||||
|
||||
// Inputs that are stored in the KV cache
|
||||
Inputs []input.Input
|
||||
Inputs []*input.Input
|
||||
|
||||
// is this cache actively being processed as part of a sequence?
|
||||
InUse bool
|
||||
@@ -95,7 +95,7 @@ type InputCacheSlot struct {
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
|
||||
func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*input.Input, error) {
|
||||
var slot *InputCacheSlot
|
||||
var numPast int32
|
||||
var err error
|
||||
@@ -146,7 +146,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
|
||||
return slot, prompt, nil
|
||||
}
|
||||
|
||||
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
||||
func (c *InputCache) findLongestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
|
||||
longest := int32(-1)
|
||||
var longestSlot *InputCacheSlot
|
||||
|
||||
@@ -169,7 +169,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot
|
||||
return longestSlot, longest, nil
|
||||
}
|
||||
|
||||
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
||||
func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
|
||||
oldest := time.Now()
|
||||
var oldestSlot *InputCacheSlot
|
||||
|
||||
@@ -205,7 +205,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
|
||||
if longest > 0 && longestSlot != oldestSlot {
|
||||
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
||||
len(longestSlot.Inputs))
|
||||
oldestSlot.Inputs = make([]input.Input, longest)
|
||||
oldestSlot.Inputs = make([]*input.Input, longest)
|
||||
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
||||
if c.cache != nil {
|
||||
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
|
||||
@@ -215,7 +215,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
|
||||
return oldestSlot, longest, nil
|
||||
}
|
||||
|
||||
func countCommonPrefix(a []input.Input, b []input.Input) int32 {
|
||||
func countCommonPrefix(a []*input.Input, b []*input.Input) int32 {
|
||||
var count int32
|
||||
|
||||
for i := range a {
|
||||
@@ -250,7 +250,7 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
||||
}
|
||||
|
||||
type ErrReprocessInputs struct {
|
||||
Inputs []input.Input
|
||||
Inputs []*input.Input
|
||||
}
|
||||
|
||||
func (e *ErrReprocessInputs) Error() string {
|
||||
@@ -283,13 +283,13 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
|
||||
"id", slot.Id, "error", err)
|
||||
|
||||
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
|
||||
newInputs := make([]*input.Input, numKeep+inputLen-(numKeep+discard))
|
||||
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||
|
||||
// Reset the cache
|
||||
_ = c.cache.Remove(slot.Id, 0, math.MaxInt32)
|
||||
slot.Inputs = []input.Input{}
|
||||
slot.Inputs = []*input.Input{}
|
||||
|
||||
// Return error with inputs that need to be reprocessed
|
||||
return &ErrReprocessInputs{Inputs: newInputs}
|
||||
|
||||
@@ -13,50 +13,50 @@ import (
|
||||
func TestCountCommon(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
t1 []input.Input
|
||||
t2 []input.Input
|
||||
t1 []*input.Input
|
||||
t2 []*input.Input
|
||||
expected int32
|
||||
}{
|
||||
{
|
||||
name: "Equal",
|
||||
t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t1: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
name: "Prefix",
|
||||
t1: []input.Input{{Token: 1}},
|
||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t1: []*input.Input{{Token: 1}},
|
||||
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Image Prefix",
|
||||
t1: []input.Input{{MultimodalHash: 1}},
|
||||
t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
|
||||
t1: []*input.Input{{MultimodalHash: 1}},
|
||||
t2: []*input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Mixed",
|
||||
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
|
||||
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []*input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "Mixed, Same Length",
|
||||
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []input.Input{{Token: 1}, {MultimodalHash: 2}},
|
||||
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []*input.Input{{Token: 1}, {MultimodalHash: 2}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Empty",
|
||||
t1: []input.Input{},
|
||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t1: []*input.Input{},
|
||||
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "Both Empty",
|
||||
t1: []input.Input{},
|
||||
t2: []input.Input{},
|
||||
t1: []*input.Input{},
|
||||
t2: []*input.Input{},
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
@@ -80,7 +80,7 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cache InputCache
|
||||
prompt []input.Input
|
||||
prompt []*input.Input
|
||||
longest expected
|
||||
best expected
|
||||
}{
|
||||
@@ -89,18 +89,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
}},
|
||||
prompt: []input.Input{{Token: 1}},
|
||||
prompt: []*input.Input{{Token: 1}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 0, len: 0},
|
||||
},
|
||||
@@ -109,18 +109,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}},
|
||||
Inputs: []*input.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
longest: expected{result: 1, len: 2},
|
||||
best: expected{result: 1, len: 2},
|
||||
},
|
||||
@@ -129,18 +129,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
}},
|
||||
prompt: []input.Input{{Token: 2}},
|
||||
prompt: []*input.Input{{Token: 2}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 1, len: 0},
|
||||
},
|
||||
@@ -150,19 +150,19 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}},
|
||||
prompt: []*input.Input{{Token: 1}},
|
||||
longest: expected{result: 0, len: 1},
|
||||
best: expected{result: 1, len: 1},
|
||||
},
|
||||
@@ -171,18 +171,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}},
|
||||
Inputs: []*input.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input.Input{{Token: 2}, {Token: 3}},
|
||||
prompt: []*input.Input{{Token: 2}, {Token: 3}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 1, len: 0},
|
||||
},
|
||||
@@ -191,18 +191,18 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: true,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{{Token: 1}},
|
||||
Inputs: []*input.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
longest: expected{result: 1, len: 1},
|
||||
best: expected{result: 1, len: 2},
|
||||
},
|
||||
@@ -300,7 +300,7 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cache InputCache
|
||||
prompt []input.Input
|
||||
prompt []*input.Input
|
||||
wantErr bool
|
||||
expectedSlotId int
|
||||
expectedPrompt int // expected length of remaining prompt
|
||||
@@ -312,19 +312,19 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
@@ -336,19 +336,19 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
@@ -360,13 +360,13 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Should leave 1 token for sampling
|
||||
@@ -378,13 +378,13 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: true,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: true,
|
||||
expectedSlotId: -1,
|
||||
expectedPrompt: -1,
|
||||
@@ -452,7 +452,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numCtx int32
|
||||
inputs []input.Input
|
||||
inputs []*input.Input
|
||||
numKeep int32
|
||||
cacheErr bool
|
||||
wantErr any
|
||||
@@ -461,7 +461,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
||||
{
|
||||
name: "Normal shift",
|
||||
numCtx: 10,
|
||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
numKeep: 2,
|
||||
cacheErr: false, // No error
|
||||
wantErr: nil,
|
||||
@@ -470,7 +470,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
||||
{
|
||||
name: "Cache removal fails",
|
||||
numCtx: 10,
|
||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
numKeep: 2,
|
||||
cacheErr: true,
|
||||
wantErr: &ErrReprocessInputs{},
|
||||
@@ -487,7 +487,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
||||
}
|
||||
slot := &InputCacheSlot{
|
||||
Id: 123,
|
||||
Inputs: make([]input.Input, len(tt.inputs)),
|
||||
Inputs: make([]*input.Input, len(tt.inputs)),
|
||||
}
|
||||
copy(slot.Inputs, tt.inputs)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -51,10 +52,10 @@ type Sequence struct {
|
||||
iBatch int
|
||||
|
||||
// prompt inputs left to evaluate
|
||||
inputs []input.Input
|
||||
inputs []*input.Input
|
||||
|
||||
// inputs that have been added to a batch but not yet submitted to Forward
|
||||
pendingInputs []input.Input
|
||||
pendingInputs []*input.Input
|
||||
|
||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||
pendingResponses []string
|
||||
@@ -86,6 +87,12 @@ type Sequence struct {
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
// true if the sequence if finished and marked for removal on next pass
|
||||
finished bool
|
||||
|
||||
// True if we have to skip this sequence to shift the cache
|
||||
skipForShift bool
|
||||
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
@@ -182,8 +189,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
// inputs processes the prompt and images into a list of inputs
|
||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||
// decoding images
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) {
|
||||
var inputs []input.Input
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) {
|
||||
var inputs []*input.Input
|
||||
var ctxs []ml.Context
|
||||
var mmStore multimodalStore
|
||||
|
||||
@@ -210,7 +217,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||
}
|
||||
|
||||
for _, t := range tokens {
|
||||
inputs = append(inputs, input.Input{Token: t})
|
||||
inputs = append(inputs, &input.Input{Token: t})
|
||||
}
|
||||
|
||||
// image - decode and store
|
||||
@@ -243,7 +250,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||
|
||||
mmStore.addMultimodal(imageEmbeddings)
|
||||
|
||||
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||
inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||
postTokenize = true
|
||||
}
|
||||
}
|
||||
@@ -259,6 +266,27 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
||||
return inputs, ctxs, mmStore, nil
|
||||
}
|
||||
|
||||
type batchState struct {
|
||||
id int
|
||||
ctx ml.Context
|
||||
modelInput ml.Tensor
|
||||
modelOutput ml.Tensor
|
||||
batchInputs []*input.Input
|
||||
batch input.Batch
|
||||
seqs []*Sequence // full set of seqs at the time this batch was initiated
|
||||
initSeqIdx int // The initial value for the set of sequences evaluated (s.nextSeq - 1)
|
||||
|
||||
// Signaled when this batches inputs are ready and compute can proceed
|
||||
inputsReadyCh chan struct{}
|
||||
|
||||
// Signaling when Compute is about to begin on this batch, and
|
||||
// seqs have been updated to prepare for the next batch
|
||||
computeStartedCh chan struct{}
|
||||
|
||||
// Signaled when this batches outputs are complete and the next batch can proceed
|
||||
outputsReadyCh chan struct{}
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
// modelPath is the location of the model to be loaded
|
||||
modelPath string
|
||||
@@ -290,6 +318,16 @@ type Server struct {
|
||||
// TODO (jmorganca): make this n_batch
|
||||
batchSize int
|
||||
|
||||
// Used to signal a hard failure during async processing which will panic the runner
|
||||
hardErrCh chan error
|
||||
|
||||
// A prior batch that's still being processed
|
||||
// only read or written by forwardBatch
|
||||
pendingBatch *batchState
|
||||
|
||||
// Simple counter used only for trace logging batches
|
||||
batchID int
|
||||
|
||||
// protects access to everything below this line
|
||||
// this is context state needed for decoding
|
||||
mu sync.Mutex
|
||||
@@ -350,45 +388,132 @@ func flushPending(seq *Sequence) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
func (s *Server) finishSequence(seqIndex int, reason llm.DoneReason) {
|
||||
seq := s.seqs[seqIndex]
|
||||
|
||||
// finish could be called multiple times since we prepare 1 batch ahead
|
||||
// and multiple scenarios can lead to finishing a sequence
|
||||
// ensure only the first finish called is processed
|
||||
if seq.finished {
|
||||
return
|
||||
}
|
||||
|
||||
flushPending(seq)
|
||||
seq.doneReason = reason
|
||||
seq.finished = true
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
seq.cache.InUse = false
|
||||
}
|
||||
|
||||
func (s *Server) removeFinishedSequence(seqIndex int) {
|
||||
s.seqs[seqIndex] = nil
|
||||
s.seqsSem.Release(1)
|
||||
}
|
||||
|
||||
// track batch state between forwardBatch, computeBatch and predictForwardBatch
|
||||
|
||||
func (s *Server) run(ctx context.Context) {
|
||||
s.ready.Wait()
|
||||
|
||||
var bs *batchState
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case err := <-s.hardErrCh:
|
||||
panic(err)
|
||||
default:
|
||||
err := s.processBatch()
|
||||
var err error
|
||||
bs, err = s.forwardBatch()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if bs == nil {
|
||||
continue
|
||||
}
|
||||
go s.computeBatch(bs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processBatch() error {
|
||||
// forwardBatch will calculate a batch.
|
||||
func (s *Server) forwardBatch() (*batchState, error) {
|
||||
inputsReady := false
|
||||
var inputsReadyCh chan struct{}
|
||||
|
||||
// If we have a pending batch still processing, wait until Compute has started
|
||||
// before setting up the next batch so the seqs inputs are ready to receive their
|
||||
// token values and we get the correct input pointers for the batchInputs
|
||||
if s.pendingBatch != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch waiting for compute to start", "pendingBatch.id", s.pendingBatch.id)
|
||||
<-s.pendingBatch.computeStartedCh
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch compute started, setting up next batch", "pendingBatch.id", s.pendingBatch.id, "id", s.batchID)
|
||||
inputsReadyCh = s.pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
|
||||
} else {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no pending batch detected", "batchID", s.batchID)
|
||||
inputsReady = true // No pendingBatch, so the inputs will be ready in the seqs immediately
|
||||
inputsReadyCh = make(chan struct{}, 1)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
for s.allNil() {
|
||||
s.cond.Wait() // Wait until an item is added
|
||||
}
|
||||
defer s.mu.Unlock()
|
||||
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
// If new sequences have been added with an active batch we delay preparing the next batch
|
||||
// until Compute has finished
|
||||
if s.pendingBatch != nil {
|
||||
for seqIdx := range s.seqs {
|
||||
if s.seqs[seqIdx] != s.pendingBatch.seqs[seqIdx] {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch seqs changed, waiting for compute to finish to pick up new sequence(s)", "pendingBatch.id", s.pendingBatch.id)
|
||||
s.mu.Unlock() // release the lock so computeBatch can finish up
|
||||
<-s.pendingBatch.outputsReadyCh
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch pending batch outputs ready", "pendingBatch.id", s.pendingBatch.id)
|
||||
s.mu.Lock()
|
||||
inputsReady = true // pendingBatch completed, so the inputs are ready in the seqs
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// Clear pending Batch - we'll set it if we have a batch with any inputs
|
||||
s.pendingBatch = nil
|
||||
|
||||
var batchInputs []int32
|
||||
// Remove any finished sequences before recording the active set of seqs in the batch
|
||||
for seqIdx := range s.seqs {
|
||||
seq := s.seqs[seqIdx]
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
if seq.finished {
|
||||
s.removeFinishedSequence(seqIdx)
|
||||
continue
|
||||
}
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.finishSequence(seqIdx, llm.DoneReasonLength)
|
||||
s.removeFinishedSequence(seqIdx)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// next batch
|
||||
nb := &batchState{
|
||||
id: s.batchID,
|
||||
initSeqIdx: s.nextSeq - 1,
|
||||
seqs: make([]*Sequence, len(s.seqs)),
|
||||
inputsReadyCh: inputsReadyCh,
|
||||
computeStartedCh: make(chan struct{}, 1),
|
||||
outputsReadyCh: make(chan struct{}, 1),
|
||||
}
|
||||
ctx := s.model.Backend().NewContext()
|
||||
nb.ctx = ctx
|
||||
|
||||
// Record the sequences at the time we create the batch so we can detect if new sequences are added on the next pass
|
||||
copy(nb.seqs, s.seqs)
|
||||
|
||||
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
||||
var batchInputs []*input.Input
|
||||
var batch input.Batch
|
||||
|
||||
resumeSeq := -1
|
||||
@@ -396,20 +521,13 @@ func (s *Server) processBatch() error {
|
||||
for range s.seqs {
|
||||
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||
seq := s.seqs[seqIdx]
|
||||
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.cache.enabled {
|
||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||
seq.cache.Inputs = []input.Input{}
|
||||
seq.cache.Inputs = []*input.Input{}
|
||||
}
|
||||
|
||||
batchSize := s.batchSize
|
||||
@@ -449,18 +567,21 @@ func (s *Server) processBatch() error {
|
||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||
// Skip this sequence but continue processing the rest
|
||||
seq.skipForShift = true // cleared in computeBatch below for the next batch
|
||||
continue
|
||||
} else {
|
||||
return err
|
||||
ctx.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
batchInputs = append(batchInputs, inp.Token)
|
||||
batchInputs = append(batchInputs, seq.inputs[i])
|
||||
if inp.Multimodal != nil {
|
||||
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
|
||||
if err != nil {
|
||||
return err
|
||||
ctx.Close()
|
||||
return nil, err
|
||||
}
|
||||
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
|
||||
}
|
||||
@@ -468,10 +589,13 @@ func (s *Server) processBatch() error {
|
||||
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||
|
||||
// TODO BUG HERE!!!
|
||||
// Somehow sometimes iBatch isn't set correctly
|
||||
seq.iBatch = len(batch.Outputs)
|
||||
if i+1 == len(seq.inputs) {
|
||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||
}
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||
}
|
||||
|
||||
@@ -485,36 +609,138 @@ func (s *Server) processBatch() error {
|
||||
}
|
||||
|
||||
if len(batchInputs) == 0 {
|
||||
return nil
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no batchInputs, going idle", "batchID", s.batchID)
|
||||
ctx.Close()
|
||||
return nil, nil
|
||||
}
|
||||
s.batchID++
|
||||
|
||||
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
|
||||
var err error
|
||||
// Actual batchInputs values will be injected into the modelInput tensor before calling Compute
|
||||
nb.modelInput, nb.modelOutput, err = model.Forward(ctx, s.model, make([]int32, len(batchInputs)), batch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
ctx.Close()
|
||||
return nil, fmt.Errorf("failed to build graph: %w", err)
|
||||
}
|
||||
nb.batchInputs = batchInputs
|
||||
nb.batch = batch
|
||||
|
||||
// computeBatch will close the context in the batch upon completion
|
||||
s.pendingBatch = nb
|
||||
|
||||
if inputsReady {
|
||||
nb.inputsReadyCh <- struct{}{}
|
||||
}
|
||||
|
||||
logits := modelOutput.Floats()
|
||||
return nb, nil
|
||||
}
|
||||
|
||||
// Async processing of the next batch
|
||||
func (s *Server) computeBatch(bs *batchState) {
|
||||
if bs == nil || bs.ctx == nil {
|
||||
// Nothing to compute
|
||||
return
|
||||
}
|
||||
defer bs.ctx.Close()
|
||||
|
||||
// Wait until inputs are ready
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: waiting for inputs to be ready", "batchID", bs.id)
|
||||
<-bs.inputsReadyCh
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: inputs are ready", "batchID", bs.id)
|
||||
|
||||
// Once we complete, signal the next batch of inputs are ready
|
||||
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
|
||||
defer func() {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: outputs are ready", "batchID", bs.id)
|
||||
bs.outputsReadyCh <- struct{}{}
|
||||
}()
|
||||
|
||||
s.mu.Lock()
|
||||
|
||||
// Gather the actual input token values now that they're ready
|
||||
batchInputs := make([]int32, len(bs.batchInputs))
|
||||
for i := range batchInputs {
|
||||
batchInputs[i] = bs.batchInputs[i].Token
|
||||
}
|
||||
|
||||
// TODO the following logic could be run in a go routine to possibly speed up getting to Compute
|
||||
|
||||
// Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
|
||||
// so that forwardBatch can build a batchInputs set which will eventually contain the actual
|
||||
// decoded tokens.
|
||||
promptProcessing := make([]bool, len(s.seqs)) // track seq's we skip
|
||||
nextBatchTokens := make([]*input.Input, len(s.seqs))
|
||||
iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
|
||||
for i, seq := range s.seqs {
|
||||
iBatches[i] = -1
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
// Skip over any newly added sequences
|
||||
if bs.seqs[i] == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// After calling Forward, pending inputs are now in the cache
|
||||
if len(seq.pendingInputs) > 0 {
|
||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||
seq.pendingInputs = []input.Input{}
|
||||
seq.pendingInputs = []*input.Input{}
|
||||
}
|
||||
|
||||
// don't sample prompt processing
|
||||
if len(seq.inputs) != 0 {
|
||||
if !s.cache.enabled {
|
||||
return errors.New("caching disabled but unable to fit entire input in a batch")
|
||||
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
|
||||
return
|
||||
}
|
||||
// Record so we can skip during Decode
|
||||
promptProcessing[i] = true
|
||||
continue
|
||||
}
|
||||
|
||||
seq.numPredicted++
|
||||
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
|
||||
seq.inputs = []*input.Input{nextToken}
|
||||
nextBatchTokens[i] = nextToken
|
||||
iBatches[i] = seq.iBatch
|
||||
}
|
||||
|
||||
// At this point the seqs are ready for forwardBatch to move forward so unblock
|
||||
s.mu.Unlock()
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: signaling computeStartedCh", "batchID", bs.id)
|
||||
bs.computeStartedCh <- struct{}{}
|
||||
|
||||
bs.modelInput.BackendSetFromIntSlice(batchInputs)
|
||||
bs.ctx.Compute(bs.modelOutput)
|
||||
logits := bs.modelOutput.Floats()
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: logits ready", "batchID", bs.id)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: decoding", "batchID", bs.id)
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
// Skip over any newly added sequences
|
||||
if bs.seqs[i] == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Detect if the sequence we're processing has already been completed and replaced
|
||||
// with a new sequence
|
||||
if seq != bs.seqs[i] {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: sequence replaced, discarding its results", "batchID", bs.id, "seqIdx", i)
|
||||
continue
|
||||
}
|
||||
|
||||
// don't sample prompt processing
|
||||
if promptProcessing[i] {
|
||||
continue
|
||||
}
|
||||
|
||||
if seq.numPredicted == 1 {
|
||||
seq.startGenerationTime = time.Now()
|
||||
}
|
||||
@@ -522,35 +748,46 @@ func (s *Server) processBatch() error {
|
||||
// if done processing the prompt, generate an embedding and return
|
||||
if seq.embeddingOnly {
|
||||
// TODO(jessegross): Embedding support
|
||||
slog.Warn("generation of embedding outputs not yet supported")
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
slog.Warn("generation of embedding outputs not yet supported", "id", bs.id, "seqIdx", i)
|
||||
s.finishSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(logits) / len(batch.Outputs)
|
||||
|
||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||
vocabSize := len(logits) / len(bs.batch.Outputs)
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: vocab details", "batchID", bs.id, "seqIdx", i, "len(logits)", len(logits), "len(bs.batch.Outputs)", len(bs.batch.Outputs), "vocabSize", vocabSize, "seq.iBatch", seq.iBatch)
|
||||
token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sample token: %w", err)
|
||||
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
nextBatchTokens[i].Token = token
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: EOS", "batchID", bs.id, "seqIdx", i)
|
||||
s.finishSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||
if err != nil {
|
||||
return err
|
||||
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
seq.inputs = []input.Input{{Token: token}}
|
||||
if nextBatchTokens[i] == nil {
|
||||
slog.Error("batch corrupted", "id", bs.id, "batch", bs.batch, "seqIdx", i, "seq", seq)
|
||||
s.hardErrCh <- fmt.Errorf("expected a single token during decode")
|
||||
return
|
||||
}
|
||||
|
||||
// fill in the final selected token value to replace the placeholder in the next batch
|
||||
// nextBatchTokensWritten++
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
@@ -575,9 +812,10 @@ func (s *Server) processBatch() error {
|
||||
if tokenTruncated || origLen == newLen {
|
||||
tokenLen--
|
||||
}
|
||||
|
||||
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
||||
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
s.finishSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -590,11 +828,9 @@ func (s *Server) processBatch() error {
|
||||
}
|
||||
|
||||
if !flushPending(seq) {
|
||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||
s.finishSequence(i, llm.DoneReasonConnectionClosed)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -736,7 +972,10 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
defer ctx.Close()
|
||||
|
||||
var err error
|
||||
inputs := make([]input.Input, s.batchSize)
|
||||
inputs := make([]*input.Input, s.batchSize)
|
||||
for i := range inputs {
|
||||
inputs[i] = &input.Input{}
|
||||
}
|
||||
mmStore := newMultimodalStore()
|
||||
|
||||
// Multimodal strategy:
|
||||
@@ -778,8 +1017,11 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
}
|
||||
|
||||
if len(inputs) < s.batchSize {
|
||||
newInputs := make([]input.Input, s.batchSize)
|
||||
newInputs := make([]*input.Input, s.batchSize)
|
||||
copy(newInputs, inputs)
|
||||
for i := len(inputs); i < s.batchSize; i++ {
|
||||
newInputs[i] = &input.Input{}
|
||||
}
|
||||
inputs = newInputs
|
||||
}
|
||||
}
|
||||
@@ -842,6 +1084,7 @@ func (s *Server) allocModel(
|
||||
// Convert memory allocation panics to errors
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
debug.PrintStack()
|
||||
if err, ok := r.(error); ok {
|
||||
panicErr = err
|
||||
} else {
|
||||
@@ -1011,6 +1254,7 @@ func Execute(args []string) error {
|
||||
server := &Server{
|
||||
modelPath: *mpath,
|
||||
status: llm.ServerStatusLaunched,
|
||||
hardErrCh: make(chan error, 1),
|
||||
}
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
Reference in New Issue
Block a user