mirror of
https://github.com/ollama/ollama.git
synced 2025-11-12 12:37:37 +01:00
batch: use tensors for outputs (#12185)
this cleans up the model interface slightly without too much impact in other areas
This commit is contained in:
@@ -54,10 +54,9 @@ type Batch struct {
|
|||||||
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||||
Inputs ml.Tensor
|
Inputs ml.Tensor
|
||||||
|
|
||||||
// Multimodal is a set of multimodal embeddings previously created by
|
// Outputs are the set of indicies into Inputs for which output data should
|
||||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
// be returned.
|
||||||
// models or for batches without multimodal elements.
|
Outputs ml.Tensor
|
||||||
Multimodal []MultimodalIndex
|
|
||||||
|
|
||||||
// Positions is the position for each Input, relative to its sequence. Equal
|
// Positions is the position for each Input, relative to its sequence. Equal
|
||||||
// in length to Inputs.
|
// in length to Inputs.
|
||||||
@@ -66,7 +65,8 @@ type Batch struct {
|
|||||||
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||||
Sequences []int
|
Sequences []int
|
||||||
|
|
||||||
// Outputs are the set of indicies into Inputs for which output data should
|
// Multimodal is a set of multimodal embeddings previously created by
|
||||||
// be returned.
|
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||||
Outputs []int32
|
// models or for batches without multimodal elements.
|
||||||
|
Multimodal []MultimodalIndex
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
|||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||||
@@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
|
|
||||||
var lastLayerOutputs ml.Tensor
|
var lastLayerOutputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
lastLayerOutputs = outputs
|
lastLayerOutputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ type embedModel struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
batch.Outputs = batch.Positions // return all positions
|
|
||||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||||
|
|
||||||
switch m.PoolingType {
|
switch m.PoolingType {
|
||||||
|
|||||||
@@ -161,7 +161,6 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
|||||||
|
|
||||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||||
@@ -194,7 +193,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
|||||||
|
|
||||||
var lastLayerOutputs ml.Tensor
|
var lastLayerOutputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
lastLayerOutputs = outputs
|
lastLayerOutputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
|||||||
|
|
||||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
|
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
|
||||||
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||||
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
|
hiddenStates = hiddenStates.Rows(ctx, batch.Outputs)
|
||||||
|
|
||||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||||
return m.Output.Forward(ctx, hiddenStates), nil
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
|
|||||||
@@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
var outputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 {
|
if i == len(m.TransformerBlocks)-1 {
|
||||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
outputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
|
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
|
|
||||||
var outputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
outputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
|
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
|
||||||
|
|||||||
@@ -176,9 +176,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
|||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
|||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
// TODO: attention mask, cross attention mask
|
// TODO: attention mask, cross attention mask
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
|
|
||||||
var outputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
outputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
|
||||||
|
|||||||
@@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
|||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache)
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
|
|
||||||
var outputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
outputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||||
|
|||||||
@@ -467,6 +467,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
|||||||
|
|
||||||
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
||||||
var batchInputs []*input.Input
|
var batchInputs []*input.Input
|
||||||
|
var batchOutputs []int32
|
||||||
var batch input.Batch
|
var batch input.Batch
|
||||||
|
|
||||||
resumeSeq := -1
|
resumeSeq := -1
|
||||||
@@ -549,9 +550,9 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
|||||||
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||||
|
|
||||||
seq.iBatch = len(batch.Outputs)
|
seq.iBatch = len(batchOutputs)
|
||||||
if i+1 == len(seq.inputs) {
|
if i+1 == len(seq.inputs) || seq.embeddingOnly {
|
||||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
batchOutputs = append(batchOutputs, int32(len(batchInputs)-1))
|
||||||
}
|
}
|
||||||
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||||
@@ -576,6 +577,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
|||||||
|
|
||||||
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
|
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
|
||||||
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
|
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
|
||||||
|
batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs))
|
||||||
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
|
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("failed to build graph: %w", err)
|
err = fmt.Errorf("failed to build graph: %w", err)
|
||||||
@@ -703,8 +705,8 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// sample a token
|
// sample a token
|
||||||
vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
|
vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0)
|
||||||
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
|
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||||
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||||
@@ -1046,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error {
|
|||||||
batch.Positions[i] = int32(i)
|
batch.Positions[i] = int32(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.Outputs = make([]int32, s.parallel)
|
|
||||||
for i := range batch.Outputs {
|
|
||||||
batch.Outputs[i] = int32(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
||||||
|
batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel)
|
||||||
|
|
||||||
cache := s.model.Config().Cache
|
cache := s.model.Config().Cache
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user