diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 38f327f82..f89a4a450 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -361,6 +361,17 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { return out, nil } +func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { + s := make([]float32, int((stop-start)/step)) + for i := range s { + s[i] = float32(start + float32(i)*step) + } + + out, _ := c.FromFloatSlice(s, len(s)) + out.(*testTensor).dtype = dtype + return out +} + func (c *testContext) Input() ml.Context { return c } func (c *testContext) Output() ml.Context { return c } func (c *testContext) Layer(int) ml.Context { return c } diff --git a/ml/backend.go b/ml/backend.go index 4823cd6f9..c58c2deaa 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -104,6 +104,7 @@ type Context interface { Zeros(dtype DType, shape ...int) Tensor FromFloatSlice(s []float32, shape ...int) (Tensor, error) FromIntSlice(s []int32, shape ...int) (Tensor, error) + Arange(start, stop, step float32, dtype DType) Tensor Forward(...Tensor) Context Compute(...Tensor) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index d5306aa98..0051a097e 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -651,6 +651,32 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { return t, nil } +func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { + switch dtype { + case ml.DTypeF32: + // ggml_arange creates a float32 tensor + return &Tensor{ + b: c.b, + t: C.ggml_arange(c.ctx, C.float(start), C.float(stop), C.float(step)), + } + case ml.DTypeI32: + // ggml_cast does not support float32 to int32 conversion + arange := make([]int32, int((stop-start)/step)) + for i := range arange { + arange[i] = int32(i) * int32(step) + } + + t, err := c.Input().FromIntSlice(arange, len(arange)) + if err != nil { + panic(err) + } + + return t + default: + panic("unsupported dtype for arange") + } +} + func (c *Context) Close() { if c != nil { C.ggml_free(c.ctx) diff --git a/model/models/gemma3/model_vision.go b/model/models/gemma3/model_vision.go index 94aa27bd7..f39cb4edd 100644 --- a/model/models/gemma3/model_vision.go +++ b/model/models/gemma3/model_vision.go @@ -91,16 +91,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize) hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) - positions := make([]int32, numPatches) - for i := range positions { - positions[i] = int32(i) - } - - positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions)) - if err != nil { - panic(err) - } - + positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeI32) hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs)) for _, layer := range m.Layers { diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 988a189d4..0a1e33243 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -92,16 +92,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return nil, err } - positions := make([]int32, 1601) - for i := range positions { - positions[i] = int32(i) - } - - positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions)) - if err != nil { - return nil, err - } - + positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32) crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio) return m.Projector.Forward(ctx, crossAttentionStates), nil }