diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 617f53635..d221e21da 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -467,6 +467,14 @@ func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, di panic("not implemented") } +func (t *testTensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, sections [4]int, ropeType uint32, base, scale float32) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { + panic("not implemented") +} + func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { panic("not implemented") } diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m index d68aab6cd..e4c093f9c 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m @@ -2186,10 +2186,6 @@ static void ggml_metal_encode_node( } break; case GGML_OP_MUL_MAT: { - if (ne00 != ne10) { - printf("mul_mat, ne00: %d, ne01: %d, ne02: %d, ne03: %d, ne10: %d, ne11: %d, ne12: %d, ne13: %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13); - } - GGML_ASSERT(ne00 == ne10); GGML_ASSERT(ne12 % ne02 == 0); diff --git a/model/models/mistral3/imageproc.go b/model/models/mistral3/imageproc.go index 907c580d0..373ce47ae 100644 --- a/model/models/mistral3/imageproc.go +++ b/model/models/mistral3/imageproc.go @@ -1,72 +1,15 @@ package mistral3 import ( - "fmt" "image" _ "image/jpeg" _ "image/png" - "io" "math" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model/imageproc" ) -func getNumImageTokens(imageSize, patchSize image.Point) image.Point { - return image.Point{ - (imageSize.X-1)/patchSize.X + 1, - (imageSize.Y-1)/patchSize.Y + 1, - } -} - -func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.Point) image.Point { - b := img.Bounds() - ratio := math.Max(float64(b.Max.Y)/float64(longestEdge), float64(b.Max.X)/float64(longestEdge)) - - newSize := img.Bounds().Max - - if ratio > 1.0 { - newSize = image.Point{ - int(math.Floor(float64(b.Max.X) / ratio)), - int(math.Floor(float64(b.Max.Y) / ratio)), - } - } - - tokens := getNumImageTokens(newSize, patchSize) - return image.Point{ - tokens.X * patchSize.X, - tokens.Y * patchSize.Y, - } -} - -func resizeImage(img image.Image, format string, longestEdge int, patchSize image.Point) image.Image { - if format == "png" { - img = imageproc.Composite(img) - } - - newSize := getResizeOutputImageSize(img, longestEdge, patchSize) - - // todo should be ResizeBicubic, but it doesn't exist - return imageproc.Resize(img, newSize, imageproc.ResizeBilinear) -} - -func Preprocess(imageData io.Reader) ([]float32, map[string]any, error) { - img, format, err := image.Decode(imageData) - if err != nil { - return nil, nil, fmt.Errorf("failed to decode image: %w", err) - } - - longestEdge := 1024 - patchSize := image.Point{16, 16} - - img = resizeImage(img, format, longestEdge, patchSize) - - data := imageproc.Normalize(img, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true) - - opts := map[string]any{} - return data, opts, nil -} - type ImageProcessor struct { imageSize int patchSize int @@ -83,10 +26,31 @@ func newImageProcessor(c ml.Config) ImageProcessor { } } -func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, error) { - outputSize := getResizeOutputImageSize(img, p.longestEdge, image.Point{p.patchSize, p.patchSize}) - newImage := imageproc.Composite(img) - newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear) - data := imageproc.Normalize(newImage, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true) - return data, nil +// ProcessImage prepares an image for the vision model by: +// 1. Compositing transparent images +// 2. Resizing to fit model constraints while preserving aspect ratio +// 3. Normalizing pixel values +// Returns normalized image data and the final size in pixels +func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, image.Point, error) { + img = imageproc.Composite(img) + + size := img.Bounds().Size() + ratio := max(float64(size.Y)/float64(p.longestEdge), float64(size.X)/float64(p.longestEdge)) + if ratio > 1.0 { + size = image.Point{ + int(math.Floor(float64(size.X) / ratio)), + int(math.Floor(float64(size.Y) / ratio)), + } + } + + patchesX := (size.X-1)/p.patchSize + 1 + patchesY := (size.Y-1)/p.patchSize + 1 + size = image.Point{ + patchesX * p.patchSize, + patchesY * p.patchSize, + } + + img = imageproc.Resize(img, size, imageproc.ResizeBilinear) + data := imageproc.Normalize(img, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true) + return data, size, nil } diff --git a/model/models/mistral3/imageproc_test.go b/model/models/mistral3/imageproc_test.go deleted file mode 100644 index 2ec634132..000000000 --- a/model/models/mistral3/imageproc_test.go +++ /dev/null @@ -1,219 +0,0 @@ -package mistral3 - -import ( - "bytes" - "encoding/binary" - "image" - "image/png" - "math" - "os" - "testing" - - "github.com/google/go-cmp/cmp" -) - -func TestGetNumImageTokens(t *testing.T) { - type numImageTokensCase struct { - ImageSize image.Point - PatchSize image.Point - Expected image.Point - } - - cases := []numImageTokensCase{ - { - ImageSize: image.Point{1024, 764}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{64, 48}, - }, - { - ImageSize: image.Point{800, 600}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{50, 38}, - }, - { - ImageSize: image.Point{640, 480}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{40, 30}, - }, - { - ImageSize: image.Point{320, 200}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{20, 13}, - }, - { - ImageSize: image.Point{1320, 200}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{83, 13}, - }, - { - ImageSize: image.Point{2000, 200}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{125, 13}, - }, - { - ImageSize: image.Point{10000, 200}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{625, 13}, - }, - { - ImageSize: image.Point{1131, 577}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{71, 37}, - }, - { - ImageSize: image.Point{16, 16}, - PatchSize: image.Point{16, 16}, - Expected: image.Point{1, 1}, - }, - } - - for _, c := range cases { - actual := getNumImageTokens(c.ImageSize, c.PatchSize) - - if diff := cmp.Diff(actual, c.Expected); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - } -} - -func TestGetResizeOutputImageSize(t *testing.T) { - type resizeCase struct { - Image image.Image - LongestEdge int - PatchSize image.Point - Expected image.Point - } - - cases := []resizeCase{ - { - Image: image.NewRGBA(image.Rect(0, 0, 1024, 768)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.Point{1024, 768}, - }, - { - Image: image.NewRGBA(image.Rect(0, 0, 1162, 690)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.Point{1024, 624}, - }, - { - Image: image.NewRGBA(image.Rect(0, 0, 300, 200)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.Point{304, 208}, - }, - { - Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.Point{1024, 288}, - }, - } - - for _, c := range cases { - actual := getResizeOutputImageSize(c.Image, c.LongestEdge, c.PatchSize) - - if diff := cmp.Diff(actual, c.Expected); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - } -} - -func TestResize(t *testing.T) { - type resizeCase struct { - Image image.Image - LongestEdge int - PatchSize image.Point - Expected image.Image - } - - cases := []resizeCase{ - { - Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.NewRGBA(image.Rect(0, 0, 1024, 288)), - }, - { - Image: image.NewRGBA(image.Rect(0, 0, 10, 10)), - LongestEdge: 1024, - PatchSize: image.Point{16, 16}, - Expected: image.NewRGBA(image.Rect(0, 0, 16, 16)), - }, - } - - for _, c := range cases { - actual := resizeImage(c.Image, "png", c.LongestEdge, c.PatchSize) - - if actual.Bounds() != c.Expected.Bounds() { - t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual.Bounds(), c.Expected.Bounds()) - } - } -} - -func TestPreprocess(t *testing.T) { - type preprocessCase struct { - TestImage image.Image - ExpectedLen int - } - - cases := []preprocessCase{ - { - TestImage: image.NewRGBA(image.Rect(0, 0, 10, 10)), - ExpectedLen: 16 * 16 * 3 * 1, - }, - { - TestImage: image.NewRGBA(image.Rect(0, 0, 2000, 2000)), - ExpectedLen: 1024 * 1024 * 3 * 1, - }, - } - - for _, c := range cases { - var buf bytes.Buffer - err := png.Encode(&buf, c.TestImage) - if err != nil { - t.Fatal(err) - } - - imgData, _, err := Preprocess(&buf) - if err != nil { - t.Fatalf("error processing: %q", err) - } - - switch len(imgData) { - case 0: - t.Errorf("no image data returned") - case c.ExpectedLen: - // ok - default: - t.Errorf("unexpected image data length: %d, expected: %d", len(imgData), c.ExpectedLen) - } - } -} - -func TestPreprocessImages(t *testing.T) { - for _, testFile := range []string{"flight.png", "sportsball.png"} { - f, err := os.Open(testFile) - if err != nil { - t.Skipf("skipping test, no test image found at %s", testFile) - } - defer f.Close() - - imgData, _, err := Preprocess(f) - if err != nil { - t.Fatalf("error processing: %q", err) - } - - byteData := make([]byte, len(imgData)*4) // float32 is 4 bytes - for i, f := range imgData { - binary.LittleEndian.PutUint32(byteData[i*4:], math.Float32bits(f)) - } - - outputPath := "processed_" + testFile + ".bin" - err = os.WriteFile(outputPath, byteData, 0o644) - if err != nil { - t.Fatalf("error writing processed image: %q", err) - } - } -} diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index 2592e1516..c79470b77 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -51,36 +51,34 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return nil, err } - f32s, err := m.ImageProcessor.ProcessImage(image) + f32s, size, err := m.ImageProcessor.ProcessImage(image) if err != nil { return nil, err } // Create tensor from image data - pixelValues, err := ctx.Input().FromFloatSlice(f32s, - m.ImageProcessor.imageSize, - 1036, // TODO (jmorganca): this should be returned from ProcessImage - m.ImageProcessor.numChannels, - ) + pixelValues, err := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels) if err != nil { return nil, err } - // fmt.Println("pixelValues", "shape", pixelValues.Shape(), "data", ml.Dump(ctx, pixelValues)) - - // Forward pass through vision model visionOutputs := m.VisionModel.Forward(ctx, pixelValues) + features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size) - // fmt.Println("visionOutputs", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs)) + // split into patches to be sent to the text transformer + var rows []ml.Tensor + for i := 0; i < size.Y; i++ { + view := features.View(ctx, features.Dim(0)*i, features.Dim(0), features.Dim(0)*4, size.X) + rows = append(rows, view) + } - // Project to text embedding space - visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps) - - // fmt.Println("visionOutputs after projector", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs)) - - return visionOutputs, nil + return rows, nil } +// PostTokenize arranges Mistral 3's inputs for the forward pass +// In Mistral 3 and Pixtral, the input patches are arranged as follows: +// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END] +// Each sequence of [IMG]...[IMG] is a single patch or "row" of vision embeddings func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { var result []input.Input @@ -88,13 +86,16 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { if inp.Multimodal == nil { result = append(result, inp) } else { - inputMultimodal := inp.Multimodal.(ml.Tensor) - - // Add special image tokens - using the imageTokenIndex from config - result = append(result, input.Input{Token: 10}) // [IMG] - result = append(result, input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}) // image data - result = append(result, slices.Repeat([]input.Input{{Token: 10}}, inputMultimodal.Dim(1)-1)...) // [IMG] placeholders - result = append(result, input.Input{Token: 13}) // [IMG_END] + inputMultimodal := inp.Multimodal.([]ml.Tensor) + for i, row := range inputMultimodal { + result = append(result, input.Input{Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.Dim(1)}) // Image data + result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Dim(1))...) // [IMG] + if i == len(inputMultimodal)-1 { + result = append(result, input.Input{Token: 13}) // [IMG_END] + } else { + result = append(result, input.Input{Token: 12}) // [IMG_BREAK] + } + } } } diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 52cd50b86..498b755f9 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -41,40 +41,29 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { batchSize := hiddenState.Dim(1) ropeType := uint32(0) - // Get head dimension - use explicit value if available, otherwise calculate headDim := opts.headDim if headDim == 0 { headDim = opts.hiddenSize / opts.numHeads } - // Query projection and reshape q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) - // Key projection and reshape k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) - // Value projection and reshape v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - // Attention computation - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) - - // Reshape attention output for final projection - outputDim := headDim * opts.numHeads - kqv = kqv.Reshape(ctx, outputDim, batchSize) - - // Apply output projection + kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache) + kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize) return sa.Output.Forward(ctx, kqv) } func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil + return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil } type MLP struct { @@ -117,10 +106,14 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten } func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor { - // Process text inputs hiddenState := m.TokenEmbedding.Forward(ctx, inputs) - // Process through text transformer layers + // image embeddings + for _, image := range batch.Multimodal { + visionOutputs := image.Multimodal.(ml.Tensor) + ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Dim(0), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) + } + for i, layer := range m.Layers { cache.SetLayer(i) diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index 8561ec358..ee98a2caf 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -1,7 +1,7 @@ package mistral3 import ( - "fmt" + "image" "math" "github.com/ollama/ollama/ml" @@ -14,32 +14,12 @@ type PatchMerger struct { MergingLayer *nn.Linear `gguf:"merging_layer"` } -func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor { - // TODO: pass these in - w := 110 - h := 74 - // tokensPerImage := w * h +func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point, spatialMergeSize int) ml.Tensor { d := visionOutputs.Dim(0) - - // TODO: handle multiple images, this currently assumes one - // fmt.Println("patchmerger visionOutputs", "shape", visionOutputs.Shape(), "data", ml.Dump(ctx, visionOutputs)) - - // Reshape to [h, w, hidden_size] - imageGrid := visionOutputs.Reshape(ctx, h, w, d) - // fmt.Println("imageGrid", "shape", imageGrid.Shape(), "data", ml.Dump(ctx, imageGrid)) - - // TODO: load from config - spatialMergeSize := 2 + imageGrid := visionOutputs.Reshape(ctx, size.Y, size.X, d) kernel := ctx.Input().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d, 1) - // fmt.Println("kernel", "shape", kernel.Shape(), "data", ml.Dump(ctx, kernel)) - patches := kernel.IM2Col(ctx, imageGrid, spatialMergeSize, spatialMergeSize, 0, 0, 1, 1) - // fmt.Println("patches", "shape", patches.Shape(), "data", ml.Dump(ctx, patches)) - - // fmt.Println("creating reshaped", d*spatialMergeSize*spatialMergeSize, "x", patches.Dim(1)*patches.Dim(2)) reshaped := patches.Reshape(ctx, d*spatialMergeSize*spatialMergeSize, patches.Dim(1)*patches.Dim(2)) - // fmt.Println("reshaped", "shape", reshaped.Shape(), "data", ml.Dump(ctx, reshaped)) - return pm.MergingLayer.Forward(ctx, reshaped) } @@ -50,23 +30,24 @@ type MultiModalProjector struct { PatchMerger *PatchMerger `gguf:"patch_merger"` spatialMergeSize int - imageTokenIndex int - hasBias bool + eps float32 + patchSize int } -func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor { - visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps) - visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs) +func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point) (ml.Tensor, image.Point) { + visionOutputs = p.Norm.Forward(ctx, visionOutputs, p.eps) + patchSizes := image.Point{size.X / p.patchSize, size.Y / p.patchSize} + visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs, patchSizes, p.spatialMergeSize) visionOutputs = p.Linear1.Forward(ctx, visionOutputs) visionOutputs = visionOutputs.GELU(ctx) - return p.Linear2.Forward(ctx, visionOutputs) + return p.Linear2.Forward(ctx, visionOutputs), image.Point{patchSizes.X / p.spatialMergeSize, patchSizes.Y / p.spatialMergeSize} } func newMultiModalProjector(c ml.Config) *MultiModalProjector { return &MultiModalProjector{ spatialMergeSize: int(c.Uint("spatial_merge_size", 2)), - imageTokenIndex: int(c.Uint("image_token_index", 10)), - hasBias: c.Bool("mm.projector_bias", false), + eps: c.Float("text_config.rms_norm_eps", 1e-5), + patchSize: int(c.Uint("vision.patch_size", 14)), } } @@ -115,9 +96,7 @@ type VisionEncoderLayer struct { func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor { residual := hiddenState - hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps) - fmt.Println("after attention norm", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState, ml.DumpOptions{Items: 3, Precision: 6})) hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, opts) hiddenState = hiddenState.Add(ctx, residual) residual = hiddenState @@ -149,22 +128,23 @@ type VisionModel struct { } func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { - numPatchesH := pixelValues.Dim(1) / m.patchSize numPatchesW := pixelValues.Dim(0) / m.patchSize - numPatches := numPatchesH * numPatchesW + numPatchesH := pixelValues.Dim(1) / m.patchSize + numPatches := numPatchesW * numPatchesH hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1) hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize) hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) hiddenState = m.EncoderNorm.Forward(ctx, hiddenState, m.VisionModelOptions.eps) - totalPositions := numPatchesH * numPatchesW - positions := make([]int32, totalPositions*4) - + // Prepare position IDs for 2D rope + positions := make([]int32, numPatches*4) for h := 0; h < numPatchesH; h++ { for w := 0; w < numPatchesW; w++ { - index := h*numPatchesW + w - positions[totalPositions+index] = int32(h) - positions[totalPositions*2+index] = int32(w) + idx := h*numPatchesW + w + positions[idx] = 0 // time (unused) + positions[numPatches+idx] = int32(h) // height + positions[numPatches*2+idx] = int32(w) // width + positions[numPatches*3+idx] = 0 // extra (unused) } } @@ -177,8 +157,6 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { hiddenState = layer.Forward(ctx, hiddenState, positionIDs, m.VisionModelOptions) } - // fmt.Println("after layers", "shape", hiddenState.Shape(), "data", ml.Dump(ctx, hiddenState)) - return hiddenState } diff --git a/model/process_text_test.go b/model/process_text_test.go index 1a84d8f5b..f48303212 100644 --- a/model/process_text_test.go +++ b/model/process_text_test.go @@ -209,322 +209,6 @@ func TestLlama(t *testing.T) { }) } -// tekken loads the Tekken tokenizer for testing -func tekken(t testing.TB) TextProcessor { - t.Helper() - - // Load tokenizer config from mistral-small - tokenizerConfigPath := filepath.Join("testdata", "mistral-small", "tokenizer_config.json") - configFile, err := os.Open(tokenizerConfigPath) - if err != nil { - t.Fatal(err) - } - defer configFile.Close() - - var config struct { - AddBosToken bool `json:"add_bos_token"` - AddEosToken bool `json:"add_eos_token"` - BosToken string `json:"bos_token"` - EosToken string `json:"eos_token"` - } - if err := json.NewDecoder(configFile).Decode(&config); err != nil { - t.Fatal(err) - } - - // Load tokenizer.json which contains the vocabulary and other settings - tokenizerJsonPath := filepath.Join("testdata", "mistral-small", "tokenizer.json") - tokenizerFile, err := os.Open(tokenizerJsonPath) - if err != nil { - t.Fatal(err) - } - defer tokenizerFile.Close() - - var tokenizerData struct { - Model struct { - Type string `json:"type"` - Vocab map[string]int32 `json:"vocab"` - Merges []string `json:"merges"` - } `json:"model"` - AddedTokens []struct { - Id int32 `json:"id"` - Content string `json:"content"` - Special bool `json:"special"` - } `json:"added_tokens"` - PreTokenizer struct { - Type string `json:"type"` - Pretokenizers []struct { - Type string `json:"type"` - Pattern struct { - String string `json:"String"` - } `json:"pattern"` - Behavior string `json:"behavior"` - } `json:"pretokenizers"` - } `json:"pre_tokenizer"` - } - if err := json.NewDecoder(tokenizerFile).Decode(&tokenizerData); err != nil { - t.Fatal(err) - } - - // Extract the pattern from pre_tokenizer if available - var pattern string - if tokenizerData.PreTokenizer.Type == "Sequence" && len(tokenizerData.PreTokenizer.Pretokenizers) > 0 { - pattern = tokenizerData.PreTokenizer.Pretokenizers[0].Pattern.String - } - - // Combine regular vocab and added tokens - vocab := tokenizerData.Model.Vocab - - // Add special tokens from added_tokens - for _, token := range tokenizerData.AddedTokens { - vocab[token.Content] = token.Id - } - - // Create vocabulary arrays - maxId := int32(-1) - for _, id := range vocab { - if id > maxId { - maxId = id - } - } - - vocabSize := int(maxId + 1) - types := make([]uint32, vocabSize) - tokens := make([]string, vocabSize) - scores := make([]float32, vocabSize) - - for token, id := range vocab { - tokens[id] = token - types[id] = TOKEN_TYPE_NORMAL - - // Assign appropriate token types for special tokens - if token == "" { - types[id] = TOKEN_TYPE_CONTROL - } else if token == "" { - types[id] = TOKEN_TYPE_CONTROL - } else if token == "[INST]" || token == "[/INST]" { - types[id] = TOKEN_TYPE_CONTROL - } - } - - // In Tekken, we don't need to load merges separately as they're part of the model - var merges []string - - // Create vocabulary object - vocabObj := &Vocabulary{ - Values: tokens, - Types: types, - Scores: scores, - Merges: merges, - BOS: vocab[config.BosToken], - EOS: vocab[config.EosToken], - AddBOS: config.AddBosToken, - AddEOS: config.AddEosToken, - } - - // Use pattern from tokenizer.json if available - if pattern != "" { - // Ensure pattern has proper escaping for Go regexp - pattern = strings.ReplaceAll(pattern, "p{", "\\p{") - return NewBytePairEncoding(pattern, vocabObj) - } - - // Fallback pattern if not found - return NewBytePairEncoding( - `\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+|\s+`, - vocabObj, - ) -} - -func TestTekken(t *testing.T) { - // Skip if the test data isn't available - if _, err := os.Stat(filepath.Join("testdata", "mistral-small")); os.IsNotExist(err) { - t.Skip("Mistral-small test data not available") - } - - tokenizer := tekken(t) - - t.Run("whitespace_handling", func(t *testing.T) { - t.Parallel() - - // The key difference from SentencePiece is that Tekken doesn't prepend whitespace - cases := []struct { - input string - expected string - }{ - {" hello", " hello"}, - {"hello ", "hello "}, - {"hello world", "hello world"}, - {" hello world ", " hello world "}, - } - - for _, tc := range cases { - ids, err := tokenizer.Encode(tc.input, false) - if err != nil { - t.Errorf("Failed to encode %q: %v", tc.input, err) - continue - } - - decoded, err := tokenizer.Decode(ids) - if err != nil { - t.Errorf("Failed to decode tokens for %q: %v", tc.input, err) - continue - } - - if decoded != tc.expected { - t.Errorf("Whitespace handling: got %q, want %q", decoded, tc.expected) - } - } - }) - - t.Run("chat_templates", func(t *testing.T) { - t.Parallel() - - // Test the Tekken chat template format which doesn't have spaces after special tokens - templates := []struct { - input string - expectSpace bool // whether we expect a space after special tokens - }{ - {"[INST]user message[/INST]", false}, - {"[INST] user message[/INST]", true}, - {"[INST]user message [/INST]", true}, - } - - for _, tc := range templates { - ids, err := tokenizer.Encode(tc.input, false) - if err != nil { - t.Errorf("Failed to encode %q: %v", tc.input, err) - continue - } - - decoded, err := tokenizer.Decode(ids) - if err != nil { - t.Errorf("Failed to decode tokens for %q: %v", tc.input, err) - continue - } - - // Check if there's a space after special tokens - hasSpaceAfterINST := strings.Contains(decoded, "[INST] ") - - if hasSpaceAfterINST != tc.expectSpace { - t.Errorf("Chat template space handling: got space=%v, want space=%v for %q", - hasSpaceAfterINST, tc.expectSpace, tc.input) - } - } - }) - - t.Run("special_tokens", func(t *testing.T) { - t.Parallel() - - // Test how Tekken handles special tokens - cases := []struct { - input string - expected []string // We'll check if these tokens are in the decoded output - }{ - {"[INST]hello[/INST]", []string{"", "[INST]", "hello", "[/INST]"}}, - {"[INST]hello[/INST]", []string{"[INST]", "hello", "[/INST]", ""}}, - {"[INST]hello[/INST][INST]again[/INST]", []string{"", "[INST]", "hello", "[/INST]", "", "[INST]", "again", "[/INST]"}}, - } - - for _, tc := range cases { - ids, err := tokenizer.Encode(tc.input, false) - if err != nil { - t.Errorf("Failed to encode %q: %v", tc.input, err) - continue - } - - decoded, err := tokenizer.Decode(ids) - if err != nil { - t.Errorf("Failed to decode tokens for %q: %v", tc.input, err) - continue - } - - for _, expected := range tc.expected { - if !strings.Contains(decoded, expected) { - t.Errorf("Special token handling: %q missing in decoded output %q", expected, decoded) - } - } - } - }) - - t.Run("vocabulary_coverage", func(t *testing.T) { - t.Parallel() - - // Tekken has a larger vocabulary, so test coverage of various token types - samples := []string{ - "Hello world!", - "This is a test of the Tekken tokenizer.", - "It has a considerably larger vocabulary size.", - "Special characters: !@#$%^&*()", - "Numbers: 1234567890", - "Multiple languages: こんにちは 你好 안녕하세요", - "Code snippets: def function(): return True", - } - - for _, sample := range samples { - ids, err := tokenizer.Encode(sample, false) - if err != nil { - t.Errorf("Failed to encode %q: %v", sample, err) - continue - } - - decoded, err := tokenizer.Decode(ids) - if err != nil { - t.Errorf("Failed to decode tokens for %q: %v", sample, err) - continue - } - - if decoded != sample { - t.Errorf("Vocabulary coverage: got %q, want %q", decoded, sample) - } - } - }) - - t.Run("splitting_behavior", func(t *testing.T) { - t.Parallel() - - // Test the splitting behavior which might differ from SentencePiece - cases := map[string][]string{ - "Hello World!": {"Hello", " World", "!"}, - "user message": {"user", " message"}, - "[INST]hello": {"[INST]", "hello"}, - "hello[/INST]": {"hello", "[/INST]"}, - } - - for s, want := range cases { - got := slices.Collect(tokenizer.(*BytePairEncoding).split(s)) - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("Splitting behavior no match (-want +got):\n%s", diff) - } - } - }) - - t.Run("full_chat_sequence", func(t *testing.T) { - t.Parallel() - - // Test a complete chat sequence with Tekken's format - chatSequence := "[INST]user message[/INST]assistant message[INST]new user message[/INST]" - - ids, err := tokenizer.Encode(chatSequence, false) - if err != nil { - t.Fatalf("Failed to encode chat sequence: %v", err) - } - - decoded, err := tokenizer.Decode(ids) - if err != nil { - t.Fatalf("Failed to decode chat sequence tokens: %v", err) - } - - // In Tekken, the whitespace shouldn't be added after special tokens - if strings.Contains(decoded, "[INST] ") { - t.Errorf("Tekken chat sequence has unexpected space after [INST]: %q", decoded) - } - - if strings.Contains(decoded, "[/INST] ") { - t.Errorf("Tekken chat sequence has unexpected space after [/INST]: %q", decoded) - } - }) -} - func BenchmarkBytePairEncoding(b *testing.B) { tokenizer := llama(b) bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt")) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 89d4d170c..31d20db80 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -182,10 +182,6 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, * return nil, nil, err } - for _, t := range tokens { - decoded, _ := s.model.(model.TextProcessor).Decode([]int32{t}) - fmt.Println("token", t, "decoded", decoded) - } for _, t := range tokens { inputs = append(inputs, input.Input{Token: t}) }