fix tests

This commit is contained in:
Michael Yang 2025-04-03 09:04:47 -07:00
parent 7946618dc3
commit 7632c4cb3c
3 changed files with 17 additions and 45 deletions

View File

@ -413,6 +413,14 @@ func (t *testTensor) Floats() []float32 {
return out
}
func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
for i := range out.data {
out.data[i] = -t.data[i]
}
return out
}
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
@ -467,25 +475,15 @@ func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, di
panic("not implemented")
}
func (t *testTensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, sections [4]int, ropeType uint32, base, scale float32) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") }
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
panic("not implemented")
@ -537,6 +535,8 @@ func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
panic("not implemented")
}
func (t *testTensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor { panic("not implemented") }
func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
panic("not implemented")
}
@ -549,3 +549,5 @@ func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
copy(t2.(*testTensor).data, t.data)
return nil
}
func (t *testTensor) Duplicate(ctx ml.Context) ml.Tensor { panic("not implemented") }

View File

@ -145,8 +145,6 @@ type Tensor interface {
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
RoPEMulti(ctx Context, positionIDs, ropeFactors Tensor, ropeDim uint32, sections [4]int, ropeType uint32, base, scale float32) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Sin(ctx Context) Tensor

View File

@ -1000,34 +1000,6 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
}
}
func (t *Tensor) RoPEMulti(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, sections [4]int, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b}
}
dequant := t.t
if C.ggml_is_quantized(t.t._type) {
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
}
return &Tensor{
b: t.b,
t: C.ggml_rope_multi(
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim),
(*C.int)(unsafe.Pointer(&sections[0])),
C.int(ropeType),
131072, // YaRN n_ctx_train
C.float(ropeBase),
C.float(ropeScale),
0., // YaRN ext_factor
1., // YaRN attn_factor
32., // YaRN beta_fast
1., // YaRN beta_slow
),
}
}
func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return &Tensor{
b: t.b,