mirror of
https://github.com/ollama/ollama.git
synced 2025-11-10 17:28:11 +01:00
llama4
This commit is contained in:
@@ -105,9 +105,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
|
||||
for range aspectRatio.Y {
|
||||
for x := range aspectRatio.X {
|
||||
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset,
|
||||
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
|
||||
patchesPerChunk)
|
||||
view := projectedOutputs.Slice(ctx, 1, offset, offset+patchesPerChunk, 1)
|
||||
var separator separator
|
||||
if x < aspectRatio.X-1 {
|
||||
separator.x = true // <|tile_x_separator|>
|
||||
@@ -120,9 +118,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
}
|
||||
}
|
||||
|
||||
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset,
|
||||
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
|
||||
patchesPerChunk)
|
||||
view := projectedOutputs.Slice(ctx, 1, offset, offset+patchesPerChunk, 1)
|
||||
multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator{}})
|
||||
|
||||
return multimodal, nil
|
||||
|
||||
@@ -37,27 +37,23 @@ type VisionAttention struct {
|
||||
func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
||||
width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3)
|
||||
|
||||
t = t.Reshape(ctx, 2, t.Dim(0)/2, t.Dim(1)*t.Dim(2)*t.Dim(3))
|
||||
|
||||
// t1 = t[..., 0::2]
|
||||
t1 := t.View(ctx, 0, 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
|
||||
t1 = t1.Reshape(ctx, width/2, height, channels, tiles)
|
||||
t1 := t.Slice(ctx, 0, 0, t.Dim(0), 2)
|
||||
|
||||
// t2 = t[..., 1::2]
|
||||
t2 := t.View(ctx, t.Stride(0), 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
|
||||
t2 = t2.Reshape(ctx, width/2, height, channels, tiles)
|
||||
t2 := t.Slice(ctx, 0, 1, t.Dim(0), 2)
|
||||
|
||||
// cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1)
|
||||
cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0)
|
||||
cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, cosOut.Dim(1)*cosOut.Dim(2)*cosOut.Dim(3))
|
||||
cosOut = cosOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
cosOut = cosOut.Reshape(ctx, width, height, channels, tiles)
|
||||
cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, -1)
|
||||
cosOut = cosOut.Permute(ctx, 1, 0, 2, 3)
|
||||
cosOut = cosOut.Contiguous(ctx, width, height, channels, tiles)
|
||||
|
||||
// sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1)
|
||||
sinOut := t2.Neg(ctx).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0)
|
||||
sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, sinOut.Dim(1)*sinOut.Dim(2)*sinOut.Dim(3))
|
||||
sinOut = sinOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
sinOut = sinOut.Reshape(ctx, width, height, channels, tiles)
|
||||
sinOut := t2.Scale(ctx, -1).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0)
|
||||
sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, -1)
|
||||
sinOut = sinOut.Permute(ctx, 1, 0, 2, 3)
|
||||
sinOut = sinOut.Contiguous(ctx, width, height, channels, tiles)
|
||||
|
||||
return cosOut.Add(ctx, sinOut)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user