From 702d5c71c96d2f633ec3773f954154598ff09649 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Mon, 3 Nov 2025 12:54:47 -0800 Subject: [PATCH] llama4 --- model/models/llama4/model.go | 8 ++------ model/models/llama4/model_vision.go | 22 +++++++++------------- 2 files changed, 11 insertions(+), 19 deletions(-) diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index 5eeac07c2b..4a22bc4bb3 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -105,9 +105,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input for range aspectRatio.Y { for x := range aspectRatio.X { - view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset, - projectedOutputs.Dim(0), projectedOutputs.Stride(1), - patchesPerChunk) + view := projectedOutputs.Slice(ctx, 1, offset, offset+patchesPerChunk, 1) var separator separator if x < aspectRatio.X-1 { separator.x = true // <|tile_x_separator|> @@ -120,9 +118,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input } } - view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset, - projectedOutputs.Dim(0), projectedOutputs.Stride(1), - patchesPerChunk) + view := projectedOutputs.Slice(ctx, 1, offset, offset+patchesPerChunk, 1) multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator{}}) return multimodal, nil diff --git a/model/models/llama4/model_vision.go b/model/models/llama4/model_vision.go index 1aa50aec46..ff6b7fcf2a 100644 --- a/model/models/llama4/model_vision.go +++ b/model/models/llama4/model_vision.go @@ -37,27 +37,23 @@ type VisionAttention struct { func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3) - t = t.Reshape(ctx, 2, t.Dim(0)/2, t.Dim(1)*t.Dim(2)*t.Dim(3)) - // t1 = t[..., 0::2] - t1 := t.View(ctx, 0, 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx) - t1 = t1.Reshape(ctx, width/2, height, channels, tiles) + t1 := t.Slice(ctx, 0, 0, t.Dim(0), 2) // t2 = t[..., 1::2] - t2 := t.View(ctx, t.Stride(0), 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx) - t2 = t2.Reshape(ctx, width/2, height, channels, tiles) + t2 := t.Slice(ctx, 0, 1, t.Dim(0), 2) // cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1) cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0) - cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, cosOut.Dim(1)*cosOut.Dim(2)*cosOut.Dim(3)) - cosOut = cosOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) - cosOut = cosOut.Reshape(ctx, width, height, channels, tiles) + cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, -1) + cosOut = cosOut.Permute(ctx, 1, 0, 2, 3) + cosOut = cosOut.Contiguous(ctx, width, height, channels, tiles) // sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1) - sinOut := t2.Neg(ctx).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0) - sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, sinOut.Dim(1)*sinOut.Dim(2)*sinOut.Dim(3)) - sinOut = sinOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) - sinOut = sinOut.Reshape(ctx, width, height, channels, tiles) + sinOut := t2.Scale(ctx, -1).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0) + sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, -1) + sinOut = sinOut.Permute(ctx, 1, 0, 2, 3) + sinOut = sinOut.Contiguous(ctx, width, height, channels, tiles) return cosOut.Add(ctx, sinOut) }