This commit is contained in:
Michael Yang
2025-11-03 12:54:47 -08:00
parent eb3e1de5c5
commit 702d5c71c9
2 changed files with 11 additions and 19 deletions

View File

@@ -105,9 +105,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
for range aspectRatio.Y {
for x := range aspectRatio.X {
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset,
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
patchesPerChunk)
view := projectedOutputs.Slice(ctx, 1, offset, offset+patchesPerChunk, 1)
var separator separator
if x < aspectRatio.X-1 {
separator.x = true // <|tile_x_separator|>
@@ -120,9 +118,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
}
}
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset,
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
patchesPerChunk)
view := projectedOutputs.Slice(ctx, 1, offset, offset+patchesPerChunk, 1)
multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator{}})
return multimodal, nil

View File

@@ -37,27 +37,23 @@ type VisionAttention struct {
func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3)
t = t.Reshape(ctx, 2, t.Dim(0)/2, t.Dim(1)*t.Dim(2)*t.Dim(3))
// t1 = t[..., 0::2]
t1 := t.View(ctx, 0, 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
t1 = t1.Reshape(ctx, width/2, height, channels, tiles)
t1 := t.Slice(ctx, 0, 0, t.Dim(0), 2)
// t2 = t[..., 1::2]
t2 := t.View(ctx, t.Stride(0), 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
t2 = t2.Reshape(ctx, width/2, height, channels, tiles)
t2 := t.Slice(ctx, 0, 1, t.Dim(0), 2)
// cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1)
cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0)
cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, cosOut.Dim(1)*cosOut.Dim(2)*cosOut.Dim(3))
cosOut = cosOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
cosOut = cosOut.Reshape(ctx, width, height, channels, tiles)
cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, -1)
cosOut = cosOut.Permute(ctx, 1, 0, 2, 3)
cosOut = cosOut.Contiguous(ctx, width, height, channels, tiles)
// sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1)
sinOut := t2.Neg(ctx).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0)
sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, sinOut.Dim(1)*sinOut.Dim(2)*sinOut.Dim(3))
sinOut = sinOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
sinOut = sinOut.Reshape(ctx, width, height, channels, tiles)
sinOut := t2.Scale(ctx, -1).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0)
sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, -1)
sinOut = sinOut.Permute(ctx, 1, 0, 2, 3)
sinOut = sinOut.Contiguous(ctx, width, height, channels, tiles)
return cosOut.Add(ctx, sinOut)
}