mirror of
https://github.com/ollama/ollama.git
synced 2025-04-07 03:18:24 +02:00
add gemma vision encoder
This commit is contained in:
parent
5f74d1fd47
commit
4b037a97dc
@ -13,13 +13,13 @@ import (
|
||||
)
|
||||
|
||||
type ModelParameters struct {
|
||||
Architectures []string `json:"architectures"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
Architectures []string `json:"architectures"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
TextModel TextParameters `json:"text_config"`
|
||||
}
|
||||
|
||||
type TextParameters struct {
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
}
|
||||
|
||||
type AdapterParameters struct {
|
||||
|
@ -4,8 +4,17 @@ import "github.com/ollama/ollama/fs/ggml"
|
||||
|
||||
type gemma3Model struct {
|
||||
gemmaModel
|
||||
TextModel gemma3TextModel `json:"text_config"`
|
||||
VisionModel gemma3VisionModel `json:"vision_config"`
|
||||
TextModel gemma3TextModel `json:"text_config"`
|
||||
VisionModel struct {
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"` // attention.head_count 16
|
||||
LayerNormEpsilon float32 `json:"layer_norm_eps"` // attention.layer_norm_epsilon 1e-05
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"` // block_count 32
|
||||
HiddenSize uint32 `json:"hidden_size"` // embedding_length 1280
|
||||
IntermediateSize uint32 `json:"intermediate_size"` // feed_forward_length 5120
|
||||
ImageSize uint32 `json:"image_size"` // image_size 560
|
||||
NumChannels uint32 `json:"num_channels"` // num_channels 3
|
||||
PatchSize uint32 `json:"patch_size"` // patch_size 14
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
type gemma3TextModel struct {
|
||||
@ -24,12 +33,6 @@ type gemma3TextModel struct {
|
||||
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
|
||||
}
|
||||
|
||||
type gemma3VisionModel struct {
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
}
|
||||
|
||||
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3"
|
||||
@ -46,11 +49,18 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv["gemma3.text.final_logit_softcapping"] = p.TextModel.FinalLogitSoftcap
|
||||
kv["gemma3.text.rope.local.freq_base"] = p.TextModel.RopeLocalTheta
|
||||
kv["gemma3.text.rope.global.freq_base"] = p.TextModel.RopeGlobalTheta
|
||||
|
||||
kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||
kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
|
||||
kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
|
||||
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
|
||||
kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize
|
||||
kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels
|
||||
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
|
||||
kv["gemma3.vision.attention.layer_norm_epsilon"] = p.VisionModel.LayerNormEpsilon
|
||||
|
||||
kv["tokenizer.ggml.bos_token_id"] = uint32(2)
|
||||
kv["tokenizer.ggml.eot_token_id"] = uint32(1)
|
||||
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
|
||||
kv["gemma3.vision.num_channels"] = p.VisionModel.NumChannels
|
||||
kv["gemma3.vision.block_count"] = p.VisionModel.HiddenLayers
|
||||
return kv
|
||||
}
|
||||
|
||||
@ -59,11 +69,11 @@ func (p *gemma3Model) Replacements() []string {
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.norm", "output_norm",
|
||||
"vision_model.vision_model", "v",
|
||||
"vision_tower.vision_model.embeddings", "v",
|
||||
"vision_tower.vision_model", "v",
|
||||
"language_model.", "",
|
||||
"model.layers", "blk",
|
||||
"encoder.layers", "blk",
|
||||
"vision_tower.vision_model.embeddings", "v",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
@ -71,11 +81,14 @@ func (p *gemma3Model) Replacements() []string {
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"self_attn.out_proj", "attn_output",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
"pre_feedforward_layernorm", "ffn_norm",
|
||||
"post_feedforward_layernorm", "post_ffw_norm",
|
||||
"input_projection_weight", "input_projection.weight",
|
||||
"multi_modal_projector", "mm",
|
||||
}
|
||||
}
|
||||
|
@ -135,7 +135,9 @@ type Tensor interface {
|
||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||
Scale(ctx Context, s float64) Tensor
|
||||
|
||||
AvgPool1D(ctx Context, k, s, p int) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
|
||||
|
||||
Tanh(ctx Context) Tensor
|
||||
|
@ -947,6 +947,13 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_pool_1d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(s), C.int(p)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
|
||||
var kqMask *C.struct_ggml_tensor
|
||||
if mask != nil {
|
||||
|
@ -1,10 +1,15 @@
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"hash/fnv"
|
||||
"image"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
@ -13,19 +18,30 @@ type Model struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
|
||||
//*VisionModel `gguf:"v,vision"`
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*TextModel
|
||||
|
||||
//Projector *nn.Linear `gguf:"mm.0"`
|
||||
*MultiModalProjector `gguf:"mm"`
|
||||
|
||||
ImageProcessor
|
||||
}
|
||||
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
type MultiModalProjector struct {
|
||||
SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
|
||||
InputProjection *nn.Linear `gguf:"mm_input_projection"`
|
||||
}
|
||||
|
||||
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
|
||||
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
|
||||
|
||||
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
|
||||
visionOutputs = p.InputProjection.Weight.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mulmat(ctx, visionOutputs)
|
||||
return visionOutputs
|
||||
}
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
// Verify unified config
|
||||
if c.Uint("vision.block_count") == 0 {
|
||||
return nil, fmt.Errorf("non-unified vision model not supported")
|
||||
}
|
||||
m := Model{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
@ -40,8 +56,8 @@ func New(c ml.Config) (model.Model, error) {
|
||||
},
|
||||
),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
//VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
}
|
||||
|
||||
slidingWindowLen := int32(c.Uint("text.attention.sliding_window"))
|
||||
@ -50,7 +66,78 @@ func New(c ml.Config) (model.Model, error) {
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f32s, err := m.ImageProcessor.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.numChannels,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positionIDs, err := ctx.FromIntSlice([]int32{0}, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues, positionIDs)
|
||||
|
||||
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
patchesPerImage := m.ImageProcessor.imageSize / m.ImageProcessor.patchSize
|
||||
kernelSize := patchesPerImage * patchesPerImage / 256
|
||||
visionOutputs = visionOutputs.AvgPool1D(ctx, kernelSize, kernelSize, 0)
|
||||
|
||||
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps)
|
||||
return visionOutputs, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
|
||||
var images []input.Input
|
||||
fnvHash := fnv.New64a()
|
||||
|
||||
for i := range inputs {
|
||||
if inputs[i].Multimodal == nil {
|
||||
if len(images) > 0 {
|
||||
inputs[i].Multimodal = images[0].Multimodal
|
||||
inputs[i].MultimodalHash = images[0].MultimodalHash
|
||||
for j := 1; j < len(images); j++ {
|
||||
inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
|
||||
fnvHash.Reset()
|
||||
binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
|
||||
binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
|
||||
inputs[i].MultimodalHash = fnvHash.Sum64()
|
||||
}
|
||||
images = nil
|
||||
}
|
||||
} else {
|
||||
images = append(images, inputs[i])
|
||||
inputs[i].Token = -1
|
||||
}
|
||||
}
|
||||
|
||||
inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
|
||||
|
||||
return inputs, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
var embeddings ml.Tensor
|
||||
if opts.Multimodal != nil {
|
||||
embeddings = opts.Multimodal[0].Multimodal.(ml.Tensor)
|
||||
}
|
||||
|
||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -66,7 +153,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, m.Cache), nil
|
||||
return m.TextModel.Forward(ctx, inputs, positions, embeddings, outputs, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -160,9 +160,12 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, embeddings, outputs ml.Tensor, cache kvcache.Cache) ml.Tensor {
|
||||
if embeddings == nil {
|
||||
embeddings = m.TokenEmbedding.Forward(ctx, inputs)
|
||||
}
|
||||
|
||||
hiddenState := embeddings.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
||||
|
||||
if len(m.Layers) == gemma27BLayerCount {
|
||||
m.TextOptions.largeModelScaling = true
|
||||
|
171
model/models/gemma3/model_vision.go
Normal file
171
model/models/gemma3/model_vision.go
Normal file
@ -0,0 +1,171 @@
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
var batchSize int = 1
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize).Permute(ctx, 0, 2, 1, 3)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize).Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
scores := key.Mulmat(ctx, query)
|
||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize)
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
hiddenState = sa.Output.Forward(ctx, attention)
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
FC1 *nn.Linear `gguf:"fc1"`
|
||||
FC2 *nn.Linear `gguf:"fc2"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
|
||||
hiddenState = mlp.FC2.Forward(ctx, hiddenState)
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
type VisionEncoderLayer struct {
|
||||
LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
|
||||
LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
|
||||
MLP *VisionMLP `gguf:"mlp"`
|
||||
}
|
||||
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
// self attention
|
||||
hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
// feed forward
|
||||
hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type VisionEncoder struct {
|
||||
Layers []VisionEncoderLayer
|
||||
}
|
||||
|
||||
func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermediateLayersIndices []uint32, opts *VisionModelOptions) (ml.Tensor, []ml.Tensor) {
|
||||
var intermediateHiddenStates []ml.Tensor
|
||||
for i, layer := range e.Layers {
|
||||
if slices.Contains(intermediateLayersIndices, uint32(i)) {
|
||||
intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int{1}, hiddenState.Shape()...)...))
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, opts)
|
||||
}
|
||||
|
||||
return hiddenState, intermediateHiddenStates
|
||||
}
|
||||
|
||||
type PrecomputedAspectRatioEmbedding struct {
|
||||
Embedding *nn.Embedding
|
||||
Gate ml.Tensor `gguf:"gate"`
|
||||
}
|
||||
|
||||
func (e *PrecomputedAspectRatioEmbedding) Forward(ctx ml.Context, hiddenState ml.Tensor, aspectRatioIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
embeddings := e.Embedding.Forward(ctx, aspectRatioIDs)
|
||||
embeddings = embeddings.Reshape(ctx, opts.hiddenSize, 1, opts.numTiles)
|
||||
if e.Gate != nil {
|
||||
embeddings = embeddings.Mul(ctx, e.Gate)
|
||||
}
|
||||
|
||||
return hiddenState.Add(ctx, embeddings)
|
||||
}
|
||||
|
||||
type PrecomputedPositionEmbedding struct {
|
||||
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
|
||||
PositionEmbeddingGate ml.Tensor `gguf:"position_embd.gate"`
|
||||
}
|
||||
|
||||
func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, numPositions int, opts *VisionModelOptions) ml.Tensor {
|
||||
positionEmbedding := e.PositionEmbedding.Forward(ctx, positionIDs)
|
||||
if e.PositionEmbeddingGate != nil {
|
||||
positionEmbedding = positionEmbedding.Mul(ctx, e.PositionEmbeddingGate)
|
||||
}
|
||||
|
||||
return hiddenState.Add(ctx, positionEmbedding)
|
||||
}
|
||||
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize, numHeads, numTiles int
|
||||
imageSize, patchSize int
|
||||
eps float32
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
PatchEmbedding *nn.Conv2D `gguf:"patch_embedding"`
|
||||
PositionEmbedding *nn.Embedding `gguf:"position_embedding"`
|
||||
PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
|
||||
|
||||
Encoder *VisionEncoder `gguf:"blk"`
|
||||
|
||||
*VisionModelOptions
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs ml.Tensor) ml.Tensor {
|
||||
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
|
||||
|
||||
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
positions := m.PositionEmbedding.Forward(ctx, positionIDs)
|
||||
hiddenState = hiddenState.Add(ctx, positions)
|
||||
|
||||
for _, layer := range m.Encoder.Layers {
|
||||
hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
|
||||
}
|
||||
|
||||
hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
func newVisionModel(c ml.Config) *VisionModel {
|
||||
return &VisionModel{
|
||||
Encoder: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count"))},
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: int(c.Uint("vision.embedding_length")),
|
||||
numHeads: int(c.Uint("vision.attention.head_count")),
|
||||
|
||||
imageSize: int(c.Uint("vision.image_size")),
|
||||
patchSize: int(c.Uint("vision.patch_size")),
|
||||
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon"),
|
||||
},
|
||||
}
|
||||
}
|
@ -8,12 +8,13 @@ import (
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
imageSize, numChannels int
|
||||
imageSize, patchSize, numChannels int
|
||||
}
|
||||
|
||||
func newImageProcessor(c ml.Config) ImageProcessor {
|
||||
return ImageProcessor{
|
||||
imageSize: int(c.Uint("vision.image_size")),
|
||||
patchSize: int(c.Uint("vision.patch_size")),
|
||||
numChannels: int(c.Uint("vision.num_channels")),
|
||||
}
|
||||
}
|
||||
|
@ -144,8 +144,6 @@ func (p *ImageProcessor) splitToTiles(img image.Image, numTilesSize image.Point)
|
||||
return images
|
||||
}
|
||||
|
||||
// remove the "alpha" channel by drawing over a prefilled image
|
||||
//
|
||||
// remove the "alpha" channel by drawing over a prefilled image
|
||||
//
|
||||
//nolint:unused
|
||||
|
@ -21,6 +21,8 @@ type SentencePieceModel struct {
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
||||
|
||||
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
||||
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
|
||||
@ -61,7 +63,7 @@ func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
|
||||
}
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
|
||||
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||
// TODO: process special tokens concurrently
|
||||
@ -196,7 +198,26 @@ func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
|
||||
}
|
||||
}
|
||||
}
|
||||
slog.Debug("encoded", "ids", ids)
|
||||
|
||||
if addSpecial && len(ids) > 0 {
|
||||
if spm.vocab.AddBOS {
|
||||
if ids[0] == spm.vocab.BOS {
|
||||
slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
|
||||
ids = append([]int32{spm.vocab.BOS}, ids...)
|
||||
}
|
||||
|
||||
if spm.vocab.AddEOS {
|
||||
if ids[len(ids)-1] == spm.vocab.EOS {
|
||||
slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
|
||||
ids = append(ids, spm.vocab.EOS)
|
||||
}
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user