diff --git a/model/models/bert/embed.go b/model/models/bert/embed.go index 2d78710f79..f2dd1deb4d 100644 --- a/model/models/bert/embed.go +++ b/model/models/bert/embed.go @@ -29,7 +29,7 @@ type Model struct { // Forward implements model.Model. func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) - hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize)) + hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.Slice(ctx, 1, 0, 1, 1)) hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromInts(batch.Positions, len(batch.Positions)))) hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)