This commit is contained in:
Michael Yang
2025-01-22 10:51:42 -08:00
parent 49df03da9a
commit 44b39749d5
56 changed files with 3690 additions and 475 deletions

155
model/llama/model.go Normal file
View File

@@ -0,0 +1,155 @@
package llama
import (
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
)
type Options struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int64
eps, ropeBase, ropeScale float32
ropeDim uint32
}
type Model struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*Options
}
func New(c ml.Config) (model.Model, error) {
return &Model{
BytePairEncoding: model.BytePairEncoding{
Pretokenizer: c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
Vocabulary: &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: c.Uint("tokenizer.ggml.bos_token_id"),
EOS: c.Uint("tokenizer.ggml.eos_token_id"),
},
},
Layers: make([]Layer, c.Uint("block_count")),
Options: &Options{
hiddenSize: int64(c.Uint("embedding_length")),
numHeads: int64(c.Uint("attention.head_count")),
numKVHeads: int64(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
},
}, nil
}
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k, v = cache.Put(ctx, k, v, cache.Options)
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.Mulmat(ctx, q)
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
kq = kq.Softmax(ctx)
kqv := v.Mulmat(ctx, kq)
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv)
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs()))
if err != nil {
return nil, err
}
positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions()))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
for i, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, positions, opts.Cache.Sub(i), m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = m.Output.Forward(ctx, hiddenState)
outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1)
if err != nil {
return nil, err
}
return hiddenState.Rows(ctx, outputs), nil
}
func init() {
model.Register("llama", New)
}

90
model/mllama/model.go Normal file
View File

@@ -0,0 +1,90 @@
package mllama
import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
)
type Model struct {
model.Base
*VisionModel `gguf:"v,vision"`
*TextModel
Projector *nn.Linear `gguf:"mm.0"`
ImageProcessor
TextProcessor
}
func New(c ml.Config) (model.Model, error) {
return &Model{
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextProcessor: newTextProcessor(c),
TextModel: newTextModel(c),
}, nil
}
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor
if opts.Images != nil {
f32s, aspectRatioID, err := m.ImageProcessor.ProcessImage(opts.Images[0])
if err != nil {
return nil, err
}
pixelValues, err := ctx.FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
m.ImageProcessor.maxNumTiles,
)
if err != nil {
return nil, err
}
aspectRatio, err := ctx.FromIntSlice([]int32{int32(aspectRatioID)}, 1)
if err != nil {
return nil, err
}
positions := make([]int32, 1601)
for i := range positions {
positions[i] = int32(i)
}
positionIDs, err := ctx.FromIntSlice(positions, len(positions))
if err != nil {
return nil, err
}
crossAttentionStates = m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates)
}
inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs()))
if err != nil {
return nil, err
}
positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions()))
if err != nil {
return nil, err
}
// TODO: attention mask, cross attention mask
hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, opts.Cache)
outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1)
if err != nil {
return nil, err
}
return hiddenState.Rows(ctx, outputs), nil
}
func init() {
model.Register("mllama", New)
}

225
model/mllama/model_text.go Normal file
View File

