mirror of
https://github.com/ollama/ollama.git
synced 2025-04-04 01:50:26 +02:00
use 2d pooling
This commit is contained in:
parent
ab39e08eb9
commit
63a394068c
@ -26,15 +26,16 @@ type gemma3Model struct {
|
||||
NumChannels uint32 `json:"num_channels"` // num_channels 3
|
||||
PatchSize uint32 `json:"patch_size"` // patch_size 14
|
||||
} `json:"vision_config"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
||||
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
||||
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
||||
}
|
||||
|
||||
const (
|
||||
@ -102,6 +103,10 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv["gemma3.attention.value_length"] = cmp.Or(p.TextModel.HeadDim, 256)
|
||||
}
|
||||
|
||||
if p.MultiModalTokensPerImage > 0 {
|
||||
kv["gemma3.mm.tokens_per_image"] = p.MultiModalTokensPerImage
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
|
@ -135,7 +135,7 @@ type Tensor interface {
|
||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||
Scale(ctx Context, s float64) Tensor
|
||||
|
||||
AvgPool1D(ctx Context, k, s, p int) Tensor
|
||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
|
||||
|
@ -247,7 +247,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
createTensor(tensor{source: t}, output.bts)
|
||||
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
|
||||
// TODO: assign vision tensors to the gpu if possible
|
||||
createTensor(tensor{source: t}, input.bts)
|
||||
createTensor(tensor{source: t}, output.bts)
|
||||
case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
|
||||
// these tensors should be repeated per layer
|
||||
for i, layer := range layers {
|
||||
@ -952,10 +952,10 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
|
||||
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_pool_1d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(s), C.int(p)),
|
||||
t: C.ggml_pool_2d(ctx.(*Context).ctx, t.t, C.GGML_OP_POOL_AVG, C.int(k), C.int(k), C.int(s), C.int(s), C.float(p), C.float(p)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5,6 +5,7 @@ import (
|
||||
"encoding/binary"
|
||||
"hash/fnv"
|
||||
"image"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
@ -30,9 +31,21 @@ var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
type MultiModalProjector struct {
|
||||
SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
|
||||
InputProjection *nn.Linear `gguf:"mm_input_projection"`
|
||||
|
||||
tokensPerImage int
|
||||
}
|
||||
|
||||
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
|
||||
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
|
||||
l := visionOutputs.Dim(0)
|
||||
|
||||
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
patchesPerImage := imageSize / patchSize
|
||||
visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
|
||||
|
||||
kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
|
||||
visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
|
||||
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
|
||||
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
|
||||
|
||||
// TODO: inputProjection must be transposed since they're incompatible with visionOutputs
|
||||
@ -59,6 +72,9 @@ func New(c ml.Config) (model.Model, error) {
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
MultiModalProjector: &MultiModalProjector{
|
||||
tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
|
||||
},
|
||||
}
|
||||
|
||||
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
||||
@ -88,17 +104,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
}
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
patchesPerImage := m.ImageProcessor.imageSize / m.ImageProcessor.patchSize
|
||||
|
||||
// TODO (jmorganca): read this from the model config
|
||||
// it should instead be math.Sqrt(tokens per image)
|
||||
tokensPerSide := 8
|
||||
kernelSize := patchesPerImage / tokensPerSide
|
||||
visionOutputs = visionOutputs.AvgPool1D(ctx, kernelSize, kernelSize, 0)
|
||||
|
||||
visionOutputs = visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps)
|
||||
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
|
||||
return visionOutputs, nil
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user