diff --git a/kvcache/causal.go b/kvcache/causal.go index c7b3595ec6..d804f3bf04 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -393,7 +393,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { mask[i] = float32(math.Inf(-1)) } - maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize) + maskTensor := ctx.Input().FromFloats(mask, length, batchSize) if c.config.MaskDType != ml.DTypeF32 { maskTensor = maskTensor.Cast(ctx, c.config.MaskDType) @@ -725,7 +725,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { offsets = offsets[batchFirst : batchLast+1] ctx := c.backend.NewContext() - kShift := ctx.Input().FromIntSlice(offsets, len(offsets)) + kShift := ctx.Input().FromInts(offsets, len(offsets)) for i, key := range c.keys { if key == nil { diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 7e4fc3b109..dd0c044272 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -477,7 +477,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) } cache.SetLayer(0) - tensor := context.FromFloatSlice(test.in, test.inShape...) + tensor := context.FromFloats(test.in, test.inShape...) cache.Put(context, tensor, tensor) out, _, mask := cache.Get(context) @@ -519,7 +519,7 @@ func TestCanResume(t *testing.T) { } cache.SetLayer(0) - tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5}, 1, 1, 5) + tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5) cache.Put(context, tensor, tensor) // with window size 4, nothing has slid out of the window yet @@ -549,7 +549,7 @@ func TestCanResume(t *testing.T) { } cache.SetLayer(0) - tensor = context.FromFloatSlice([]float32{6}, 1, 1, 1) + tensor = context.FromFloats([]float32{6}, 1, 1, 1) cache.Put(context, tensor, tensor) // only the latest position has overlapping windows @@ -594,7 +594,7 @@ func TestCanResumeSWAMem(t *testing.T) { } cache.SetLayer(0) - tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7) + tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7) cache.Put(context, tensor, tensor) // shift window by adding position 7 @@ -607,7 +607,7 @@ func TestCanResumeSWAMem(t *testing.T) { } cache.SetLayer(0) - tensor = context.FromFloatSlice([]float32{8}, 1, 1, 1) + tensor = context.FromFloats([]float32{8}, 1, 1, 1) cache.Put(context, tensor, tensor) // only the latest position has overlapping windows @@ -670,7 +670,7 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { return c.Empty(dtype, shape...) } -func (c *testContext) FromFloatSlice(s []float32, shape ...int) ml.Tensor { +func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor { t := c.Empty(ml.DTypeF32, shape...).(*testTensor) copy(t.data, s) @@ -678,13 +678,13 @@ func (c *testContext) FromFloatSlice(s []float32, shape ...int) ml.Tensor { return t } -func (c *testContext) FromIntSlice(s []int32, shape ...int) ml.Tensor { +func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor { f := make([]float32, len(s)) for i := range f { f[i] = float32(s[i]) } - out := c.FromFloatSlice(f, shape...) + out := c.FromFloats(f, shape...) out.(*testTensor).dtype = ml.DTypeI32 return out @@ -696,7 +696,7 @@ func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tenso s = append(s, i) } - out := c.FromFloatSlice(s, len(s)) + out := c.FromFloats(s, len(s)) out.(*testTensor).dtype = dtype return out } diff --git a/ml/backend.go b/ml/backend.go index 351942d54a..764ff0854b 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -98,8 +98,9 @@ func NewBackend(modelPath string, params BackendParams) (Backend, error) { type Context interface { Empty(dtype DType, shape ...int) Tensor Zeros(dtype DType, shape ...int) Tensor - FromFloatSlice(s []float32, shape ...int) Tensor - FromIntSlice(s []int32, shape ...int) Tensor + FromBytes(dtype DType, s []byte, shape ...int) Tensor + FromFloats(s []float32, shape ...int) Tensor + FromInts(s []int32, shape ...int) Tensor // Arange creates a 1D tensor with values within an interval (start, stop] increased by step. Arange(start, stop, step float32, dtype DType) Tensor @@ -136,7 +137,9 @@ type Tensor interface { Bytes() []byte Floats() []float32 - SetValueFromIntSlice(s []int32) + FromBytes([]byte) + FromFloats([]float32) + FromInts([]int32) Neg(ctx Context) Tensor Add(ctx Context, t2 Tensor) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 88078d7798..64aae14121 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -12,6 +12,7 @@ import "C" import ( "context" + "encoding/binary" "errors" "fmt" "io" @@ -871,7 +872,7 @@ func pad(length, pad C.size_t) C.size_t { return ((length + pad - 1) / pad) * pad } -func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor { +func (c *Context) newTensor(dtype ml.DType, shape []int) *Tensor { if c.buft == nil { panic("set Input or Layer before creating tensors") } @@ -915,7 +916,7 @@ func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { t := c.newTensor(dtype, shape) if c.b.allocMemory { - C.ggml_set_zero(t.(*Tensor).t) + C.ggml_set_zero(t.t) } return t } @@ -936,25 +937,34 @@ func checkShape[S ~[]E, E any](s S, shape ...int) { } } -func (c *Context) FromFloatSlice(s []float32, shape ...int) ml.Tensor { - checkShape(s, shape...) - - t := c.newTensor(ml.DTypeF32, shape) - - if c.b.allocMemory && len(s) > 0 { - C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) +func (c Context) FromBytes(dtype ml.DType, s []uint8, shape ...int) ml.Tensor { + // Unchecked to handle quantized types + t := c.newTensor(dtype, shape) + if c.b.allocMemory { + t.FromBytes(s) } return t } -func (c *Context) FromIntSlice(s []int32, shape ...int) ml.Tensor { +func (c *Context) FromFloats(s []float32, shape ...int) ml.Tensor { + checkShape(s, shape...) + + t := c.newTensor(ml.DTypeF32, shape) + + if c.b.allocMemory { + t.FromFloats(s) + } + + return t +} + +func (c *Context) FromInts(s []int32, shape ...int) ml.Tensor { checkShape(s, shape...) t := c.newTensor(ml.DTypeI32, shape) - - if c.b.allocMemory && len(s) > 0 { - C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) + if c.b.allocMemory { + t.FromInts(s) } return t @@ -975,7 +985,7 @@ func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { arange = append(arange, int32(i)) } - return c.Input().FromIntSlice(arange, len(arange)) + return c.Input().FromInts(arange, len(arange)) default: panic("unsupported dtype for arange") } @@ -1045,10 +1055,26 @@ func (t *Tensor) Floats() (data []float32) { return } -func (t *Tensor) SetValueFromIntSlice(s []int32) { - if len(s) > 0 { - C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t)) +func tensorSet[S ~[]E, E byte | float32 | int32](t *Tensor, s S) { + if len(s) == 0 { + return } + if int(C.ggml_nbytes(t.t)) != len(s)*binary.Size(s[0]) { + panic("data size does not match tensor size") + } + C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t)) +} + +func (t *Tensor) FromBytes(s []byte) { + tensorSet(t, s) +} + +func (t *Tensor) FromFloats(s []float32) { + tensorSet(t, s) +} + +func (t *Tensor) FromInts(s []int32) { + tensorSet(t, s) } func (t *Tensor) DType() ml.DType { @@ -1622,13 +1648,3 @@ func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor { t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)), } } - -func (c Context) FromBytes(dtype ml.DType, s []uint8, shape ...int) ml.Tensor { - // Unchecked to handle quantized types - t := c.newTensor(dtype, shape) - if c.b.allocMemory && len(s) > 0 { - C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) - } - - return t -} diff --git a/model/models/bert/embed.go b/model/models/bert/embed.go index 166c11e139..2d78710f79 100644 --- a/model/models/bert/embed.go +++ b/model/models/bert/embed.go @@ -30,7 +30,7 @@ type Model struct { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize)) - hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)))) + hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromInts(batch.Positions, len(batch.Positions)))) hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps) for _, layer := range m.Layers { diff --git a/model/models/deepseek2/model.go b/model/models/deepseek2/model.go index 7e57f72dd0..cfd579ca50 100644 --- a/model/models/deepseek2/model.go +++ b/model/models/deepseek2/model.go @@ -302,7 +302,7 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 2b16dc62e4..06c71fc3b1 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -175,7 +175,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 27da889e40..62f51074a2 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -101,7 +101,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return nil, err } - pixelValues := ctx.Input().FromFloatSlice(f32s, + pixelValues := ctx.Input().FromFloats(f32s, m.ImageProcessor.imageSize, m.ImageProcessor.imageSize, m.ImageProcessor.numChannels, diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index d5bdd410df..8d1a1be6a3 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -163,7 +163,7 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, } func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor { - positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize))) diff --git a/model/models/gemma3n/model_text.go b/model/models/gemma3n/model_text.go index 1333151b3e..ec038a287f 100644 --- a/model/models/gemma3n/model_text.go +++ b/model/models/gemma3n/model_text.go @@ -29,9 +29,9 @@ type TextModel struct { } func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) { - positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) // Create a tensor of a single float32 value of 1.0 to use for altup correction - one := ctx.Input().FromFloatSlice([]float32{1.0}, 1) + one := ctx.Input().FromFloats([]float32{1.0}, 1) inputs := m.TokenEmbedding.Forward(ctx, batch.Inputs, math.Sqrt(float64(m.hiddenSize))) inputsPerLayer := m.PerLayerProjector.Forward(ctx, batch, inputs, &m.TextOptions) diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go index 6a3270651b..08bf753d52 100644 --- a/model/models/gptoss/model.go +++ b/model/models/gptoss/model.go @@ -30,9 +30,9 @@ type Transformer struct { // Forward implements model.Model. func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) - positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) - one := ctx.Input().FromFloatSlice([]float32{1}, 1) + one := ctx.Input().FromFloats([]float32{1}, 1) for i, block := range m.TransformerBlocks { m.Cache.SetLayer(i) if c, ok := m.Cache.(*kvcache.WrapperCache); ok { diff --git a/model/models/llama/model.go b/model/models/llama/model.go index c03f04a0d8..52c66ba570 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -179,7 +179,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index e80fbaed63..5eeac07c2b 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -76,7 +76,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return nil, err } - tilesLocal := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels) + tilesLocal := ctx.Input().FromFloats(pixelsLocal, size.X, size.Y, m.numChannels) ratioW, ratioH := size.X/m.imageSize, size.Y/m.imageSize @@ -87,7 +87,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input pixelValues := tilesLocal if len(pixelsGlobal) > 0 { - tilesGlobal := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels) + tilesGlobal := ctx.Input().FromFloats(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels) pixelValues = pixelValues.Concat(ctx, tilesGlobal, 3) } @@ -174,7 +174,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil } diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index e056391f56..96b5d24d87 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -211,7 +211,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor scales[i] = float32(math.Log(math.Floor(((float64(p)+1.0)/float64(m.attentionFloorScale))+1.0))*m.attentionScale + 1.0) } - attentionScales = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales)) + attentionScales = ctx.Input().FromFloats(scales, 1, 1, len(scales)) } for i, layer := range m.Layers { diff --git a/model/models/llama4/model_vision.go b/model/models/llama4/model_vision.go index dc6f82b84c..1aa50aec46 100644 --- a/model/models/llama4/model_vision.go +++ b/model/models/llama4/model_vision.go @@ -245,7 +245,7 @@ func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) { } } - ropeFreqs := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2) + ropeFreqs := ctx.Input().FromFloats(freqs, freqDim/2, numPatches, 2) ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches) diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index 5c46615e92..e071d71a89 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -114,7 +114,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return nil, err } - pixelValues := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels) + pixelValues := ctx.Input().FromFloats(f32s, size.X, size.Y, m.ImageProcessor.numChannels) visionOutputs := m.VisionModel.Forward(ctx, pixelValues) features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size) @@ -158,7 +158,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil } diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index 3bfb8c90a3..ce3110c7c4 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -110,8 +110,8 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor) } } - h := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2) - w := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2) + h := ctx.Input().FromFloats(frequenciesHeight, maxPatchesPerSide, frequencies/2) + w := ctx.Input().FromFloats(frequenciesWidth, maxPatchesPerSide, frequencies/2) h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) @@ -144,7 +144,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { } } - positionIDs := ctx.Input().FromIntSlice(positions, len(positions)) + positionIDs := ctx.Input().FromInts(positions, len(positions)) positionEmbedding := m.positionalEmbedding(ctx, positionIDs) cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx) diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 769743694c..58fd5adcfc 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -80,8 +80,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input f32s = f32s[:m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles] } - pixelValues := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles) - aspectRatio := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1) + pixelValues := ctx.Input().FromFloats(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles) + aspectRatio := ctx.Input().FromInts([]int32{int32(ratio.rank)}, 1) positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32) crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio) @@ -106,7 +106,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal[0].Tensor } - positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) // TODO: attention mask, cross attention mask return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index 2e2347102e..10a1e65cf1 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -102,7 +102,7 @@ type Model struct { // Forward implements model.Model. func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index 6898e38cac..13fa3fee1e 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -69,7 +69,7 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, * m.ImageProcessor.patchSize * m.ImageProcessor.patchSize numPatches := grid.Temporal * grid.Height * grid.Width - pixelValues := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches) + pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches) return pixelValues, grid, nil } @@ -139,7 +139,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache) } diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index 3dd60e3baf..88b2c005c9 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -43,7 +43,7 @@ func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int } } - mask := ctx.Input().FromFloatSlice(flat, seqLength, seqLength) + mask := ctx.Input().FromFloats(flat, seqLength, seqLength) // Reshape to match [seqLength, seqLength, 1] for broadcasting mask = mask.Reshape(ctx, seqLength, seqLength, 1) @@ -299,7 +299,7 @@ func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) } } - t := ctx.Input().FromIntSlice(index, len(index)) + t := ctx.Input().FromInts(index, len(index)) return t, bounds } @@ -319,7 +319,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim))) } } - freqs := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize) + freqs := ctx.Input().FromFloats(freqVals, freq, maxGridSize) // Create position coordinates (y,x pairs) for the grid // In PyTorch: Equivalent to generating position ids with torch.arange() @@ -329,7 +329,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor coords = append(coords, int32(y), int32(x)) } } - pos := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height) + pos := ctx.Input().FromInts(coords, 2, grid.Width, grid.Height) // Reshape and permute positions to match spatial merging pattern pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge) diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index 9fd6e313d9..72ce36ed94 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -181,7 +181,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { // Forward implements model.Model. func (m *Model) forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) diff --git a/runner/ollamarunner/multimodal.go b/runner/ollamarunner/multimodal.go index fbdc7d72a8..78ceb771c0 100644 --- a/runner/ollamarunner/multimodal.go +++ b/runner/ollamarunner/multimodal.go @@ -102,7 +102,7 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten for i, t := range entry.mm { if in == t.Tensor { if !reserve { - return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...), nil + return ctx.Input().FromFloats(entry.data[i], t.Tensor.Shape()...), nil } else { return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index b0cf637303..ea039157f3 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -599,7 +599,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er // Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs)) - batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs)) + batch.Outputs = nextBatch.ctx.Input().FromInts(batchOutputs, len(batchOutputs)) nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch) if err != nil { err = fmt.Errorf("failed to build graph: %w", err) @@ -692,7 +692,7 @@ func (s *Server) computeBatch(activeBatch batchState) { // At this point the seqs are ready for forwardBatch to move forward so unblock s.mu.Unlock() - activeBatch.batch.Inputs.SetValueFromIntSlice(batchInputs) + activeBatch.batch.Inputs.FromInts(batchInputs) activeBatch.ctx.ComputeWithNotify( func() { logutil.Trace("computeBatch: signaling computeStartedCh", "batchID", activeBatch.id) @@ -1090,7 +1090,7 @@ func (s *Server) reserveWorstCaseGraph() error { batch.Positions[i] = int32(i) } - batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) + batch.Inputs = ctx.Input().FromInts(batchInputs, len(batchInputs)) batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel) cache := s.model.Config().Cache