@@ -0,0 +1,225 @@
package mllama
import (
"math"
"slices"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
)
type TextSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key, value = cache.Put(ctx, key, value, cache.Options)
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scores := key.Mulmat(ctx, query)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
if mask != nil {
scores = scores.Add(ctx, mask)
}
scores = scores.Softmax(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
type TextMLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type TextSelfAttentionDecoderLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *TextSelfAttention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *TextMLP
}
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
residual := hiddenState
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
type TextCrossAttention struct {
QueryNorm *nn.RMSNorm `gguf:"cross_attn_q_norm"`
Query *nn.Linear `gguf:"cross_attn_q_proj"`
KeyNorm *nn.RMSNorm `gguf:"cross_attn_k_norm"`
Key *nn.Linear `gguf:"cross_attn_k_proj"`
Value *nn.Linear `gguf:"cross_attn_v_proj"`
Output *nn.Linear `gguf:"cross_attn_o_proj"`
}
func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
query := ca.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
key := ca.Key.Forward(ctx, crossAttentionStates)
key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
key = ca.KeyNorm.Forward(ctx, key, opts.eps)
value := ca.Value.Forward(ctx, crossAttentionStates)
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
// TODO cache key, value
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scores := key.Mulmat(ctx, query)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
scores = scores.Softmax(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return ca.Output.Forward(ctx, attention)
}
type TextCrossAttentionDecoderLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
CrossAttention *TextCrossAttention
AttentionGate ml.Tensor `gguf:"cross_attn_attn_gate"`
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *TextMLP
MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
}
func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
residual := hiddenState
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = d.CrossAttention.Forward(ctx, hiddenState, crossAttentionStates, cache, opts)
hiddenState = hiddenState.Mul(ctx, d.AttentionGate.Tanh(ctx))
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
hiddenState = hiddenState.Mul(ctx, d.MLPGate.Tanh(ctx))
return hiddenState.Add(ctx, residual)
}
type TextDecoderLayer interface {
Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor
}
type TextDecoder struct {
Layers []TextDecoderLayer
}
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor {
for i, layer := range d.Layers {
if !slices.Contains(opts.crossAttentionLayers, uint32(i)) || crossAttentionStates != nil {
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache.Sub(i), opts)
}
}
return hiddenState
}
type TextModelOptions struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int64
eps, ropeBase, ropeScale float32
ropeDim uint32
crossAttentionLayers []uint32
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Transformer *TextDecoder `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output"`
*TextModelOptions
}
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState)
}
func newTextModel(c ml.Config) *TextModel {
var decoderLayers []TextDecoderLayer
for i := range c.Uint("block_count") {
var textDecoderLayer TextDecoderLayer
if slices.Contains(c.Uints("attention.cross_attention_layers"), i) {
textDecoderLayer = &TextCrossAttentionDecoderLayer{}
} else {
textDecoderLayer = &TextSelfAttentionDecoderLayer{}
}
decoderLayers = append(decoderLayers, textDecoderLayer)
}
return &TextModel{
Transformer: &TextDecoder{Layers: decoderLayers},
TextModelOptions: &TextModelOptions{
hiddenSize: int64(c.Uint("embedding_length")),
numHeads: int64(c.Uint("attention.head_count")),
numKVHeads: int64(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
},
}
}

View File

@@ -0,0 +1,234 @@
package mllama
import (
"math"
"slices"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
var batchSize int64 = 1
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_out"`
Gate ml.Tensor `gguf:"attn_gate"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scores := key.Mulmat(ctx, query)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
scores = scores.Softmax(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
hiddenState = sa.Output.Forward(ctx, attention)
if sa.Gate != nil {
hiddenState = hiddenState.Mul(ctx, sa.Gate)
}
return hiddenState
}
type VisionMLP struct {
Down *nn.Linear `gguf:"ffn_down"`
Up *nn.Linear `gguf:"ffn_up"`
Gate ml.Tensor `gguf:"ffn_gate"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenState = mlp.Down.Forward(ctx, hiddenState).GELU(ctx)
hiddenState = mlp.Up.Forward(ctx, hiddenState)
if mlp.Gate != nil {
hiddenState = hiddenState.Mul(ctx, mlp.Gate)
}
return hiddenState
}
type VisionEncoderLayer struct {
AttentionNorm *nn.LayerNorm `gguf:"ln1"`
SelfAttention *VisionSelfAttention
MLPNorm *nn.LayerNorm `gguf:"ln2"`
MLP *VisionMLP
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenState
// self attention
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
// feed forward
hiddenState = e.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
type VisionEncoder struct {
Layers []VisionEncoderLayer
}
func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermediateLayersIndices []uint32, opts *VisionModelOptions) (ml.Tensor, []ml.Tensor) {
var intermediateHiddenStates []ml.Tensor
for i, layer := range e.Layers {
if slices.Contains(intermediateLayersIndices, uint32(i)) {
intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int64{1}, hiddenState.Shape()...)...))
}
hiddenState = layer.Forward(ctx, hiddenState, opts)
}
return hiddenState, intermediateHiddenStates
}
type PrecomputedAspectRatioEmbedding struct {
Embedding *nn.Embedding
Gate ml.Tensor `gguf:"gate"`
}
func (e *PrecomputedAspectRatioEmbedding) Forward(ctx ml.Context, hiddenState ml.Tensor, aspectRatioIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
embeddings := e.Embedding.Forward(ctx, aspectRatioIDs)
embeddings = embeddings.Reshape(ctx, opts.hiddenSize, 1, opts.numTiles)
if e.Gate != nil {
embeddings = embeddings.Mul(ctx, e.Gate)
}
return hiddenState.Add(ctx, embeddings)
}
type PrecomputedPositionEmbedding struct {
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
PositionEmbeddingGate ml.Tensor `gguf:"position_embd.gate"`
TilePositionEmbedding *nn.Embedding `gguf:"tile_position_embd"`
TilePositionEmbeddingGate ml.Tensor `gguf:"tile_position_embd.gate"`
}
func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, positionIDs, aspectRatioIDs ml.Tensor, numPositions int64, opts *VisionModelOptions) ml.Tensor {
positionEmbedding := e.PositionEmbedding.Forward(ctx, positionIDs)
if e.PositionEmbeddingGate != nil {
positionEmbedding = positionEmbedding.Mul(ctx, e.PositionEmbeddingGate)
}
hiddenState = hiddenState.Add(ctx, positionEmbedding)
tilePositionEmbedding := e.TilePositionEmbedding.Forward(ctx, aspectRatioIDs)
tilePositionEmbedding = tilePositionEmbedding.Reshape(ctx, opts.hiddenSize, numPositions, opts.numTiles)
if e.TilePositionEmbeddingGate != nil {
tilePositionEmbedding = tilePositionEmbedding.Mul(ctx, e.TilePositionEmbeddingGate)
}
return hiddenState.Add(ctx, tilePositionEmbedding)
}
type VisionModelOptions struct {
hiddenSize, numHeads, numTiles int64
imageSize, patchSize int
eps float32
intermediateLayersIndices []uint32
}
type VisionModel struct {
PatchEmbeddings *nn.Conv2D `gguf:"patch_embd"`
PreTilePositionEmbedding *PrecomputedAspectRatioEmbedding `gguf:"pre_tile_position_embd"`
PostTilePositionEmbedding *PrecomputedAspectRatioEmbedding `gguf:"post_tile_position_embd"`
PositionEmbedding *PrecomputedPositionEmbedding
PreLayerNorm *nn.LayerNorm `gguf:"pre_ln"`
PostLayerNorm *nn.LayerNorm `gguf:"post_ln"`
ClassEmbedding ml.Tensor `gguf:"class_embd"`
Transformer *VisionEncoder `gguf:"blk"`
GlobalTransformer *VisionEncoder `gguf:"global.blk"`
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRatioIDs ml.Tensor) ml.Tensor {
numPatches := int64((m.imageSize / m.patchSize) * (m.imageSize / m.patchSize))
numPositions := numPatches
if m.ClassEmbedding != nil {
numPositions++
}
hiddenState := m.PatchEmbeddings.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize, m.numTiles)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenState = m.PreTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions)
hiddenState = m.ClassEmbedding.Stack(ctx, 2, slices.Repeat([]ml.Tensor{m.ClassEmbedding}, int(m.numTiles)-1)...).Concat(ctx, hiddenState, 1)
hiddenState = m.PositionEmbedding.Forward(ctx, hiddenState, positionIDs, aspectRatioIDs, numPositions, m.VisionModelOptions)
hiddenState = m.PreLayerNorm.Forward(ctx, hiddenState, m.eps)
numPaddingPatches := 8 - (hiddenState.Dim(1)%8)%8
hiddenState = hiddenState.Pad(ctx, 0, numPaddingPatches, 0, 0)
hiddenState = hiddenState.Reshape(ctx, hiddenState.Dim(0), hiddenState.Dim(1)*hiddenState.Dim(2), batchSize)
hiddenState, intermediateHiddenStates := m.Transformer.Forward(ctx, hiddenState, m.intermediateLayersIndices, m.VisionModelOptions)
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
hiddenState = m.PostTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions)
hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, m.numTiles*(numPositions+numPaddingPatches), batchSize)
hiddenState, _ = m.GlobalTransformer.Forward(ctx, hiddenState, nil, m.VisionModelOptions)
hiddenStates := intermediateHiddenStates[0].Stack(ctx, 0, intermediateHiddenStates[1:]...)
hiddenStates = hiddenStates.Reshape(ctx, int64(len(intermediateHiddenStates))*m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
hiddenStates = hiddenStates.Unpad(ctx, 0, numPaddingPatches, 0, 0)
hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
hiddenState = hiddenState.Unpad(ctx, 0, numPaddingPatches, 0, 0)
return hiddenState.Concat(ctx, hiddenStates, 0)
}
func newVisionModel(c ml.Config) *VisionModel {
return &VisionModel{
Transformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count"))},
GlobalTransformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.global.block_count"))},
VisionModelOptions: &VisionModelOptions{
hiddenSize: int64(c.Uint("vision.embedding_length")),
numHeads: int64(c.Uint("vision.attention.head_count")),
numTiles: int64(c.Uint("vision.max_num_tiles")),
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
eps: c.Float("vision.attention.layer_norm_epsilon"),
intermediateLayersIndices: c.Uints("vision.intermediate_layers_indices"),
},
}
}

