This commit is contained in:
Michael Yang 2025-04-03 10:25:23 -07:00
parent b01b163a37
commit 7256d74d58
5 changed files with 40 additions and 20 deletions

View File

@ -361,6 +361,17 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
return out, nil
}
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
s := make([]float32, int((stop-start)/step))
for i := range s {
s[i] = float32(start + float32(i)*step)
}
out, _ := c.FromFloatSlice(s, len(s))
out.(*testTensor).dtype = dtype
return out
}
func (c *testContext) Input() ml.Context { return c }
func (c *testContext) Output() ml.Context { return c }
func (c *testContext) Layer(int) ml.Context { return c }

View File

@ -104,6 +104,7 @@ type Context interface {
Zeros(dtype DType, shape ...int) Tensor
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
FromIntSlice(s []int32, shape ...int) (Tensor, error)
Arange(start, stop, step float32, dtype DType) Tensor
Forward(...Tensor) Context
Compute(...Tensor)

View File

@ -651,6 +651,32 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
return t, nil
}
func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
switch dtype {
case ml.DTypeF32:
// ggml_arange creates a float32 tensor
return &Tensor{
b: c.b,
t: C.ggml_arange(c.ctx, C.float(start), C.float(stop), C.float(step)),
}
case ml.DTypeI32:
// ggml_cast does not support float32 to int32 conversion
arange := make([]int32, int((stop-start)/step))
for i := range arange {
arange[i] = int32(i) * int32(step)
}
t, err := c.Input().FromIntSlice(arange, len(arange))
if err != nil {
panic(err)
}
return t
default:
panic("unsupported dtype for arange")
}
}
func (c *Context) Close() {
if c != nil {
C.ggml_free(c.ctx)

View File

@ -91,16 +91,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
positions := make([]int32, numPatches)
for i := range positions {
positions[i] = int32(i)
}
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
if err != nil {
panic(err)
}
positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeI32)
hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs))
for _, layer := range m.Layers {

View File

@ -92,16 +92,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return nil, err
}
positions := make([]int32, 1601)
for i := range positions {
positions[i] = int32(i)
}
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
if err != nil {
return nil, err
}
positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32)
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
return m.Projector.Forward(ctx, crossAttentionStates), nil
}