Jesse Gross a1cda80bcb model: Update encoder cache to use multimodal input processing handler
The encoder cache needs to know the position of images in the input
stream so that it knows when to delete them. Previously images didn't
have a position, so we implied one by breaking batches before an
image and then assuming the image was in the first position. However,
multimodal objects are now given explicit positions in the input
stream, so we can use that instead.

Breaking batches was also a way to simulate a cross attention mask
for mllama. However, given that it only supports a single sequence
and a single image, this mask doesn't serve any real purpose.
Removing the batch break does not appear to affect the quality of
the output.

Most of this is simply moving the input data structures to a new
package to avoid import cycles.
2025-03-09 17:05:26 -07:00

177 lines
4.2 KiB

package model
import (
fs "github.com/ollama/ollama/fs/ggml"
func TestParseTags(t *testing.T) {
cases := []struct {
value string
want Tag
value: "output",
want: Tag{
Name: "output",
value: "output,alt:token_embd",
want: Tag{
Name: "output",
Alternate: []string{
for _, tt := range cases {
t.Run(tt.value, func(t *testing.T) {
got := ParseTags(tt.value)
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff)
type fakeBackend struct {
names []string
type fakeTensor struct {
Name string
func (m *fakeBackend) Get(name string) ml.Tensor {
if slices.Contains(m.names, name) {
return &fakeTensor{Name: name}
return nil
func TestPopulateFields(t *testing.T) {
type fakeLayer struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_o"`
type fakeModel struct {
Input *nn.Embedding `gguf:"input"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output"`
Layers [2]fakeLayer `gguf:"blk"`
var m fakeModel
v := reflect.ValueOf(&m)
v.Elem().Set(populateFields(Base{b: &fakeBackend{
names: []string{
}}, v.Elem()))
if diff := cmp.Diff(fakeModel{
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
OutputNorm: &nn.RMSNorm{Weight: &fakeTensor{Name: "output_norm.weight"}},
Output: &nn.Linear{Weight: &fakeTensor{Name: "output.weight"}},
Layers: [2]fakeLayer{
Query: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_q.weight"}},
Key: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_k.weight"}},
Value: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_v.weight"}},
Query: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_q.weight"}},
Key: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_k.weight"}},
Value: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_v.weight"}},
}, m); diff != "" {
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
func TestPopulateFieldsAlternateName(t *testing.T) {
type fakeModel struct {
Input *nn.Embedding `gguf:"input"`
Output *nn.Linear `gguf:"output,alt:input"`
m := fakeModel{}
v := reflect.ValueOf(&m)
v.Elem().Set(populateFields(Base{b: &fakeBackend{
names: []string{
}}, v.Elem()))
if diff := cmp.Diff(fakeModel{
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
Output: &nn.Linear{Weight: &fakeTensor{Name: "input.weight"}},
}, m); diff != "" {
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
func TestGetTextProcessor(t *testing.T) {
tp, err := getTextProcessor(fs.KV{})
if err == nil {
t.Error("expected error")
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
t.Errorf("unexpected error: %v", err)
} else if tp != nil {
t.Error("expected nil tp")
models["dummy"] = func(ml.Config) (Model, error) {
return notTextProcessorModel{}, nil
tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"})
if err == nil {
t.Error("expected error")
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
t.Errorf("unexpected error: %v", err)
} else if tp != nil {
t.Error("expected nil tp")
type notTextProcessorModel struct{}
func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
func (notTextProcessorModel) Backend() ml.Backend {
func (notTextProcessorModel) Config() config {