From 178761aef3d5c66fecd5bb394b38475c1543a4b5 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 16 Apr 2025 15:25:34 -0700 Subject: [PATCH] image processing Co-authored-by: Patrick Devine --- model/models/llama4/model.go | 31 ++- model/models/llama4/process_image.go | 167 ++++++++++++ model/models/llama4/process_image_test.go | 300 ++++++++++++++++++++++ 3 files changed, 494 insertions(+), 4 deletions(-) create mode 100644 model/models/llama4/process_image.go create mode 100644 model/models/llama4/process_image_test.go diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index bbc64e74b..9a1d06ebd 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -15,6 +15,7 @@ import ( type Model struct { model.Base model.BytePairEncoding + ImageProcessor *VisionModel `gguf:"v,vision"` *Projector `gguf:"mm"` @@ -43,8 +44,9 @@ func New(c fs.Config) (model.Model, error) { AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), }, ), - VisionModel: newVisionModel(c), - TextModel: newTextModel(c), + ImageProcessor: newImageProcessor(c), + VisionModel: newVisionModel(c), + TextModel: newTextModel(c), } m.Cache = kvcache.NewWrapperCache( @@ -66,21 +68,42 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return nil, err } - f32s, aspectRatio, err := m.ProcessImage(ctx, img) + pixelsLocal, pixelsGlobal, size, err := m.ProcessImage(img) if err != nil { return nil, err } - pixelValues, err := ctx.Input().FromFloatSlice(f32s, len(f32s)) + tilesLocal, err := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels) if err != nil { return nil, err } + ratioW, ratioH := int(size.X/m.imageSize), int(size.Y/m.imageSize) + + tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW, ratioW, size.Y, m.numChannels).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW*size.Y/ratioH, ratioH, ratioW, m.numChannels).Permute(ctx, 0, 3, 2, 1).Contiguous(ctx) + tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW, size.Y/ratioH, m.numChannels, ratioH*ratioW) + + pixelValues := tilesLocal + + if len(pixelsGlobal) > 0 { + tilesGlobal, err := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels) + if err != nil { + return nil, err + } + + pixelValues = pixelValues.Concat(ctx, tilesGlobal, 3) + } + visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3)) return m.Projector.Forward(ctx, visionOutputs), nil } +func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { + return inputs, nil +} + func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) if err != nil { diff --git a/model/models/llama4/process_image.go b/model/models/llama4/process_image.go new file mode 100644 index 000000000..916f6f905 --- /dev/null +++ b/model/models/llama4/process_image.go @@ -0,0 +1,167 @@ +package llama4 + +import ( + "cmp" + "image" + "math" + "slices" + "sort" + + "golang.org/x/image/draw" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/model/imageproc" +) + +type ImageProcessor struct { + imageSize, patchSize, numChannels, maxUpscalingSize int +} + +func newImageProcessor(c fs.Config) ImageProcessor { + return ImageProcessor{ + imageSize: int(c.Uint("vision.image_size")), + patchSize: int(c.Uint("vision.patch_size")), + numChannels: int(c.Uint("vision.num_channels", 3)), + maxUpscalingSize: int(c.Uint("vision.max_upscaling_size", 448)), + } +} + +func factors(n int) []int { + var result []int + seen := make(map[int]bool) + + for i := 1; i <= n/2; i++ { + if n%i == 0 && !seen[i] { + result = append(result, i) + seen[i] = true + } + } + + result = append(result, n) + sort.Ints(result) + + return result +} + +func (p ImageProcessor) supportedResolutions() []image.Point { + var resolutions []image.Point + + aspectMap := make(map[float64][]image.Point) + for i := p.patchSize; i >= 1; i-- { + for _, f := range factors(i) { + x := f + y := i / f + k := float64(y) / float64(x) + aspectMap[k] = append(aspectMap[k], image.Point{x, y}) + } + } + + for _, v := range aspectMap { + for _, i := range v { + resolutions = append(resolutions, image.Point{i.X * p.imageSize, i.Y * p.imageSize}) + } + } + + return resolutions +} + +func (p ImageProcessor) bestResolution(img image.Point, possibleResolutions []image.Point, resizeToMaxCanvas bool) image.Point { + w, h := img.X, img.Y + + scales := make([]float64, len(possibleResolutions)) + + for i, res := range possibleResolutions { + scaleW := float64(res.X) / float64(w) + scaleH := float64(res.Y) / float64(h) + scale := math.Min(scaleW, scaleH) + + scales[i] = scale + } + + minAboveOne := func(scales []float64) (float64, bool) { + min := math.MaxFloat64 + found := false + + for _, s := range scales { + if s >= 1.0 && s < min { + min = s + found = true + } + } + + return min, found + } + + bestScale, ok := minAboveOne(scales) + if resizeToMaxCanvas || !ok { + bestScale = slices.Max(scales) + } + + var bestOptions []image.Point + for i, scale := range scales { + if math.Abs(scale-bestScale) < 1e-6 { + bestOptions = append(bestOptions, possibleResolutions[i]) + } + } + + var chosenResolution image.Point + if len(bestOptions) > 1 { + chosenResolution = slices.MinFunc(bestOptions, func(a, b image.Point) int { + return cmp.Compare(a.X*a.Y, b.X*b.Y) + }) + } else { + chosenResolution = bestOptions[0] + } + + return chosenResolution +} + +func (p ImageProcessor) maxResolution(imageRes, targetRes image.Point) image.Point { + scaleW := float64(targetRes.X) / float64(imageRes.X) + scaleH := float64(targetRes.Y) / float64(imageRes.Y) + + var newRes image.Point + if scaleW < scaleH { + newRes = image.Point{ + targetRes.X, + int(math.Min(math.Floor(float64(imageRes.Y)*scaleW), float64(targetRes.Y))), + } + } else { + newRes = image.Point{ + int(math.Min(math.Floor(float64(imageRes.X)*scaleH), float64(targetRes.X))), + targetRes.Y, + } + } + + return newRes +} + +func (p ImageProcessor) pad(src image.Image, outputSize image.Point) image.Image { + dst := image.NewRGBA(image.Rect(0, 0, outputSize.X, outputSize.Y)) + draw.Draw(dst, src.Bounds(), src, image.Point{}, draw.Over) + return dst +} + +func (p ImageProcessor) ProcessImage(img image.Image) (pixelsLocal, pixelsGlobal []float32, targetSize image.Point, _ error) { + img = imageproc.Composite(img) + + targetSize = p.bestResolution(img.Bounds().Max, p.supportedResolutions(), false) + targetSizeWithoutDistortion := targetSize + if p.maxUpscalingSize > 0 { + targetSizeWithoutDistortion = p.maxResolution(img.Bounds().Max, targetSize) + targetSizeWithoutDistortion.X = min(max(img.Bounds().Max.X, p.maxUpscalingSize), targetSize.X) + targetSizeWithoutDistortion.Y = min(max(img.Bounds().Max.Y, p.maxUpscalingSize), targetSize.Y) + } + + newSizeWithoutDistortion := p.maxResolution(img.Bounds().Max, targetSizeWithoutDistortion) + + padded := p.pad(imageproc.Resize(img, newSizeWithoutDistortion, imageproc.ResizeBilinear), targetSize) + pixelsLocal = imageproc.Normalize(padded, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD, true, true) + + if targetSize.X/p.imageSize*targetSize.Y/p.imageSize > 1 { + padded := imageproc.Resize(img, image.Point{p.imageSize, p.imageSize}, imageproc.ResizeBilinear) + pixelsGlobal = imageproc.Normalize(padded, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD, true, true) + } + + return pixelsLocal, pixelsGlobal, targetSize, nil +} diff --git a/model/models/llama4/process_image_test.go b/model/models/llama4/process_image_test.go new file mode 100644 index 000000000..6dde549ea --- /dev/null +++ b/model/models/llama4/process_image_test.go @@ -0,0 +1,300 @@ +package llama4 + +import ( + "cmp" + "image" + "image/color" + "reflect" + "slices" + "testing" + + gocmp "github.com/google/go-cmp/cmp" +) + +func TestFactors(t *testing.T) { + tests := []struct { + name string + input int + expected []int + }{ + { + name: "factors of 1", + input: 1, + expected: []int{1}, + }, + { + name: "factors of 2", + input: 2, + expected: []int{1, 2}, + }, + { + name: "factors of 6", + input: 6, + expected: []int{1, 2, 3, 6}, + }, + { + name: "factors of 28", + input: 28, + expected: []int{1, 2, 4, 7, 14, 28}, + }, + { + name: "factors of 49", + input: 49, + expected: []int{1, 7, 49}, + }, + { + name: "factors of 97 (prime)", + input: 97, + expected: []int{1, 97}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := factors(tt.input) + if !reflect.DeepEqual(actual, tt.expected) { + t.Errorf("factors(%d) = %v; want %v", tt.input, actual, tt.expected) + } + }) + } +} + +func TestSupportedResolutions(t *testing.T) { + expectedResolutions := []image.Point{ + {X: 3360, Y: 336}, + {X: 672, Y: 2688}, + {X: 336, Y: 1344}, + {X: 336, Y: 4032}, + {X: 1008, Y: 1344}, + {X: 1344, Y: 1008}, + {X: 336, Y: 1680}, + {X: 1680, Y: 336}, + {X: 336, Y: 5040}, + {X: 4032, Y: 336}, + {X: 2352, Y: 336}, + {X: 2688, Y: 672}, + {X: 1344, Y: 336}, + {X: 5376, Y: 336}, + {X: 2352, Y: 672}, + {X: 672, Y: 1008}, + {X: 1008, Y: 672}, + {X: 336, Y: 5376}, + {X: 1680, Y: 1008}, + {X: 5040, Y: 336}, + {X: 336, Y: 3024}, + {X: 3024, Y: 336}, + {X: 336, Y: 2688}, + {X: 672, Y: 1344}, + {X: 336, Y: 672}, + {X: 336, Y: 2352}, + {X: 2016, Y: 672}, + {X: 1008, Y: 336}, + {X: 336, Y: 3360}, + {X: 336, Y: 4368}, + {X: 1008, Y: 1680}, + {X: 336, Y: 4704}, + {X: 4704, Y: 336}, + {X: 1344, Y: 672}, + {X: 672, Y: 336}, + {X: 2688, Y: 336}, + {X: 3696, Y: 336}, + {X: 2016, Y: 336}, + {X: 1344, Y: 1344}, + {X: 1008, Y: 1008}, + {X: 672, Y: 672}, + {X: 336, Y: 336}, + {X: 4368, Y: 336}, + {X: 672, Y: 2016}, + {X: 336, Y: 1008}, + {X: 336, Y: 3696}, + {X: 672, Y: 1680}, + {X: 1680, Y: 672}, + {X: 336, Y: 2016}, + {X: 672, Y: 2352}, + } + + sortResolutionFunc := func(a, b image.Point) int { + return cmp.Or(cmp.Compare(a.X, b.X), cmp.Compare(a.Y, b.Y)) + } + + slices.SortStableFunc(expectedResolutions, sortResolutionFunc) + + imgProc := ImageProcessor{ + imageSize: 336, + patchSize: 16, + numChannels: 3, + maxUpscalingSize: 448, + } + + actualResolutions := imgProc.supportedResolutions() + slices.SortStableFunc(actualResolutions, sortResolutionFunc) + + if diff := gocmp.Diff(expectedResolutions, actualResolutions); diff != "" { + t.Errorf("supportedResolutions() mismatch (-want +got):\n%s", diff) + } +} + +func TestBestResolution(t *testing.T) { + tests := []struct { + name string + size image.Point + resolutions []image.Point + max bool + expected image.Point + }{ + { + "normal", + image.Point{800, 600}, + []image.Point{ + {300, 200}, + {640, 480}, + {800, 600}, + {1024, 768}, + {1600, 1200}, + }, + false, + image.Point{800, 600}, + }, + { + "max", + image.Point{800, 600}, + []image.Point{ + {300, 200}, + {640, 480}, + {800, 600}, + {1024, 768}, + {1600, 1200}, + }, + true, + image.Point{1600, 1200}, + }, + { + "mid", + image.Point{1000, 700}, + []image.Point{ + {300, 200}, + {640, 480}, + {800, 600}, + {1024, 768}, + {1600, 1200}, + }, + false, + image.Point{1024, 768}, + }, + { + "smol", + image.Point{100, 100}, + []image.Point{ + {300, 200}, + {640, 480}, + {800, 600}, + {1024, 768}, + {1600, 1200}, + }, + false, + image.Point{300, 200}, + }, + { + "huge", + image.Point{10000, 10000}, + []image.Point{ + {300, 200}, + {640, 480}, + {800, 600}, + {1024, 768}, + {1600, 1200}, + }, + false, + image.Point{1600, 1200}, + }, + } + + p := ImageProcessor{} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := p.bestResolution(tt.size, tt.resolutions, tt.max) + if diff := gocmp.Diff(tt.expected, actual); diff != "" { + t.Errorf("best resolution mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestMaxResolution(t *testing.T) { + tests := []struct { + name string + origRes image.Point + targetRes image.Point + expected image.Point + }{ + { + "normal", + image.Point{800, 600}, + image.Point{800, 600}, + image.Point{800, 600}, + }, + { + "skew", + image.Point{800, 600}, + image.Point{1100, 700}, + image.Point{933, 700}, + }, + } + + p := ImageProcessor{} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := p.maxResolution(tt.origRes, tt.targetRes) + if !reflect.DeepEqual(actual, tt.expected) { + t.Errorf("max resolution; got %v want %v", actual, tt.expected) + } + }) + } +} + +func TestProcessImage(t *testing.T) { + imgProc := ImageProcessor{ + imageSize: 336, + patchSize: 16, + numChannels: 3, + maxUpscalingSize: 448, + } + + generateImage := func(seed int) image.Image { + width, height := 20, 10 + img := image.NewRGBA(image.Rect(0, 0, width, height)) + + for x := range width { + // Use the seed to vary color generation + r := uint8((seed + x*11) % 256) + g := uint8((seed + x*17) % 256) + b := uint8((seed + x*23) % 256) + + c := color.RGBA{R: r, G: g, B: b, A: 255} + for y := range height { + img.Set(x, y, c) + } + } + + return img + } + + pixelsLocal, pixelsGlobal, targetSize, err := imgProc.ProcessImage(generateImage(12)) + if err != nil { + t.Error(err) + } + + if n := len(pixelsLocal); n != 336*336*3 { + t.Errorf("unexpected size of f32s: %d", n) + } + + if n := len(pixelsGlobal); n > 0 { + t.Errorf("unexpected size of f32s: %d", n) + } + + if !targetSize.Eq(image.Point{336, 336}) { + t.Errorf("unexpected target size: %v", targetSize) + } +}