mirror of
https://github.com/ollama/ollama.git
synced 2025-04-13 14:19:23 +02:00
arange
This commit is contained in:
parent
b01b163a37
commit
7256d74d58
@ -361,6 +361,17 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
||||
s := make([]float32, int((stop-start)/step))
|
||||
for i := range s {
|
||||
s[i] = float32(start + float32(i)*step)
|
||||
}
|
||||
|
||||
out, _ := c.FromFloatSlice(s, len(s))
|
||||
out.(*testTensor).dtype = dtype
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *testContext) Input() ml.Context { return c }
|
||||
func (c *testContext) Output() ml.Context { return c }
|
||||
func (c *testContext) Layer(int) ml.Context { return c }
|
||||
|
@ -104,6 +104,7 @@ type Context interface {
|
||||
Zeros(dtype DType, shape ...int) Tensor
|
||||
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
||||
FromIntSlice(s []int32, shape ...int) (Tensor, error)
|
||||
Arange(start, stop, step float32, dtype DType) Tensor
|
||||
|
||||
Forward(...Tensor) Context
|
||||
Compute(...Tensor)
|
||||
|
@ -651,6 +651,32 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
||||
switch dtype {
|
||||
case ml.DTypeF32:
|
||||
// ggml_arange creates a float32 tensor
|
||||
return &Tensor{
|
||||
b: c.b,
|
||||
t: C.ggml_arange(c.ctx, C.float(start), C.float(stop), C.float(step)),
|
||||
}
|
||||
case ml.DTypeI32:
|
||||
// ggml_cast does not support float32 to int32 conversion
|
||||
arange := make([]int32, int((stop-start)/step))
|
||||
for i := range arange {
|
||||
arange[i] = int32(i) * int32(step)
|
||||
}
|
||||
|
||||
t, err := c.Input().FromIntSlice(arange, len(arange))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return t
|
||||
default:
|
||||
panic("unsupported dtype for arange")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Context) Close() {
|
||||
if c != nil {
|
||||
C.ggml_free(c.ctx)
|
||||
|
@ -91,16 +91,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
|
||||
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
positions := make([]int32, numPatches)
|
||||
for i := range positions {
|
||||
positions[i] = int32(i)
|
||||
}
|
||||
|
||||
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeI32)
|
||||
hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs))
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
|
@ -92,16 +92,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions := make([]int32, 1601)
|
||||
for i := range positions {
|
||||
positions[i] = int32(i)
|
||||
}
|
||||
|
||||
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32)
|
||||
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
|
||||
return m.Projector.Forward(ctx, crossAttentionStates), nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user