mirror of
https://github.com/ollama/ollama.git
synced 2025-11-12 05:27:46 +01:00
batch: use tensors for outputs (#12185)
this cleans up the model interface slightly without too much impact in other areas
This commit is contained in:
@@ -54,10 +54,9 @@ type Batch struct {
|
||||
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||
Inputs ml.Tensor
|
||||
|
||||
// Multimodal is a set of multimodal embeddings previously created by
|
||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||
// models or for batches without multimodal elements.
|
||||
Multimodal []MultimodalIndex
|
||||
// Outputs are the set of indicies into Inputs for which output data should
|
||||
// be returned.
|
||||
Outputs ml.Tensor
|
||||
|
||||
// Positions is the position for each Input, relative to its sequence. Equal
|
||||
// in length to Inputs.
|
||||
@@ -66,7 +65,8 @@ type Batch struct {
|
||||
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||
Sequences []int
|
||||
|
||||
// Outputs are the set of indicies into Inputs for which output data should
|
||||
// be returned.
|
||||
Outputs []int32
|
||||
// Multimodal is a set of multimodal embeddings previously created by
|
||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||
// models or for batches without multimodal elements.
|
||||
Multimodal []MultimodalIndex
|
||||
}
|
||||
|
||||
@@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||
@@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
||||
|
||||
@@ -22,7 +22,6 @@ type embedModel struct {
|
||||
}
|
||||
|
||||
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
batch.Outputs = batch.Positions // return all positions
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
|
||||
switch m.PoolingType {
|
||||
|
||||
@@ -161,7 +161,6 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||
@@ -194,7 +193,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
||||
|
||||
@@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
||||
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
|
||||
hiddenStates = hiddenStates.Rows(ctx, batch.Outputs)
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
|
||||
@@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
|
||||
}
|
||||
|
||||
var outputs ml.Tensor
|
||||
if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if i == len(m.TransformerBlocks)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
|
||||
|
||||
@@ -160,7 +160,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
|
||||
|
||||
@@ -176,9 +176,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
}
|
||||
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
// TODO: attention mask, cross attention mask
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
|
||||
|
||||
@@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache)
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -165,7 +165,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||
|
||||
@@ -467,6 +467,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
|
||||
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
||||
var batchInputs []*input.Input
|
||||
var batchOutputs []int32
|
||||
var batch input.Batch
|
||||
|
||||
resumeSeq := -1
|
||||
@@ -549,9 +550,9 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||
|
||||
seq.iBatch = len(batch.Outputs)
|
||||
if i+1 == len(seq.inputs) {
|
||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||
seq.iBatch = len(batchOutputs)
|
||||
if i+1 == len(seq.inputs) || seq.embeddingOnly {
|
||||
batchOutputs = append(batchOutputs, int32(len(batchInputs)-1))
|
||||
}
|
||||
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||
@@ -576,6 +577,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
|
||||
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
|
||||
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
|
||||
batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs))
|
||||
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to build graph: %w", err)
|
||||
@@ -703,8 +705,8 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
}
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
|
||||
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||
vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0)
|
||||
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||
if err != nil {
|
||||
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||
@@ -1046,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
batch.Positions[i] = int32(i)
|
||||
}
|
||||
|
||||
batch.Outputs = make([]int32, s.parallel)
|
||||
for i := range batch.Outputs {
|
||||
batch.Outputs[i] = int32(i)
|
||||
}
|
||||
|
||||
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
||||
batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel)
|
||||
|
||||
cache := s.model.Config().Cache
|
||||
if cache != nil {
|
||||
|
||||
Reference in New Issue
Block a user