From 4e3ad2c2edf10fa2135b91c396018a88a55f9910 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 30 Oct 2025 20:20:57 -0700 Subject: [PATCH] use slice/chunks --- ml/nn/pooling/pooling.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ml/nn/pooling/pooling.go b/ml/nn/pooling/pooling.go index 63b63b3af3..47af874635 100644 --- a/ml/nn/pooling/pooling.go +++ b/ml/nn/pooling/pooling.go @@ -32,10 +32,9 @@ func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) case TypeCLS: - return hiddenStates.View(ctx, 0, hiddenStates.Dim(0)) + return hiddenStates.Slice(ctx, 1, 0, 1, 1) case TypeLast: - hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0)) - return hiddenStates + return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1) default: panic("unknown pooling type") }