mirror of
https://github.com/ollama/ollama.git
synced 2025-08-26 20:11:59 +02:00
arange
This commit is contained in:
committed by
Michael Yang
parent
1d99451ad7
commit
40b8fdbdca
@@ -696,6 +696,32 @@ func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
||||
switch dtype {
|
||||
case ml.DTypeF32:
|
||||
// ggml_arange creates a float32 tensor
|
||||
return &Tensor{
|
||||
b: c.b,
|
||||
t: C.ggml_arange(c.ctx, C.float(start), C.float(stop), C.float(step)),
|
||||
}
|
||||
case ml.DTypeI32:
|
||||
// ggml_cast does not support float32 to int32 conversion
|
||||
arange := make([]int32, 0, int((stop-start)/step))
|
||||
for i := start; i < stop; i += step {
|
||||
arange = append(arange, int32(i))
|
||||
}
|
||||
|
||||
t, err := c.Input().FromIntSlice(arange, len(arange))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return t
|
||||
default:
|
||||
panic("unsupported dtype for arange")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Context) Close() {
|
||||
if c != nil {
|
||||
for _, b := range *c.allocatedBuffers {
|
||||
|
Reference in New Issue
Block a user