mirror of
https://github.com/ollama/ollama.git
synced 2025-11-10 19:18:06 +01:00
bert
This commit is contained in:
@@ -29,7 +29,7 @@ type Model struct {
|
|||||||
// Forward implements model.Model.
|
// Forward implements model.Model.
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize))
|
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.Slice(ctx, 1, 0, 1, 1))
|
||||||
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromInts(batch.Positions, len(batch.Positions))))
|
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromInts(batch.Positions, len(batch.Positions))))
|
||||||
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
|
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user