View File

@@ -0,0 +1,240 @@
package mllama
import (
"image"
"image/color"
"math"
"slices"
"golang.org/x/image/draw"
"github.com/ollama/ollama/ml"
)
type ImageProcessor struct {
imageSize, numChannels, maxNumTiles int
}
func newImageProcessor(c ml.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size")),
numChannels: int(c.Uint("vision.num_channels")),
maxNumTiles: int(c.Uint("vision.max_num_tiles")),
}
}
func (p *ImageProcessor) supportedAspectRatios(maxTiles int) []image.Point {
ratios := []image.Point{}
for w := range maxTiles {
for h := range maxTiles {
if (w+1)*(h+1) <= maxTiles {
ratios = append(ratios, image.Point{w + 1, h + 1})
}
}
}
return ratios
}
func (p *ImageProcessor) clip(a, a_min, a_max int) int {
if a < a_min {
return a_min
} else if a > a_max {
return a_max
}
return a
}
func (p *ImageProcessor) fitToCanvas(imageSize, canvasSize image.Point, tileSize int) image.Point {
targetWidth := p.clip(imageSize.X, tileSize, canvasSize.X)
targetHeight := p.clip(imageSize.Y, tileSize, canvasSize.Y)
scaleWidth := float64(targetWidth) / float64(imageSize.X)
scaleHeight := float64(targetHeight) / float64(imageSize.Y)
var w, h int
if scaleWidth < scaleHeight {
w = targetWidth
h = min(int(math.Floor(float64(imageSize.Y)*scaleWidth)), targetHeight)
} else {
w = min(int(math.Floor(float64(imageSize.X)*scaleHeight)), targetWidth)
h = targetHeight
}
return image.Point{w, h}
}
func (p *ImageProcessor) optimalTiledCanvas(imageSize image.Point, maxImageTiles, tileSize int) image.Point {
possibleTileArrangements := p.supportedAspectRatios(maxImageTiles)
possibleCanvasSizes := []image.Point{}
for _, pta := range possibleTileArrangements {
possibleCanvasSizes = append(possibleCanvasSizes, image.Point{pta.X * tileSize, pta.Y * tileSize})
}
scales := []float64{}
for _, pcs := range possibleCanvasSizes {
scaleHeight := float64(pcs.Y) / float64(imageSize.Y)
scaleWidth := float64(pcs.X) / float64(imageSize.X)
if scaleWidth > scaleHeight {
scales = append(scales, scaleHeight)
} else {
scales = append(scales, scaleWidth)
}
}
var minUpscale float64
var maxDownscale float64
var upscale bool
for _, s := range scales {
if s > 1.0 {
upscale = true
if minUpscale == 0 {
minUpscale = s
} else {
minUpscale = math.Min(minUpscale, s)
}
} else {
maxDownscale = math.Max(maxDownscale, s)
}
}
selectedScale := maxDownscale
if upscale {
selectedScale = minUpscale
}
var selectedCanvas image.Point
for n, pcs := range possibleCanvasSizes {
if scales[n] == selectedScale {
// choose the smallest possible canvas
if selectedCanvas.X == 0 && selectedCanvas.Y == 0 {
selectedCanvas = pcs
} else if pcs.X*pcs.Y < selectedCanvas.X*selectedCanvas.Y {
selectedCanvas = pcs
}
}
}
return selectedCanvas
}
func (p *ImageProcessor) splitToTiles(img image.Image, numTilesSize image.Point) []image.Image {
b := img.Bounds()
width := b.Max.X - b.Min.X
height := b.Max.Y - b.Min.Y
tileHeight := height / numTilesSize.Y
tileWidth := width / numTilesSize.X
images := []image.Image{}
for h := range numTilesSize.Y {
for w := range numTilesSize.X {
rect := image.Rect(tileWidth*w, tileHeight*h, tileWidth*(w+1), tileHeight*(h+1))
images = append(images, img.(interface {
SubImage(image.Rectangle) image.Image
}).SubImage(rect))
}
}
return images
}
// remove the "alpha" channel by drawing over a prefilled image
//
// remove the "alpha" channel by drawing over a prefilled image
//
//nolint:unused
func (p *ImageProcessor) compositeImage(img image.Image) image.Image {
dst := image.NewRGBA(img.Bounds())
white := color.RGBA{255, 255, 255, 255}
draw.Draw(dst, dst.Bounds(), &image.Uniform{white}, image.Point{}, draw.Src)
draw.Draw(dst, dst.Bounds(), img, img.Bounds().Min, draw.Over)
return dst
}
func (p *ImageProcessor) resize(img image.Image, outputSize image.Point, maxImageTiles int) (image.Image, image.Point) {
b := img.Bounds()
tileSize := outputSize.Y
canvasSize := p.optimalTiledCanvas(b.Max, maxImageTiles, tileSize)
aspectRatio := image.Point{canvasSize.X / tileSize, canvasSize.Y / tileSize}
newSize := p.fitToCanvas(b.Max, canvasSize, tileSize)
dst := image.NewRGBA(image.Rect(0, 0, newSize.X, newSize.Y))
// scaling choices:
// NearestNeighbor fast, blocky output
// ApproxBiLinear fast, medium quality
// BiLinear slow, high quality
// CatmullRom very slow, very high quality
draw.BiLinear.Scale(dst, dst.Rect, img, b, draw.Over, nil)
return dst, aspectRatio
}
func (p *ImageProcessor) pad(img image.Image, outputSize, aspectRatio image.Point) image.Image {
paddedSize := image.Point{
X: outputSize.X * aspectRatio.X,
Y: outputSize.Y * aspectRatio.Y,
}
dst := image.NewRGBA(image.Rect(0, 0, paddedSize.X, paddedSize.Y))
draw.Draw(dst, img.Bounds(), img, image.Point{0, 0}, draw.Over)
return dst
}
func (p *ImageProcessor) pack(img image.Image, aspectRatio image.Point, mean, std [3]float32) []float32 {
subImages := p.splitToTiles(img, aspectRatio)
var pixelVals []float32
for _, subImg := range subImages {
bounds := subImg.Bounds()
var rVals, gVals, bVals []float32
for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
for x := bounds.Min.X; x < bounds.Max.X; x++ {
c := subImg.At(x, y)
r, g, b, _ := c.RGBA()
rVal := float32(r>>8) / 255.0
gVal := float32(g>>8) / 255.0
bVal := float32(b>>8) / 255.0
rVal = (rVal - mean[0]) / std[0]
gVal = (gVal - mean[1]) / std[1]
bVal = (bVal - mean[2]) / std[2]
rVals = append(rVals, rVal)
gVals = append(gVals, gVal)
bVals = append(bVals, bVal)
}
}
pixelVals = append(pixelVals, rVals...)
pixelVals = append(pixelVals, gVals...)
pixelVals = append(pixelVals, bVals...)
}
return pixelVals
}
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, int, error) {
outputSize := image.Point{p.imageSize, p.imageSize}
// clip values
mean := [3]float32{0.48145466, 0.4578275, 0.40821073}
std := [3]float32{0.26862954, 0.26130258, 0.27577711}
newImage, aspectRatio := p.resize(img, outputSize, p.maxNumTiles)
newImage = p.pad(newImage, outputSize, aspectRatio)
data := p.pack(newImage, aspectRatio, mean, std)
aspectRatioIndex := slices.Index(p.supportedAspectRatios(p.maxNumTiles), aspectRatio) + 1
return data, aspectRatioIndex, nil
}

