From 8934324b72c962fdff69ecee9c13d5d69580a1b7 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 7 Mar 2025 17:38:36 -0800 Subject: [PATCH] use fast attention --- ml/backend/ggml/ggml.go | 4 ++-- model/models/gemma3/model.go | 4 ++-- model/models/gemma3/model_vision.go | 14 ++++---------- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 487ea524a..9ff4446b3 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -958,9 +958,9 @@ func (t *Tensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) m var tt *C.struct_ggml_tensor switch len(strides) { case 0: - tt = C.ggml_set_1d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset)) + tt = C.ggml_set_1d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset)) case 1: - tt = C.ggml_set_2d_inplace(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0])) + tt = C.ggml_set_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.size_t(offset), C.size_t(strides[0])) default: panic("unsupported number of dimensions") } diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 2bb5232ba..f9beccc24 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -138,8 +138,8 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu {Token: 255999}, // """ } - // - imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 262144}}, 256)...) + // pad inputs with placeholders for image embeddings + imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 0}}, 256)...) // imageInputs = append(imageInputs, input.Input{Token: 256000}) diff --git a/model/models/gemma3/model_vision.go b/model/models/gemma3/model_vision.go index ee6e3b6f0..49f9a5d29 100644 --- a/model/models/gemma3/model_vision.go +++ b/model/models/gemma3/model_vision.go @@ -24,17 +24,11 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op key := sa.Key.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState) - query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3) - key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3) - value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize).Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize) + key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize) + value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize) - scores := key.Mulmat(ctx, query) - scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) - scores = scores.Softmax(ctx) - - attention := value.Mulmat(ctx, scores) - attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize) - attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) hiddenState = sa.Output.Forward(ctx, attention)