mirror of
https://github.com/ollama/ollama.git
synced 2025-03-18 05:41:43 +01:00
use non-causal mask only for image positions
This commit is contained in:
parent
9d2a20a763
commit
e95278932b
@ -21,9 +21,10 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
|
||||
type Causal struct {
|
||||
DType ml.DType
|
||||
Capacity int32
|
||||
causal bool
|
||||
windowSize int32
|
||||
|
||||
opts CausalOptions
|
||||
|
||||
// config controls mostly backend-specific optimizations
|
||||
config *ml.CacheConfig
|
||||
|
||||
@ -79,7 +80,6 @@ type cellRange struct {
|
||||
|
||||
func NewCausalCache(shift shiftFn) *Causal {
|
||||
return &Causal{
|
||||
causal: true,
|
||||
windowSize: math.MaxInt32,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
@ -90,7 +90,6 @@ func NewCausalCache(shift shiftFn) *Causal {
|
||||
|
||||
func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||
return &Causal{
|
||||
causal: true,
|
||||
windowSize: windowSize,
|
||||
shiftFn: shift,
|
||||
ctxs: make(map[int]ml.Context),
|
||||
@ -235,9 +234,10 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||
mask := make([]float32, batchSize*length)
|
||||
|
||||
for i := range c.curBatchSize {
|
||||
enabled := !slices.Contains(c.opts.Except, c.curPositions[i])
|
||||
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
|
||||
(c.causal && c.cells[j].pos > c.curPositions[i]) ||
|
||||
(enabled && c.cells[j].pos > c.curPositions[i]) ||
|
||||
c.cells[j].pos < c.curPositions[i]-c.windowSize {
|
||||
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
}
|
||||
@ -404,15 +404,19 @@ func (c *Causal) SetLayer(layer int) {
|
||||
c.curLayer = layer
|
||||
}
|
||||
|
||||
type CausalOptions struct {
|
||||
// Enabled controls whether the causal mask is generated for a particular position.
|
||||
Except []int32
|
||||
}
|
||||
|
||||
// SetCausal enables or disables causal mask generation for subsequent calls to Get.
|
||||
// This state carries over to future forward passes. The default value is true.
|
||||
//
|
||||
// ctx may be set to nil if this is called from outside of a forward pass, for
|
||||
// example, when initializing the cache.
|
||||
func (c *Causal) SetCausal(ctx ml.Context, causal bool) {
|
||||
if c.causal != causal {
|
||||
c.causal = causal
|
||||
|
||||
func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
||||
if !slices.Equal(c.opts.Except, opts.Except) {
|
||||
c.opts = opts
|
||||
if ctx != nil {
|
||||
var err error
|
||||
c.curMask, err = c.buildMask(ctx)
|
||||
|
@ -183,8 +183,12 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||
hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1))
|
||||
|
||||
if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok {
|
||||
causal.SetCausal(ctx, false)
|
||||
defer causal.SetCausal(ctx, true)
|
||||
except := make([]int32, visionOutputs.Dim(1))
|
||||
for i := 0; i < visionOutputs.Dim(1); i++ {
|
||||
except[i] = int32(offset + i)
|
||||
}
|
||||
|
||||
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user