View File

@@ -0,0 +1,25 @@
package mllama
import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
)
type TextProcessor struct {
model.BytePairEncoding
}
func newTextProcessor(c ml.Config) TextProcessor {
return TextProcessor{
BytePairEncoding: model.BytePairEncoding{
Pretokenizer: c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
Vocabulary: &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: c.Uint("tokenizer.ggml.bos_token_id"),
EOS: c.Uint("tokenizer.ggml.eos_token_id"),
},
},
}
}

View File

@@ -0,0 +1,87 @@
package mllama
import (
"encoding/json"
"errors"
"os"
"path/filepath"
"strconv"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/model"
)
func TestProcessText(t *testing.T) {
ours, err := model.New(filepath.Join("testdata", "model.bin"))
if errors.Is(err, os.ErrNotExist) {
t.Skip("no model.bin")
} else if err != nil {
t.Fatal(err)
}
t.Run("decode", func(t *testing.T) {
f, err := os.Open(filepath.Join("testdata", "theirs.json"))
if errors.Is(err, os.ErrNotExist) {
t.Skip("no theirs.json")
} else if err != nil {
t.Fatal(err)
}
defer f.Close()
var theirs [][]byte
if err := json.NewDecoder(f).Decode(&theirs); err != nil {
t.Fatal(err)
}
for id := range theirs {
ids := []int32{int32(id)}
s, err := ours.(model.TextProcessor).Decode(ids)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(string(theirs[id]), s); diff != "" {
t.Errorf("%d no match (-theirs +ours):\n%s", id, diff)
}
}
})
t.Run("encode", func(t *testing.T) {
f, err := os.Open(filepath.Join("..", "testdata", "inputs.json"))
if errors.Is(err, os.ErrNotExist) {
t.Skip("no inputs.json")
} else if err != nil {
t.Fatal(err)
}
defer f.Close()
var inputs []struct {
Values []byte `json:"base64"`
IDs []int32 `json:"ids"`
}
if err := json.NewDecoder(f).Decode(&inputs); err != nil {
t.Fatal(err)
}
for i, input := range inputs {
if i == 45 {
t.Skip("skip 45")
}
t.Run(strconv.Itoa(i), func(t *testing.T) {
ids, err := ours.(model.TextProcessor).Encode(string(input.Values))
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(input.IDs, ids, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("%s: no match (-theirs +ours):\n%s", input.Values, diff)
}
})
}
})
}

1
model/mllama/testdata/model.bin vendored Symbolic link
View File

