backend: Consistently use int (vs. int64) for tensor shapes

Currently there is a mixture of int and int64 used when dealing with
tensor dimensions and shapes, which causes unnecessary conversions -
they all should be the same type.

In general, most interfaces (such as Pytorch) use int64 for
generality but most implementations (such as CUDA) use int32 for
performance. There isn't much benefit to us to being more flexible
than the implementations we are likely to run on.

In addition, as a practical matter, a model with a tensor with a single
dimension larger than 32 bits is unlikely to run on a 32-bit machine.
This commit is contained in:
Jesse Gross 2025-02-03 17:21:57 -08:00 committed by Jesse Gross
parent 7e13f568dc
commit 0e38297f87
6 changed files with 59 additions and 50 deletions

18
cache/cache.go vendored
View File

@ -36,24 +36,24 @@ func (c *Simple) Sub(i int) Cache {
func (c *Simple) Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor) {
if c.keys[0] == nil || c.values[0] == nil {
c.keys[0] = ctx.Zeros(c.DType, int(key.Dim(0)*key.Dim(1))*c.Capacity)
c.values[0] = ctx.Zeros(c.DType, int(value.Dim(0)*value.Dim(1))*c.Capacity)
c.keys[0] = ctx.Zeros(c.DType, key.Dim(0)*key.Dim(1)*c.Capacity)
c.values[0] = ctx.Zeros(c.DType, value.Dim(0)*value.Dim(1)*c.Capacity)
}
ctx.Forward(key.Copy(ctx, c.keys[0].View(ctx, int(key.Stride(2))*opts.Position, int(key.Dim(0)*key.Dim(1)*key.Dim(2)))))
ctx.Forward(value.Copy(ctx, c.values[0].View(ctx, int(value.Stride(2))*opts.Position, int(value.Dim(0)*value.Dim(1)*value.Dim(2)))))
ctx.Forward(key.Copy(ctx, c.keys[0].View(ctx, key.Stride(2)*opts.Position, key.Dim(0)*key.Dim(1)*key.Dim(2))))
ctx.Forward(value.Copy(ctx, c.values[0].View(ctx, value.Stride(2)*opts.Position, value.Dim(0)*value.Dim(1)*value.Dim(2))))
n := min(c.Capacity, int(key.Dim(2))+opts.Position)
n := min(c.Capacity, key.Dim(2)+opts.Position)
key = c.keys[0].View(ctx, 0,
int(key.Dim(0)), int(key.Stride(1)),
int(key.Dim(1)), int(key.Stride(2)),
key.Dim(0), key.Stride(1),
key.Dim(1), key.Stride(2),
n,
)
value = c.values[0].View(ctx, 0,
int(value.Dim(0)), int(value.Stride(1)),
int(value.Dim(1)), int(value.Stride(2)),
value.Dim(0), value.Stride(1),
value.Dim(1), value.Stride(2),
n,
)

View File

@ -54,10 +54,10 @@ type Context interface {
}
type Tensor interface {
Dim(n int) int64
Stride(n int) int64
Dim(n int) int
Stride(n int) int
Shape() []int64
Shape() []int
DType() DType
Bytes() []byte
@ -79,13 +79,13 @@ type Tensor interface {
GELU(ctx Context) Tensor
SILU(ctx Context) Tensor
Reshape(ctx Context, shape ...int64) Tensor
Reshape(ctx Context, shape ...int) Tensor
View(ctx Context, offset int, shape ...int) Tensor
Permute(ctx Context, shape ...int) Tensor
Contiguous(ctx Context) Tensor
Pad(ctx Context, shape ...int64) Tensor
Unpad(ctx Context, shape ...int64) Tensor
Pad(ctx Context, shape ...int) Tensor
Unpad(ctx Context, shape ...int) Tensor
Stack(ctx Context, dim int, s ...Tensor) Tensor
Concat(ctx Context, t2 Tensor, dim int) Tensor
@ -111,7 +111,7 @@ func mul[T number](s ...T) T {
type DumpOptions struct {
// Items is the number of elements to print at the beginning and end of each dimension.
Items int64
Items int
// Precision is the number of decimal places to print. Applies to float32 and float64.
Precision int
@ -139,7 +139,7 @@ func Dump(t Tensor, opts ...DumpOptions) string {
}
}
func dump[S ~[]E, E number](t Tensor, items int64, fn func(E) string) string {
func dump[S ~[]E, E number](t Tensor, items int, fn func(E) string) string {
bts := t.Bytes()
if bts == nil {
return "<nil>"
@ -153,12 +153,12 @@ func dump[S ~[]E, E number](t Tensor, items int64, fn func(E) string) string {
shape := t.Shape()
var sb strings.Builder
var f func([]int64, int64)
f = func(dims []int64, stride int64) {
var f func([]int, int)
f = func(dims []int, stride int) {
prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
fmt.Fprint(&sb, "[")
defer func() { fmt.Fprint(&sb, "]") }()
for i := int64(0); i < dims[0]; i++ {
for i := 0; i < dims[0]; i++ {
if i >= items && i < dims[0]-items {
fmt.Fprint(&sb, "..., ")
// skip to next printable element

View File

@ -254,6 +254,15 @@ func (c *Context) Compute(t ml.Tensor) ml.Tensor {
return t
}
func shapeToGGML(shape []int) *C.int64_t {
sh := make([]C.int64_t, len(shape))
for i, s := range shape {
sh[i] = (C.int64_t)(s)
}
return &sh[0]
}
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
if len(shape) < 1 || len(shape) > 4 {
panic("unsupported number of dimensions")
@ -268,9 +277,9 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
var t *C.struct_ggml_tensor
switch dtype {
case ml.DTypeF32:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), (*C.int64_t)(unsafe.Pointer(&shape[0])))
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
case ml.DTypeI32:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), (*C.int64_t)(unsafe.Pointer(&shape[0])))
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
default:
panic("unsupported dtype")
}
@ -291,7 +300,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
return nil, fmt.Errorf("invalid shape %v for %d elements", shape, len(s))
}
t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), (*C.int64_t)(unsafe.Pointer(&shape[0])))
t := C.ggml_new_tensor(ctx.ctx, dtype, C.int(len(shape)), shapeToGGML(shape))
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
@ -324,16 +333,16 @@ func (t *Tensor) LogValue() slog.Value {
)
}
func (t *Tensor) Dim(n int) int64 {
return int64(t.t.ne[n])
func (t *Tensor) Dim(n int) int {
return int(t.t.ne[n])
}
func (t *Tensor) Stride(n int) int64 {
return int64(t.t.nb[n])
func (t *Tensor) Stride(n int) int {
return int(t.t.nb[n])
}
func (t *Tensor) Shape() []int64 {
shape := make([]int64, C.ggml_n_dims(t.t))
func (t *Tensor) Shape() []int {
shape := make([]int, C.ggml_n_dims(t.t))
for i := range shape {
shape[i] = t.Dim(i)
}
@ -420,7 +429,7 @@ func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
return (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
}
func (t *Tensor) Pad(ctx ml.Context, shape ...int64) ml.Tensor {
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
if len(shape) != 4 {
panic("expected 4 dimensions")
}
@ -452,7 +461,7 @@ func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) Reshape(ctx ml.Context, shape ...int64) ml.Tensor {
func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
switch len(shape) {
case 1:
return &Tensor{
@ -493,7 +502,7 @@ func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) Unpad(ctx ml.Context, shape ...int64) ml.Tensor {
func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
if len(shape) != 4 {
panic("expected 4 dimensions")
}

View File

@ -10,7 +10,7 @@ import (
type Options struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int64
hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
}
@ -41,9 +41,9 @@ func New(c ml.Config) (model.Model, error) {
),
Layers: make([]Layer, c.Uint("block_count")),
Options: &Options{
hiddenSize: int64(c.Uint("embedding_length")),
numHeads: int64(c.Uint("attention.head_count")),
numKVHeads: int64(c.Uint("attention.head_count_kv")),
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),

View File

@ -173,7 +173,7 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr
type TextModelOptions struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int64
hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
@ -212,9 +212,9 @@ func newTextModel(c ml.Config) *TextModel {
return &TextModel{
Transformer: &TextDecoder{Layers: decoderLayers},
TextModelOptions: &TextModelOptions{
hiddenSize: int64(c.Uint("embedding_length")),
numHeads: int64(c.Uint("attention.head_count")),
numKVHeads: int64(c.Uint("attention.head_count_kv")),
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),

View File

@ -8,7 +8,7 @@ import (
"github.com/ollama/ollama/ml/nn"
)
var batchSize int64 = 1
var batchSize int = 1
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
@ -99,7 +99,7 @@ func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermedi
var intermediateHiddenStates []ml.Tensor
for i, layer := range e.Layers {
if slices.Contains(intermediateLayersIndices, uint32(i)) {
intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int64{1}, hiddenState.Shape()...)...))
intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int{1}, hiddenState.Shape()...)...))
}
hiddenState = layer.Forward(ctx, hiddenState, opts)
@ -131,7 +131,7 @@ type PrecomputedPositionEmbedding struct {
TilePositionEmbeddingGate ml.Tensor `gguf:"tile_position_embd.gate"`
}
func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, positionIDs, aspectRatioIDs ml.Tensor, numPositions int64, opts *VisionModelOptions) ml.Tensor {
func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, positionIDs, aspectRatioIDs ml.Tensor, numPositions int, opts *VisionModelOptions) ml.Tensor {
positionEmbedding := e.PositionEmbedding.Forward(ctx, positionIDs)
if e.PositionEmbeddingGate != nil {
positionEmbedding = positionEmbedding.Mul(ctx, e.PositionEmbeddingGate)
@ -149,7 +149,7 @@ func (e *PrecomputedPositionEmbedding) Forward(ctx ml.Context, hiddenState, posi
}
type VisionModelOptions struct {
hiddenSize, numHeads, numTiles int64
hiddenSize, numHeads, numTiles int
imageSize, patchSize int
eps float32
@ -174,7 +174,7 @@ type VisionModel struct {
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRatioIDs ml.Tensor) ml.Tensor {
numPatches := int64((m.imageSize / m.patchSize) * (m.imageSize / m.patchSize))
numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
numPositions := numPatches
if m.ClassEmbedding != nil {
numPositions++
@ -185,7 +185,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenState = m.PreTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions)
hiddenState = m.ClassEmbedding.Stack(ctx, 2, slices.Repeat([]ml.Tensor{m.ClassEmbedding}, int(m.numTiles)-1)...).Concat(ctx, hiddenState, 1)
hiddenState = m.ClassEmbedding.Stack(ctx, 2, slices.Repeat([]ml.Tensor{m.ClassEmbedding}, m.numTiles-1)...).Concat(ctx, hiddenState, 1)
hiddenState = m.PositionEmbedding.Forward(ctx, hiddenState, positionIDs, aspectRatioIDs, numPositions, m.VisionModelOptions)
hiddenState = m.PreLayerNorm.Forward(ctx, hiddenState, m.eps)
@ -205,7 +205,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa
hiddenState, _ = m.GlobalTransformer.Forward(ctx, hiddenState, nil, m.VisionModelOptions)
hiddenStates := intermediateHiddenStates[0].Stack(ctx, 0, intermediateHiddenStates[1:]...)
hiddenStates = hiddenStates.Reshape(ctx, int64(len(intermediateHiddenStates))*m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
hiddenStates = hiddenStates.Reshape(ctx, len(intermediateHiddenStates)*m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
hiddenStates = hiddenStates.Unpad(ctx, 0, numPaddingPatches, 0, 0)
hiddenState = hiddenState.Reshape(ctx, m.hiddenSize, numPositions+numPaddingPatches, m.numTiles, batchSize)
@ -219,9 +219,9 @@ func newVisionModel(c ml.Config) *VisionModel {
GlobalTransformer: &VisionEncoder{Layers: make([]VisionEncoderLayer, c.Uint("vision.global.block_count"))},
VisionModelOptions: &VisionModelOptions{
hiddenSize: int64(c.Uint("vision.embedding_length")),
numHeads: int64(c.Uint("vision.attention.head_count")),
numTiles: int64(c.Uint("vision.max_num_tiles")),
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.attention.head_count")),
numTiles: int(c.Uint("vision.max_num_tiles")),
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),