mirror of
https://github.com/ollama/ollama.git
synced 2025-07-17 22:03:16 +02:00
Disable causal attention based on batch index
Currently we are using positions, which are relative to a sequence and may not be unique.
This commit is contained in:
committed by
Michael Yang
parent
475005504e
commit
a8e83a7654
@ -144,6 +144,7 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
|||||||
c.curBatchSize = len(opts.Positions)
|
c.curBatchSize = len(opts.Positions)
|
||||||
c.curSequences = opts.Sequences
|
c.curSequences = opts.Sequences
|
||||||
c.curPositions = opts.Positions
|
c.curPositions = opts.Positions
|
||||||
|
c.opts.Except = nil
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
c.curLoc, err = c.findStartLoc()
|
c.curLoc, err = c.findStartLoc()
|
||||||
@ -234,7 +235,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
|||||||
mask := make([]float32, batchSize*length)
|
mask := make([]float32, batchSize*length)
|
||||||
|
|
||||||
for i := range c.curBatchSize {
|
for i := range c.curBatchSize {
|
||||||
enabled := !slices.Contains(c.opts.Except, c.curPositions[i])
|
enabled := !slices.Contains(c.opts.Except, i)
|
||||||
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||||
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||||
(enabled && c.cells[j].pos > c.curPositions[i]) ||
|
(enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||||
@ -405,15 +406,12 @@ func (c *Causal) SetLayer(layer int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CausalOptions struct {
|
type CausalOptions struct {
|
||||||
// Enabled controls whether the causal mask is generated for a particular position.
|
// Enabled controls whether the causal mask is generated for a particular index in a batch
|
||||||
Except []int32
|
Except []int
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetCausal enables or disables causal mask generation for subsequent calls to Get.
|
// SetCausal disables causal mask generation for a particular range of indicies in
|
||||||
// This state carries over to future forward passes. The default value is true.
|
// the current batch for subsequent calls to Get. The state resets for the next forward pass.
|
||||||
//
|
|
||||||
// ctx may be set to nil if this is called from outside of a forward pass, for
|
|
||||||
// example, when initializing the cache.
|
|
||||||
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
||||||
if !slices.Equal(c.opts.Except, opts.Except) {
|
if !slices.Equal(c.opts.Except, opts.Except) {
|
||||||
c.opts = opts
|
c.opts = opts
|
||||||
|
@ -173,10 +173,10 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex, positions []int32) []int32 {
|
func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex) []int {
|
||||||
var embedding ml.Tensor
|
var embedding ml.Tensor
|
||||||
var src, dst, length int
|
var src, dst, length int
|
||||||
var except []int32
|
var except []int
|
||||||
|
|
||||||
for _, image := range multimodal {
|
for _, image := range multimodal {
|
||||||
imageToken := image.Multimodal.(imageToken)
|
imageToken := image.Multimodal.(imageToken)
|
||||||
@ -204,7 +204,7 @@ func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []inpu
|
|||||||
length = 1
|
length = 1
|
||||||
}
|
}
|
||||||
|
|
||||||
except = append(except, positions[imageDst])
|
except = append(except, imageDst)
|
||||||
}
|
}
|
||||||
|
|
||||||
if embedding != nil {
|
if embedding != nil {
|
||||||
@ -219,7 +219,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
|||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
||||||
|
|
||||||
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal, opts.Positions)
|
except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal)
|
||||||
|
|
||||||
for i, layer := range m.Layers {
|
for i, layer := range m.Layers {
|
||||||
// gemma alternates between the sliding window (local) and causal (global)
|
// gemma alternates between the sliding window (local) and causal (global)
|
||||||
|
Reference in New Issue
Block a user