@@ -0,0 +1 @@
/Users/michaelyang/git/ollama/library/nltpt/Llama-3.2-11B-Vision-Instruct/merged.gguf

1
model/mllama/testdata/theirs.json vendored Normal file

File diff suppressed because one or more lines are too long

279
model/model.go Normal file
View File

@@ -0,0 +1,279 @@
package model
import (
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"log/slog"
"os"
"reflect"
"strconv"
"strings"
_ "golang.org/x/image/bmp"
_ "golang.org/x/image/tiff"
_ "golang.org/x/image/webp"
"github.com/ollama/ollama/cache"
"github.com/ollama/ollama/ml"
_ "github.com/ollama/ollama/ml/backend"
)
type Cache struct {
cache.Cache
cache.Options
}
func (c Cache) Sub(i int) Cache {
if c.Cache != nil {
return Cache{
Cache: c.Cache.Sub(i),
Options: c.Options,
}
}
return c
}
func (c Cache) Put(ctx ml.Context, key, value ml.Tensor, opts cache.Options) (ml.Tensor, ml.Tensor) {
if c.Cache != nil {
return c.Cache.Put(ctx, key, value, opts)
}
return key, value
}
type Options struct {
inputs []int32
Offset int
Images []image.Image
Cache
}
func (opts Options) Inputs() []int32 {
return opts.inputs[opts.Offset:]
}
func (opts Options) Positions() []int32 {
positions := make([]int32, len(opts.inputs)-opts.Offset)
for i := range positions {
positions[i] = int32(opts.Offset + i)
}
return positions
}
type OptionsFunc func(Model, *Options)
func WithInputIDs(ids []int32) OptionsFunc {
return func(m Model, opts *Options) {
opts.inputs = ids
}
}
func WithOffset(offset int) OptionsFunc {
return func(m Model, opts *Options) {
opts.Offset = offset
opts.Cache.Position = offset
}
}
func WithImage(img image.Image) OptionsFunc {
return func(m Model, opts *Options) {
opts.Images = append(opts.Images, img)
}
}
func WithCache(c cache.Cache) OptionsFunc {
return func(m Model, opts *Options) {
opts.Cache = Cache{
Cache: c,
Options: cache.Options{
Position: opts.Offset,
},
}
}
}
type Base struct {
b ml.Backend
}
func (m *Base) Backend() ml.Backend {
return m.b
}
type Model interface {
Forward(ml.Context, Options) (ml.Tensor, error)
Backend() ml.Backend
}
var models = make(map[string]func(ml.Config) (Model, error))
func Register(name string, f func(ml.Config) (Model, error)) {
if _, ok := models[name]; ok {
panic("model: model already registered")
}
models[name] = f
}
func New(s string) (Model, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
}
defer r.Close()
b, err := ml.NewBackend(r)
if err != nil {
return nil, err
}
arch := b.Config().Architecture()
f, ok := models[arch]
if !ok {
return nil, fmt.Errorf("unsupported model architecture %q", arch)
}
m, err := f(b.Config())
if err != nil {
return nil, err
}
v := reflect.ValueOf(m)
v.Elem().Set(populateFields(b, v))
return m, nil
}
func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
t := v.Type()
if t.Kind() == reflect.Pointer {
t, v = t.Elem(), v.Elem()
}
if t.Kind() == reflect.Struct {
allNil := true
for i := range t.NumField() {
tt := t.Field(i).Type
vv := v.Field(i)
if !vv.CanSet() {
continue
}
// make a copy
tagsCopy := tags
if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
tagsCopy = append(tagsCopy, ParseTags(tag))
}
if tt == reflect.TypeOf((*Base)(nil)).Elem() {
vv.Set(reflect.ValueOf(Base{b: b}))
} else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
var fn func([]Tag) [][]string
fn = func(tags []Tag) (values [][]string) {
if len(tags) < 1 {
return nil
}
values = [][]string{{tags[0].Name}}
for _, alt := range tags[0].Alternate {
values = append(values, []string{alt})
}
for i, value := range values {
for _, rest := range fn(tags[1:]) {
value = append(value, rest...)
}
values[i] = value
}
return values
}
names := fn(tagsCopy)
for _, name := range names {
if tensor := b.Get(strings.Join(name, ".")); tensor != nil {
slog.Debug("found tensor", "", tensor)
vv.Set(reflect.ValueOf(tensor))
break
}
}
} else if tt.Kind() == reflect.Pointer {
vvv := vv.Elem()
if vv.IsNil() {
vvv = reflect.New(tt.Elem())
}
if f := populateFields(b, vvv, tagsCopy...); f.CanAddr() {
vv.Set(f.Addr())
}
} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
for i := range vv.Len() {
vv.Index(i).Set(populateFields(b, vv.Index(i), append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
}
}
if !canNil(tt) || !vv.IsNil() {
allNil = false
}
}
if allNil {
return reflect.Zero(t)
}
}
return v
}
type Tag struct {
Name string
Alternate []string
}
func ParseTags(s string) (tag Tag) {
parts := strings.Split(s, ",")
if len(parts) > 0 {
tag.Name = parts[0]
for _, part := range parts[1:] {
if value, ok := strings.CutPrefix(part, "alt:"); ok {
tag.Alternate = append(tag.Alternate, value)
}
}
}
return
}
func canNil(t reflect.Type) bool {
return t.Kind() == reflect.Chan ||
t.Kind() == reflect.Func ||
t.Kind() == reflect.Interface ||
t.Kind() == reflect.Map ||
t.Kind() == reflect.Pointer ||
t.Kind() == reflect.Slice
}
func Forward(m Model, optsFuncs ...OptionsFunc) (ml.Tensor, error) {
var opts Options
for _, optsFunc := range optsFuncs {
optsFunc(m, &opts)
}
ctx := m.Backend().NewContext()
t, err := m.Forward(ctx, opts)
if err != nil {
return nil, err
}
defer ctx.Close()
return ctx.Compute(t), nil
}

