mirror of
https://github.com/ollama/ollama.git
synced 2025-11-10 17:17:56 +01:00
use slice/chunks
This commit is contained in:
@@ -32,10 +32,9 @@ func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
||||
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
case TypeCLS:
|
||||
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
|
||||
return hiddenStates.Slice(ctx, 1, 0, 1, 1)
|
||||
case TypeLast:
|
||||
hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0))
|
||||
return hiddenStates
|
||||
return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1)
|
||||
default:
|
||||
panic("unknown pooling type")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user