model: Update encoder cache to use multimodal input processing handler

The encoder cache needs to know the position of images in the input
stream so that it knows when to delete them. Previously images didn't
have a position, so we implied one by breaking batches before an
image and then assuming the image was in the first position. However,
multimodal objects are now given explicit positions in the input
stream, so we can use that instead.

Breaking batches was also a way to simulate a cross attention mask
for mllama. However, given that it only supports a single sequence
and a single image, this mask doesn't serve any real purpose.
Removing the batch break does not appear to affect the quality of
the output.

Most of this is simply moving the input data structures to a new
package to avoid import cycles.
This commit is contained in:
Jesse Gross
2025-03-08 15:45:31 -08:00
committed by Jesse Gross
parent 4614fafae0
commit a1cda80bcb
13 changed files with 157 additions and 160 deletions

View File

@ -9,6 +9,7 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
@ -137,7 +138,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err

View File

@ -12,6 +12,7 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
@ -101,8 +102,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return m.Projector.Forward(ctx, crossAttentionStates), nil
}
func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Input, error) {
var images []model.Input
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
var images []input.Input
fnvHash := fnv.New64a()
for i := range inputs {
@ -125,15 +126,15 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Inpu
}
}
inputs = slices.DeleteFunc(inputs, func(input model.Input) bool { return input.Token == -1 })
inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
return inputs, nil
}
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor
if opts.Multimodal != nil {
crossAttentionStates = opts.Multimodal[0].Multimodal.(ml.Tensor)
if len(opts.Multimodal) > 0 {
crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
}
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))