136
model/model_test.go Normal file
View File

@@ -0,0 +1,136 @@
package model
import (
"reflect"
"slices"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn"
)
func TestParseTags(t *testing.T) {
cases := []struct {
value string
want Tag
}{
{
value: "output",
want: Tag{
Name: "output",
},
},
{
value: "output,alt:token_embd",
want: Tag{
Name: "output",
Alternate: []string{
"token_embd",
},
},
},
}
for _, tt := range cases {
t.Run(tt.value, func(t *testing.T) {
got := ParseTags(tt.value)
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff)
}
})
}
}
type fakeBackend struct {
*ggml.Backend
names []string
}
type fakeTensor struct {
*ggml.Tensor
Name string
}
func (m *fakeBackend) Get(name string) ml.Tensor {
if slices.Contains(m.names, name) {
return &fakeTensor{Name: name}
}
return nil
}
func TestPopulateFields(t *testing.T) {
type fakeLayer struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_o"`
}
type fakeModel struct {
Input *nn.Embedding `gguf:"input"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output"`
Layers [2]fakeLayer `gguf:"blk"`
}
var m fakeModel
v := reflect.ValueOf(&m)
v.Elem().Set(populateFields(&fakeBackend{
names: []string{
"input.weight",
"blk.0.attn_q.weight",
"blk.0.attn_k.weight",
"blk.0.attn_v.weight",
"blk.1.attn_q.weight",
"blk.1.attn_k.weight",
"blk.1.attn_v.weight",
"output_norm.weight",
"output.weight",
},
}, v))
if diff := cmp.Diff(fakeModel{
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
OutputNorm: &nn.RMSNorm{Weight: &fakeTensor{Name: "output_norm.weight"}},
Output: &nn.Linear{Weight: &fakeTensor{Name: "output.weight"}},
Layers: [2]fakeLayer{
{
Query: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_q.weight"}},
Key: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_k.weight"}},
Value: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.attn_v.weight"}},
},
{
Query: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_q.weight"}},
Key: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_k.weight"}},
Value: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.attn_v.weight"}},
},
},
}, m); diff != "" {
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
}
}
func TestPopulateFieldsAlternateName(t *testing.T) {
type fakeModel struct {
Input *nn.Embedding `gguf:"input"`
Output *nn.Linear `gguf:"output,alt:input"`
}
m := fakeModel{}
v := reflect.ValueOf(&m)
v.Elem().Set(populateFields(&fakeBackend{
names: []string{
"input.weight",
},
}, v))
if diff := cmp.Diff(fakeModel{
Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
Output: &nn.Linear{Weight: &fakeTensor{Name: "input.weight"}},
}, m); diff != "" {
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
}
}

312
model/process_text.go Normal file
View File

