additional review comments

This commit is contained in:
Jesse Gross 2025-03-07 11:19:03 -08:00 committed by Michael Yang
parent b27e8f3f10
commit 98272fbd58
2 changed files with 32 additions and 16 deletions

View File

@ -402,7 +402,10 @@ func (b *Backend) NewContext() ml.Context {
}
func (b *Backend) NewContextSize(n int) ml.Context {
n = min(n, b.maxGraphNodes)
if n > b.maxGraphNodes {
panic(fmt.Errorf("requested number of graph nodes (%v) for new context exceeds maximum (%v)", n, b.maxGraphNodes))
}
return &Context{
b: b,
maxGraphNodes: n,
@ -534,7 +537,7 @@ func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
panic("unsupported dtype")
}
if len(shape) < 1 {
if len(shape) < 1 || shape[0] == 0 {
var shape C.int64_t = 0
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
} else if len(shape) > 4 {
@ -565,6 +568,11 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
func checkShape[S ~[]E, E any](s S, shape ...int) error {
n := len(s)
if n == 0 {
return nil
}
for _, v := range shape {
n /= v
}
@ -577,22 +585,28 @@ func checkShape[S ~[]E, E any](s S, shape ...int) error {
}
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
if err := checkShape(s, shape...); err != nil && len(shape) > 0 {
if err := checkShape(s, shape...); err != nil {
return nil, err
}
t := c.newTensor(ml.DTypeF32, shape)
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
if len(s) > 0 {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
}
return t, nil
}
func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
if err := checkShape(s, shape...); err != nil && len(shape) > 0 {
if err := checkShape(s, shape...); err != nil {
return nil, err
}
t := c.newTensor(ml.DTypeI32, shape)
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
if len(s) > 0 {
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
}
return t, nil
}

View File

@ -10,10 +10,11 @@ import (
)
type TextSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
@ -22,11 +23,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@ -39,8 +40,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the causal cache, which are just the self attention layers
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
}
return key, nil
}
type TextMLP struct {
@ -191,8 +195,6 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
}
type TextModelOptions struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32