diff --git a/kvcache/causal.go b/kvcache/causal.go index 9bc1d5da27..f6bacaaf86 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -211,10 +211,9 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e c.curCellRange.max = len(c.cells) - 1 } - var err error - c.curMask, err = c.buildMask(ctx) + c.curMask = c.buildMask(ctx) - return err + return nil } func newRange() cellRange { @@ -297,7 +296,7 @@ func roundUp(length, pad int) int { // Builds a mask of history x batch indicating whether for each token in the batch the // token in the history should apply. This is based on both the sequence and causality (the // position of the history is not ahead of the token in the batch). -func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { +func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { // Align and pad the two dimensions as required by the backend batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) @@ -325,10 +324,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { mask[i] = float32(math.Inf(-1)) } - maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize) - if err != nil { - return nil, err - } + maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize) if c.config.MaskDType != ml.DTypeF32 { out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...) @@ -336,7 +332,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { maskTensor = out } - return maskTensor, nil + return maskTensor } func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) { @@ -491,12 +487,7 @@ func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) { if !slices.Equal(c.opts.Except, opts.Except) { c.opts = opts if ctx != nil { - var err error - c.curMask, err = c.buildMask(ctx) - if err != nil { - // This error should never occur because we have previously built a mask with the same shape - panic(fmt.Errorf("SetCausal: %w", err)) - } + c.curMask = c.buildMask(ctx) } } } @@ -652,10 +643,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { } } - kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets)) - if err != nil { - return err - } + kShift := ctx.Input().FromIntSlice(offsets, len(offsets)) for i, key := range c.keys { if key == nil { diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 820d496d1d..5b1dbe868f 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -344,7 +344,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) } cache.SetLayer(0) - tensor, _ := context.FromFloatSlice(test.in, test.inShape...) + tensor := context.FromFloatSlice(test.in, test.inShape...) cache.Put(context, tensor, tensor) out, _, mask := cache.Get(context) @@ -386,7 +386,7 @@ func TestCanResume(t *testing.T) { } cache.SetLayer(0) - tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4) + tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4) cache.Put(context, tensor, tensor) // with window size 4, nothing has slid out of the window yet @@ -413,7 +413,7 @@ func TestCanResume(t *testing.T) { } cache.SetLayer(0) - tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2) + tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2) cache.Put(context, tensor, tensor) // only the latest position has overlapping windows @@ -470,24 +470,24 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { return c.Empty(dtype, shape...) } -func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { +func (c *testContext) FromFloatSlice(s []float32, shape ...int) ml.Tensor { t := c.Empty(ml.DTypeF32, shape...).(*testTensor) copy(t.data, s) - return t, nil + return t } -func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { +func (c *testContext) FromIntSlice(s []int32, shape ...int) ml.Tensor { f := make([]float32, len(s)) for i := range f { f[i] = float32(s[i]) } - out, _ := c.FromFloatSlice(f, shape...) + out := c.FromFloatSlice(f, shape...) out.(*testTensor).dtype = ml.DTypeI32 - return out, nil + return out } func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { @@ -496,7 +496,7 @@ func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tenso s = append(s, i) } - out, _ := c.FromFloatSlice(s, len(s)) + out := c.FromFloatSlice(s, len(s)) out.(*testTensor).dtype = dtype return out } diff --git a/ml/backend.go b/ml/backend.go index 7c9b9e3139..6beb7d2bc0 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -171,8 +171,8 @@ func NewBackend(modelPath string, params BackendParams) (Backend, error) { type Context interface { Empty(dtype DType, shape ...int) Tensor Zeros(dtype DType, shape ...int) Tensor - FromFloatSlice(s []float32, shape ...int) (Tensor, error) - FromIntSlice(s []int32, shape ...int) (Tensor, error) + FromFloatSlice(s []float32, shape ...int) Tensor + FromIntSlice(s []int32, shape ...int) Tensor // Arange creates a 1D tensor with values within an interval (start, stop] increased by step. Arange(start, stop, step float32, dtype DType) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 496ba8a607..76172ae1ab 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -729,11 +729,11 @@ func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { return t } -func checkShape[S ~[]E, E any](s S, shape ...int) error { +func checkShape[S ~[]E, E any](s S, shape ...int) { n := len(s) if n == 0 { - return nil + return } for _, v := range shape { @@ -741,16 +741,12 @@ func checkShape[S ~[]E, E any](s S, shape ...int) error { } if n != 1 { - return fmt.Errorf("invalid shape: %v", shape) + panic(fmt.Errorf("invalid shape: %v", shape)) } - - return nil } -func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { - if err := checkShape(s, shape...); err != nil { - return nil, err - } +func (c *Context) FromFloatSlice(s []float32, shape ...int) ml.Tensor { + checkShape(s, shape...) t := c.newTensor(ml.DTypeF32, shape) @@ -758,13 +754,11 @@ func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) } - return t, nil + return t } -func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { - if err := checkShape(s, shape...); err != nil { - return nil, err - } +func (c *Context) FromIntSlice(s []int32, shape ...int) ml.Tensor { + checkShape(s, shape...) t := c.newTensor(ml.DTypeI32, shape) @@ -772,7 +766,7 @@ func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) } - return t, nil + return t } func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { @@ -790,12 +784,7 @@ func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { arange = append(arange, int32(i)) } - t, err := c.Input().FromIntSlice(arange, len(arange)) - if err != nil { - panic(err) - } - - return t + return c.Input().FromIntSlice(arange, len(arange)) default: panic("unsupported dtype for arange") } diff --git a/model/model.go b/model/model.go index 39b68db155..25097e0103 100644 --- a/model/model.go +++ b/model/model.go @@ -287,11 +287,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten return nil, errors.New("batch size cannot be less than 1") } - var err error - batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs)) - if err != nil { - return nil, err - } + batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs)) cache := m.Config().Cache if cache != nil { diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 3c5a7ea5ba..e621d03ae2 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -175,15 +175,8 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 89d1788ef1..53bf827587 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -101,14 +101,11 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return nil, err } - pixelValues, err := ctx.Input().FromFloatSlice(f32s, + pixelValues := ctx.Input().FromFloatSlice(f32s, m.ImageProcessor.imageSize, m.ImageProcessor.imageSize, m.ImageProcessor.numChannels, ) - if err != nil { - return nil, err - } visionOutputs := m.VisionModel.Forward(ctx, pixelValues) visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps) @@ -144,15 +141,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil } diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 507f1ebc20..3cf782d00f 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -142,10 +142,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) @@ -154,10 +151,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) } hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options) diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index af5173a16e..8084760b0c 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -77,10 +77,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return nil, err } - tilesLocal, err := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels) - if err != nil { - return nil, err - } + tilesLocal := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels) ratioW, ratioH := size.X/m.imageSize, size.Y/m.imageSize @@ -91,11 +88,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input pixelValues := tilesLocal if len(pixelsGlobal) > 0 { - tilesGlobal, err := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels) - if err != nil { - return nil, err - } - + tilesGlobal := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels) pixelValues = pixelValues.Concat(ctx, tilesGlobal, 3) } @@ -182,15 +175,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil } diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index ff9f5e20d3..27935f4012 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -223,11 +223,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor scales[i] = float32(math.Log(math.Floor(((float64(p)+1.0)/float64(m.attentionFloorScale))+1.0))*m.attentionScale + 1.0) } - var err error - attentionScales, err = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales)) - if err != nil { - panic(err) - } + attentionScales = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales)) } for i, layer := range m.Layers { diff --git a/model/models/llama4/model_vision.go b/model/models/llama4/model_vision.go index e6b1afef6c..dc6f82b84c 100644 --- a/model/models/llama4/model_vision.go +++ b/model/models/llama4/model_vision.go @@ -245,10 +245,7 @@ func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) { } } - ropeFreqs, err := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2) - if err != nil { - panic(err) - } + ropeFreqs := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2) ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches) diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index dd01a587b1..9d662fc110 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -114,10 +114,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input return nil, err } - pixelValues, err := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels) - if err != nil { - return nil, err - } + pixelValues := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels) visionOutputs := m.VisionModel.Forward(ctx, pixelValues) features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size) @@ -161,15 +158,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil } diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index 2454100460..65bdcff2ae 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -110,15 +110,8 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor) } } - h, err := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2) - if err != nil { - panic(err) - } - - w, err := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2) - if err != nil { - panic(err) - } + h := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2) + w := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2) h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) @@ -151,10 +144,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { } } - positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions)) - if err != nil { - panic(err) - } + positionIDs := ctx.Input().FromIntSlice(positions, len(positions)) positionEmbedding := m.positionalEmbedding(ctx, positionIDs) cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx) diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 547e2cb328..45cb3e02c2 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -80,15 +80,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input f32s = f32s[:m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles] } - pixelValues, err := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles) - if err != nil { - return nil, err - } - - aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1) - if err != nil { - return nil, err - } + pixelValues := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles) + aspectRatio := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1) positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32) crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio) @@ -113,15 +106,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal[0].Tensor } - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) // TODO: attention mask, cross attention mask return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index 3c3d81aa50..42338d0d69 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -100,10 +100,7 @@ type Model struct { // Forward implements model.Model. func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) @@ -112,10 +109,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) } hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options) diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index 32cca56072..ee38cad924 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -69,10 +69,7 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, * m.ImageProcessor.patchSize * m.ImageProcessor.patchSize numPatches := grid.Temporal * grid.Height * grid.Width - pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches) - if err != nil { - return nil, nil, fmt.Errorf("failed to create tensor from image: %w", err) - } + pixelValues := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches) return pixelValues, grid, nil } @@ -142,15 +139,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } - - outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache) } diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index 01eef392b2..4d7afaa144 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -1,7 +1,6 @@ package qwen25vl import ( - "fmt" "math" "slices" @@ -44,10 +43,8 @@ func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int } } - mask, err := ctx.Input().FromFloatSlice(flat, seqLength, seqLength) - if err != nil { - panic(err) - } + mask := ctx.Input().FromFloatSlice(flat, seqLength, seqLength) + // Reshape to match [seqLength, seqLength, 1] for broadcasting mask = mask.Reshape(ctx, seqLength, seqLength, 1) @@ -303,10 +300,7 @@ func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) } } - t, err := ctx.Input().FromIntSlice(index, len(index)) - if err != nil { - panic(err) - } + t := ctx.Input().FromIntSlice(index, len(index)) return t, bounds } @@ -326,10 +320,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim))) } } - freqs, err := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize) - if err != nil { - panic(fmt.Errorf("failed to create tensor from frequencies: %w", err)) - } + freqs := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize) // Create position coordinates (y,x pairs) for the grid // In PyTorch: Equivalent to generating position ids with torch.arange() @@ -339,10 +330,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor coords = append(coords, int32(y), int32(x)) } } - pos, err := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height) - if err != nil { - panic(fmt.Errorf("failed to create tensor from positions: %w", err)) - } + pos := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height) // Reshape and permute positions to match spatial merging pattern pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge) diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index 44c32f9e3b..1930da7e20 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -156,10 +156,7 @@ type Model struct { // Forward implements model.Model. func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - if err != nil { - return nil, err - } + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) @@ -168,10 +165,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - if err != nil { - return nil, err - } + outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) } hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) diff --git a/runner/ollamarunner/multimodal.go b/runner/ollamarunner/multimodal.go index dbe6bba106..fbdc7d72a8 100644 --- a/runner/ollamarunner/multimodal.go +++ b/runner/ollamarunner/multimodal.go @@ -102,7 +102,7 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten for i, t := range entry.mm { if in == t.Tensor { if !reserve { - return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...) + return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...), nil } else { return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 99bee1061a..a7a889f1fd 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -808,10 +808,7 @@ func (s *Server) reserveWorstCaseGraph() error { batch.Outputs[i] = int32(i) } - batch.Inputs, err = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) - if err != nil { - return err - } + batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) cache := s.model.Config().Cache if cache != nil { @@ -876,7 +873,8 @@ func (s *Server) load( parallel int, kvCacheType string, kvSize int, - multiUserCache bool) { + multiUserCache bool, +) { err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache) if err != nil { panic(err)