@@ -0,0 +1,312 @@
package model
import (
"cmp"
"log/slog"
"strings"
"sync"
"github.com/dlclark/regexp2"
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
)
type Special int32
const (
SpecialBOS Special = iota
SpecialEOS
)
type TextProcessor interface {
Encode(string) ([]int32, error)
Decode([]int32) (string, error)
Is(uint32, Special) bool
}
type Vocabulary struct {
Values []string
Types []uint32
Scores []uint32
Merges []string
BOS, EOS uint32
specialOnce sync.Once
special []string
valuesOnce sync.Once
values map[string]int32
mergeOnce sync.Once
merge map[string]int32
}
func (v *Vocabulary) Is(id uint32, special Special) bool {
switch special {
case SpecialBOS:
return id == v.BOS
case SpecialEOS:
return id == v.EOS
default:
return false
}
}
func (v *Vocabulary) Encode(s string) int32 {
v.valuesOnce.Do(func() {
v.values = make(map[string]int32, len(v.Values))
for i, value := range v.Values {
v.values[value] = int32(i)
}
})
if id, ok := v.values[s]; ok {
return id
}
return -1
}
func (v *Vocabulary) Decode(id int32) string {
return v.Values[id]
}
func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() {
for i := range v.Values {
if v.Types[i] == 3 {
v.special = append(v.special, v.Values[i])
}
}
})
return v.special
}
func (v *Vocabulary) Merge(left, right string) int {
v.mergeOnce.Do(func() {
v.merge = make(map[string]int32, len(v.Merges))
for i, merge := range v.Merges {
v.merge[merge] = int32(i)
}
})
if id, ok := v.merge[left+" "+right]; ok {
return int(id)
}
return -1
}
type BytePairEncoding struct {
Pretokenizer string
*Vocabulary
}
func (bpe BytePairEncoding) split(s string) ([]string, error) {
re, err := regexp2.Compile(bpe.Pretokenizer, regexp2.Unicode|regexp2.RE2)
if err != nil {
return nil, err
}
var matches []string
for m, _ := re.FindStringMatch(s); m != nil; m, _ = re.FindNextMatch(m) {
matches = append(matches, m.String())
}
return matches, nil
}
// fragment is a string fragment and their corresponding token IDs
type fragment struct {
value string
ids []int32
}
// pair is a pair of runes and its rank
type pair struct {
a, b int
rank int
value string
}
type merge struct {
p, n int
runes []rune
}
func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range bpe.Vocabulary.SpecialVocabulary() {
// TODO: process special tokens concurrently
id := bpe.Vocabulary.Encode(special)
for i := 0; i < len(fragments); i++ {
frag := fragments[i]
if len(frag.ids) > 0 {
continue
}
var middle []fragment
switch i := strings.Index(frag.value, special); {
case i < 0:
middle = append(middle, frag)
case i > 0:
middle = append(middle, fragment{value: frag.value[:i]})
fallthrough
default:
middle = append(middle, fragment{value: special, ids: []int32{id}})
if rest := frag.value[i+len(special):]; rest != "" {
middle = append(middle, fragment{value: rest})
}
}
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
}
}
var ids []int32
for _, frag := range fragments {
if len(frag.ids) > 0 {
ids = append(ids, frag.ids...)
slog.Debug("encoded", "text", frag.value, "ids", frag.ids, "special", true)
continue
}
// split fragment using pretokenizer
splits, err := bpe.split(frag.value)
if err != nil {
return nil, err
}
for _, split := range splits {
// TODO: process splits concurrently
var sb strings.Builder
for _, b := range []byte(split) {
r := rune(b)
switch {
case r == 0x00ad:
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007e && r <= 0x00a0:
r = r + 0x00a2
}
sb.WriteRune(r)
}
// short circuit if the fragment is in the vocabulary
if id := bpe.Vocabulary.Encode(sb.String()); id >= 0 {
ids = append(ids, id)
slog.Debug("encoded", "text", sb.String(), "ids", []int32{id})
continue
}
runes := []rune(sb.String())
merges := make([]merge, len(runes))
for r := range runes {
merges[r] = merge{
p: r - 1,
n: r + 1,
runes: []rune{runes[r]},
}
}
pairwise := func(a, b int) *pair {
if a < 0 || b >= len(runes) {
return nil
}
left, right := string(merges[a].runes), string(merges[b].runes)
rank := bpe.Vocabulary.Merge(left, right)
if rank < 0 {
return nil
}
return &pair{
a: a,
b: b,
rank: rank,
value: left + right,
}
}
pairs := heap.NewWith(func(i, j *pair) int {
return cmp.Compare(i.rank, j.rank)
})
for i := range len(runes) - 1 {
if pair := pairwise(i, i+1); pair != nil {
pairs.Push(pair)
}
}
for !pairs.Empty() {
pair, _ := pairs.Pop()
left, right := merges[pair.a], merges[pair.b]
if len(left.runes) == 0 || len(right.runes) == 0 ||
string(left.runes)+string(right.runes) != pair.value {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil
merges[pair.a].n = right.n
if right.n < len(merges) {
merges[right.n].p = pair.a
}
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
pairs.Push(pair)
}
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
pairs.Push(pair)
}
}
for _, merge := range merges {
if len(merge.runes) > 0 {
// TODO: handle the edge case where the rune isn't in the vocabulary
if id := bpe.Vocabulary.Encode(string(merge.runes)); id >= 0 {
ids = append(ids, id)
slog.Debug("encoded", "text", string(merge.runes), "ids", []int32{id})
}
}
}
}
}
return ids, nil
}
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
for _, r := range bpe.Vocabulary.Decode(id) {
switch {
case r == 0x0100:
// this produces 0x00 aka NULL
continue
case r == 0x0143:
r = 0x00ad
case r > 0x0100 && r <= 0x0120:
r = r - 0x0100
case r > 0x0120 && r <= 0x0142:
r = r - 0x00a2
}
// NOTE: not using WriteRune here because it writes the UTF-8
// encoding of the rune which is _not_ what we want
if err := sb.WriteByte(byte(r)); err != nil {
return "", err
}
}
}
slog.Debug("decoded", "ids", ids, "text", sb.String())
return sb.String(), nil
}

586
model/testdata/inputs.json vendored Normal file
View File

