From 6f7117145f56e00d16572ed11cb8f83c4b3af636 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 15 Sep 2025 14:33:06 -0700 Subject: [PATCH] batch: use tensors for outputs (#12185) this cleans up the model interface slightly without too much impact in other areas --- model/input/input.go | 14 +++++++------- model/models/gemma2/model.go | 3 +-- model/models/gemma3/embed.go | 1 - model/models/gemma3/model_text.go | 3 +-- model/models/gemma3n/model_text.go | 2 +- model/models/gptoss/model.go | 4 ++-- model/models/llama/model.go | 2 +- model/models/llama4/model.go | 4 +--- model/models/mistral3/model.go | 3 +-- model/models/mllama/model.go | 3 +-- model/models/qwen2/model.go | 2 +- model/models/qwen25vl/model.go | 3 +-- model/models/qwen3/model.go | 2 +- runner/ollamarunner/runner.go | 18 ++++++++---------- 14 files changed, 27 insertions(+), 37 deletions(-) diff --git a/model/input/input.go b/model/input/input.go index bd9b53ec67..35dc41b354 100644 --- a/model/input/input.go +++ b/model/input/input.go @@ -54,10 +54,9 @@ type Batch struct { // Inputs is the input tokens, including placeholders for multimodal inputs. Inputs ml.Tensor - // Multimodal is a set of multimodal embeddings previously created by - // EncodeMultimodal, along with an index into Inputs. Unused for text-only - // models or for batches without multimodal elements. - Multimodal []MultimodalIndex + // Outputs are the set of indicies into Inputs for which output data should + // be returned. + Outputs ml.Tensor // Positions is the position for each Input, relative to its sequence. Equal // in length to Inputs. @@ -66,7 +65,8 @@ type Batch struct { // Sequences is the sequence for each Input. Equal in length to Inputs. Sequences []int - // Outputs are the set of indicies into Inputs for which output data should - // be returned. - Outputs []int32 + // Multimodal is a set of multimodal embeddings previously created by + // EncodeMultimodal, along with an index into Inputs. Unused for text-only + // models or for batches without multimodal elements. + Multimodal []MultimodalIndex } diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index e621d03ae2..84c89e1fe4 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) @@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var lastLayerOutputs ml.Tensor if i == len(m.Layers)-1 { - lastLayerOutputs = outputs + lastLayerOutputs = batch.Outputs } hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options) diff --git a/model/models/gemma3/embed.go b/model/models/gemma3/embed.go index 16c299e22d..7d1e269ff5 100644 --- a/model/models/gemma3/embed.go +++ b/model/models/gemma3/embed.go @@ -22,7 +22,6 @@ type embedModel struct { } func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - batch.Outputs = batch.Positions // return all positions hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) switch m.PoolingType { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 2a3b23939c..5e515a9272 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -161,7 +161,6 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize))) @@ -194,7 +193,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac var lastLayerOutputs ml.Tensor if i == len(m.Layers)-1 { - lastLayerOutputs = outputs + lastLayerOutputs = batch.Outputs } hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig) diff --git a/model/models/gemma3n/model_text.go b/model/models/gemma3n/model_text.go index b75a2abb37..eeb9ab0286 100644 --- a/model/models/gemma3n/model_text.go +++ b/model/models/gemma3n/model_text.go @@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx) hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx) - hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))) + hiddenStates = hiddenStates.Rows(ctx, batch.Outputs) hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) return m.Output.Forward(ctx, hiddenStates), nil diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go index 3ef078095d..a74f764876 100644 --- a/model/models/gptoss/model.go +++ b/model/models/gptoss/model.go @@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err } var outputs ml.Tensor - if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + if i == len(m.TransformerBlocks)-1 { + outputs = batch.Outputs } hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options) diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 77d8f36d3c..51273a0141 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -160,7 +160,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + outputs = batch.Outputs } hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options) diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index 99a898d2d9..9cb2efc87a 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -176,9 +176,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil } func init() { diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index 408e54d3dd..435b1a304d 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil } func init() { diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index d0ad4670ef..239d999d50 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { } positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) // TODO: attention mask, cross attention mask - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil } func init() { diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index 3c662f0682..93a5026125 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + outputs = batch.Outputs } hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options) diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index d73f499d29..6c76305db8 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache) + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache) } func init() { diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index 7a83e0d04a..ec2adaa7d7 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -165,7 +165,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + outputs = batch.Outputs } hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 1081a1f555..3a32384f86 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -467,6 +467,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er // Prepare the seqs and batch, but defer the input token values as we may not be ready yet var batchInputs []*input.Input + var batchOutputs []int32 var batch input.Batch resumeSeq := -1 @@ -549,9 +550,9 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) batch.Sequences = append(batch.Sequences, seq.cache.Id) - seq.iBatch = len(batch.Outputs) - if i+1 == len(seq.inputs) { - batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1)) + seq.iBatch = len(batchOutputs) + if i+1 == len(seq.inputs) || seq.embeddingOnly { + batchOutputs = append(batchOutputs, int32(len(batchInputs)-1)) } logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs)) seq.pendingInputs = append(seq.pendingInputs, inp) @@ -576,6 +577,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er // Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs)) + batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs)) nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch) if err != nil { err = fmt.Errorf("failed to build graph: %w", err) @@ -703,8 +705,8 @@ func (s *Server) computeBatch(activeBatch batchState) { } // sample a token - vocabSize := len(outputs) / len(activeBatch.batch.Outputs) - logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches) + vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0) + logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches) token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]) if err != nil { s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err) @@ -1046,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error { batch.Positions[i] = int32(i) } - batch.Outputs = make([]int32, s.parallel) - for i := range batch.Outputs { - batch.Outputs[i] = int32(i) - } - batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) + batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel) cache := s.model.Config().Cache if cache != nil {