use slice/chunks

This commit is contained in:
Michael Yang
2025-10-30 20:20:57 -07:00
parent eadae522dc
commit 4e3ad2c2ed

View File

@@ -32,10 +32,9 @@ func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
case TypeCLS: case TypeCLS:
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0)) return hiddenStates.Slice(ctx, 1, 0, 1, 1)
case TypeLast: case TypeLast:
hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0)) return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1)
return hiddenStates
default: default:
panic("unknown pooling type") panic("unknown pooling type")
} }