@@ -0,0 +1,586 @@
[
{
"base64": "aWVkIDQgwr0gbW9udGhz",
"ids": [
1142,
220,
19,
220,
27154,
4038
]
},
{
"base64": "RsO8aHJlcg==",
"ids": [
37,
51853,
261
]
},
{
"base64": "",
"ids": []
},
{
"base64": "IA==",
"ids": [
220
]
},
{
"base64": "ICA=",
"ids": [
256
]
},
{
"base64": "ICAg",
"ids": [
262
]
},
{
"base64": "CQ==",
"ids": [
197
]
},
{
"base64": "Cg==",
"ids": [
198
]
},
{
"base64": "Cgo=",
"ids": [
271
]
},
{
"base64": "CgoK",
"ids": [
1432
]
},
{
"base64": "CQo=",
"ids": [
1602
]
},
{
"base64": "SGVsbG8gd29ybGQ=",
"ids": [
9906,
1917
]
},
{
"base64": "IEhlbGxvIHdvcmxk",
"ids": [
22691,
1917
]
},
{
"base64": "SGVsbG8gV29ybGQ=",
"ids": [
9906,
4435
]
},
{
"base64": "IEhlbGxvIFdvcmxk",
"ids": [
22691,
4435
]
},
{
"base64": "IEhlbGxvIFdvcmxkIQ==",
"ids": [
22691,
4435,
0
]
},
{
"base64": "SGVsbG8sIHdvcmxkIQ==",
"ids": [
9906,
11,
1917,
0
]
},
{
"base64": "IEhlbGxvLCB3b3JsZCE=",
"ids": [
22691,
11,
1917,
0
]
},
{
"base64": "IHRoaXMgaXMg8J+mmS5jcHA=",
"ids": [
420,
374,
11410,
99,
247,
13,
11055
]
},
{
"base64": "dzA0OCA3dHVpamsgZHNkZmh1",
"ids": [
86,
23904,
220,
22,
83,
2005,
42908,
11729,
3013,
17156
]
},
{
"base64": "0L3QtdGJ0L4g0L3QsCDQkdGK0LvQs9Cw0YDRgdC60Lg=",
"ids": [
79862,
102118,
13373,
64571,
34694,
3114,
112203,
80112
]
},
{
"base64": "4Z6A4Z624Z6T4Z+L4Z6P4Z+C4Z6W4Z634Z6f4Z+B4Z6f4Z6i4Z624Z6F4Z6B4Z6b4Z6F4Z+B4Z6J",
"ids": [
21549,
222,
98629,
241,
45358,
233,
21549,
237,
45358,
224,
21549,
244,
21549,
115,
21549,
253,
45358,
223,
21549,
253,
21549,
95,
98629,
227,
21549,
223,
21549,
249,
21549,
227,
45358,
223,
21549,
231
]
},
{
"base64": "8J+agCAobm9ybWFsKSDwn5i24oCN8J+Mq++4jyAobXVsdGlwbGUgZW1vamlzIGNvbmNhdGVuYXRlZCkg4pyFIChvbmx5IGVtb2ppIHRoYXQgaGFzIGl0cyBvd24gdG9rZW4p",
"ids": [
9468,
248,
222,
320,
8416,
8,
27623,
114,
102470,
9468,
234,
104,
31643,
320,
36773,
100166,
98634,
8,
26602,
227,
320,
3323,
43465,
430,
706,
1202,
1866,
4037,
8
]
},
{
"base64": "SGVsbG8=",
"ids": [
9906
]
},
{
"base64": "IEhlbGxv",
"ids": [
22691
]
},
{
"base64": "ICBIZWxsbw==",
"ids": [
220,
22691
]
},
{
"base64": "ICAgSGVsbG8=",
"ids": [
256,
22691
]
},
{
"base64": "ICAgIEhlbGxv",
"ids": [
262,
22691
]
},
{
"base64": "ICAgIEhlbGxvCiAgICBIZWxsbw==",
"ids": [
262,
22691,
198,
262,
22691
]
},
{
"base64": "ICg=",
"ids": [
320
]
},
{
"base64": "CiA9",
"ids": [
198,
284
]
},
{
"base64": "JyBlcmE=",
"ids": [
6,
11639
]
},
{
"base64": "SGVsbG8sIHknYWxsISBIb3cgYXJlIHlvdSDwn5iBID/miJHmg7PlnKhhcHBsZeW3peS9nDEzMTQxNTHlpKnvvZ4=",
"ids": [
9906,
11,
379,
65948,
0,
2650,
527,
499,
27623,
223,
949,
37046,
101067,
19000,
23182,
102301,
9263,
18136,
16,
36827,
21909
]
},
{
"base64": "ISEhISEh",
"ids": [
17523,
3001
]
},
{
"base64": "Mw==",
"ids": [
18
]
},
{
"base64": "MzM=",
"ids": [
1644
]
},
{
"base64": "MzMz",
"ids": [
8765
]
},
{
"base64": "MzMzMw==",
"ids": [
8765,
18
]
},
{
"base64": "MzMzMzM=",
"ids": [
8765,
1644
]
},
{
"base64": "MzMzMzMz",
"ids": [
8765,
8765
]
},
{
"base64": "MzMzMzMzMw==",
"ids": [
8765,
8765,
18
]
},
{
"base64": "MzMzMzMzMzM=",
"ids": [
8765,
8765,
1644
]
},
{
"base64": "MzMzMzMzMzMz",
"ids": [
8765,
8765,
8765
]
},
{
"base64": "Q+G7rWEgVmnhu4d0",
"ids": [
34,
91163,
101798
]
},
{
"base64": "IGRpc2NhcmRz",
"ids": [
2624,
2402
]
},
{
"base64": "CiAKCiAKCgogCSAJCSAJCiAgCiAgIAogICAgCiAgICAgCvCfmoAgKG5vcm1hbCkg8J+YtuKAjfCfjKvvuI8gKG11bHRpcGxlIGVtb2ppcyBjb25jYXRlbmF0ZWQpIOKchSDwn6aZ8J+mmSAzIDMzIDMzMyAzMzMzIDMzMzMzIDMzMzMzMyAzMzMzMzMzIDMzMzMzMzMzIDMuMyAzLi4zIDMuLi4zIOGegOGetuGek+Gfi+Gej+GfguGeluGet+Gen+GfgeGen+GeouGetuGehfCfmIEgP+aIkeaDs+WcqGFwcGxl5bel5L2cMTMxNDE1MeWkqe+9niAtLS0tLS09PT09PT09INC90LXRidC+INC90LAg0JHRitC70LPQsNGA0YHQutC4ICcnJycnJ2BgYGBgYGAiIiIiLi4uLi4uISEhISEhPz8/Pz8/IEkndmUgYmVlbiAndG9sZCBoZSdzIHRoZXJlLCAnUkUgeW91IHN1cmU/ICdNIG5vdCBzdXJlIEknbGwgbWFrZSBpdCwgJ0QgeW91IGxpa2Ugc29tZSB0ZWE/IFdlJ1ZlIGEnbEw=",
"ids": [
198,
4815,
15073,
66597,
8004,
1602,
2355,
79772,
11187,
9468,
248,
222,
320,
8416,
8,
27623,
114,
102470,
9468,
234,
104,
31643,
320,
36773,
100166,
98634,
8,
26602,
227,
11410,
99,
247,
9468,
99,
247,
220,
18,
220,
1644,
220,
8765,
220,
8765,
18,
220,
8765,
1644,
220,
8765,
8765,
220,
8765,
8765,
18,
220,
8765,
8765,
1644,
220,
18,
13,
18,
220,
18,
497,
18,
220,
18,
1131,
18,
220,
21549,
222,
98629,
241,
45358,
233,
21549,
237,
45358,
224,
21549,
244,
21549,
115,
21549,
253,
45358,
223,
21549,
253,
21549,
95,
98629,
227,
76460,
223,
949,
37046,
101067,
19000,
23182,
102301,
9263,
18136,
16,
36827,
21909,
56560,
54337,
19175,
102118,
13373,
64571,
34694,
3114,
112203,
80112,
3436,
106451,
14196,
14196,
74694,
3089,
3089,
29249,
17523,
3001,
27708,
7801,
358,
3077,
1027,
364,
83,
820,
568,
596,
1070,
11,
364,
793,
499,
2771,
30,
364,
44,
539,
2771,
358,
3358,
1304,
433,
11,
364,
35,
499,
1093,
1063,
15600,
30,
1226,
6,
43712,
264,
64966,
43
]
}
]