sdk/dataloader simplify and fix lock issues, now it will work.

This commit is contained in:
fiatjaf 2025-03-26 00:58:19 -03:00
parent d1fca24cc3
commit a60e225a5f
3 changed files with 34 additions and 57 deletions

View File

@ -743,8 +743,12 @@ func (pool *SimplePool) BatchedSubManyEose(
pool.duplicateMiddleware(relay, id) pool.duplicateMiddleware(relay, id)
} }
return exists return exists
}), seenAlready, opts...) { }), seenAlready, opts...,
res <- ie ) {
select {
case res <- ie:
case <-ctx.Done():
}
} }
wg.Done() wg.Done()

View File

@ -27,10 +27,7 @@ type Loader[K comparable, V any] struct {
batchFn BatchFunc[K, V] batchFn BatchFunc[K, V]
// the maximum batch size. Set to 0 if you want it to be unbounded. // the maximum batch size. Set to 0 if you want it to be unbounded.
batchCap uint batchCap int
// count of queued up items
count uint
// the amount of time to wait before triggering a batch // the amount of time to wait before triggering a batch
wait time.Duration wait time.Duration
@ -40,9 +37,6 @@ type Loader[K comparable, V any] struct {
// current batcher // current batcher
curBatcher *batcher[K, V] curBatcher *batcher[K, V]
// used to close the sleeper of the current batcher
thresholdReached chan bool
} }
// type used to on input channel // type used to on input channel
@ -61,11 +55,12 @@ type Options struct {
func NewBatchedLoader[K comparable, V any](batchFn BatchFunc[K, V], opts Options) *Loader[K, V] { func NewBatchedLoader[K comparable, V any](batchFn BatchFunc[K, V], opts Options) *Loader[K, V] {
loader := &Loader[K, V]{ loader := &Loader[K, V]{
batchFn: batchFn, batchFn: batchFn,
batchCap: opts.MaxThreshold, batchCap: int(opts.MaxThreshold),
count: 0,
wait: opts.Wait, wait: opts.Wait,
} }
loader.curBatcher = loader.newBatcher()
return loader return loader
} }
@ -73,36 +68,23 @@ func NewBatchedLoader[K comparable, V any](batchFn BatchFunc[K, V], opts Options
// The first context passed to this function within a given batch window will be provided to // The first context passed to this function within a given batch window will be provided to
// the registered BatchFunc. // the registered BatchFunc.
func (l *Loader[K, V]) Load(ctx context.Context, key K) (value V, err error) { func (l *Loader[K, V]) Load(ctx context.Context, key K) (value V, err error) {
c := make(chan Result[V], 1) c := make(chan Result[V])
// this is sent to batch fn. It contains the key and the channel to return // this is sent to batch fn. It contains the key and the channel to return
// the result on // the result on
req := batchRequest[K, V]{ctx, key, c} req := batchRequest[K, V]{ctx, key, c}
l.batchLock.Lock() l.batchLock.Lock()
// start the batch window if it hasn't already started.
if l.curBatcher == nil {
l.curBatcher = l.newBatcher()
// start a sleeper for the current batcher // we will run the batch function either after some time or after a threshold has been reached
l.thresholdReached = make(chan bool) if len(l.curBatcher.requests) == 0 {
go func(b *batcher[K, V]) {
// unlock either here or on the else condition
l.batchLock.Unlock()
// we will run the batch function either after some time or after a threshold has been reached
b := l.curBatcher
go func() {
select { select {
case <-l.thresholdReached: case <-b.thresholdReached:
case <-time.After(l.wait): case <-time.After(l.wait):
} l.batchLock.Lock()
l.curBatcher = l.newBatcher()
// We can end here also if the batcher has already been closed and a l.batchLock.Unlock()
// new one has been created. So reset the loader state only if the batcher
// is the current one
if l.curBatcher == b {
l.reset()
} }
var ( var (
@ -124,22 +106,17 @@ func (l *Loader[K, V]) Load(ctx context.Context, key K) (value V, err error) {
} }
close(req.channel) close(req.channel)
} }
}() }(l.curBatcher)
} else {
l.batchLock.Unlock()
} }
l.curBatcher.requests = append(l.curBatcher.requests, req) l.curBatcher.requests = append(l.curBatcher.requests, req)
if len(l.curBatcher.requests) == l.batchCap {
l.count++ close(l.curBatcher.thresholdReached)
if l.count == l.batchCap { l.curBatcher = l.newBatcher()
close(l.thresholdReached)
// end the batcher synchronously here because another call to Load
// may concurrently happen and needs to go to a new batcher.
l.reset()
} }
l.batchLock.Unlock()
if v, ok := <-c; ok { if v, ok := <-c; ok {
return v.Data, v.Error return v.Data, v.Error
} }
@ -147,22 +124,17 @@ func (l *Loader[K, V]) Load(ctx context.Context, key K) (value V, err error) {
return value, NoValueError return value, NoValueError
} }
func (l *Loader[K, V]) reset() {
l.batchLock.Lock()
defer l.batchLock.Unlock()
l.count = 0
l.curBatcher = nil
}
type batcher[K comparable, V any] struct { type batcher[K comparable, V any] struct {
requests []batchRequest[K, V] thresholdReached chan struct{}
batchFn BatchFunc[K, V] requests []batchRequest[K, V]
batchFn BatchFunc[K, V]
} }
// newBatcher returns a batcher for the current requests // newBatcher returns a batcher for the current requests
func (l *Loader[K, V]) newBatcher() *batcher[K, V] { func (l *Loader[K, V]) newBatcher() *batcher[K, V] {
return &batcher[K, V]{ return &batcher[K, V]{
requests: make([]batchRequest[K, V], 0, l.batchCap), thresholdReached: make(chan struct{}),
batchFn: l.batchFn, requests: make([]batchRequest[K, V], 0, l.batchCap),
batchFn: l.batchFn,
} }
} }

View File

@ -6,6 +6,7 @@ import (
"slices" "slices"
"strconv" "strconv"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/nbd-wtf/go-nostr" "github.com/nbd-wtf/go-nostr"
@ -74,7 +75,8 @@ func (sys *System) batchLoadReplaceableEvents(
cm := sync.Mutex{} cm := sync.Mutex{}
aggregatedContext, aggregatedCancel := context.WithCancel(context.Background()) aggregatedContext, aggregatedCancel := context.WithCancel(context.Background())
waiting := len(pubkeys) waiting := atomic.Int32{}
waiting.Add(int32(len(pubkeys)))
for i, pubkey := range pubkeys { for i, pubkey := range pubkeys {
ctx, cancel := context.WithCancel(ctxs[i]) ctx, cancel := context.WithCancel(ctxs[i])
@ -111,8 +113,7 @@ func (sys *System) batchLoadReplaceableEvents(
wg.Done() wg.Done()
<-ctx.Done() <-ctx.Done()
waiting-- if waiting.Add(-1) == 0 {
if waiting == 0 {
aggregatedCancel() aggregatedCancel()
} }
}(i, pubkey) }(i, pubkey)