diff --git a/pool.go b/pool.go index 514f7d9..13399f4 100644 --- a/pool.go +++ b/pool.go @@ -743,8 +743,12 @@ func (pool *SimplePool) BatchedSubManyEose( pool.duplicateMiddleware(relay, id) } return exists - }), seenAlready, opts...) { - res <- ie + }), seenAlready, opts..., + ) { + select { + case res <- ie: + case <-ctx.Done(): + } } wg.Done() diff --git a/sdk/dataloader/dataloader.go b/sdk/dataloader/dataloader.go index 2c4d7cd..bc3434e 100644 --- a/sdk/dataloader/dataloader.go +++ b/sdk/dataloader/dataloader.go @@ -27,10 +27,7 @@ type Loader[K comparable, V any] struct { batchFn BatchFunc[K, V] // the maximum batch size. Set to 0 if you want it to be unbounded. - batchCap uint - - // count of queued up items - count uint + batchCap int // the amount of time to wait before triggering a batch wait time.Duration @@ -40,9 +37,6 @@ type Loader[K comparable, V any] struct { // current batcher curBatcher *batcher[K, V] - - // used to close the sleeper of the current batcher - thresholdReached chan bool } // type used to on input channel @@ -61,11 +55,12 @@ type Options struct { func NewBatchedLoader[K comparable, V any](batchFn BatchFunc[K, V], opts Options) *Loader[K, V] { loader := &Loader[K, V]{ batchFn: batchFn, - batchCap: opts.MaxThreshold, - count: 0, + batchCap: int(opts.MaxThreshold), wait: opts.Wait, } + loader.curBatcher = loader.newBatcher() + return loader } @@ -73,36 +68,23 @@ func NewBatchedLoader[K comparable, V any](batchFn BatchFunc[K, V], opts Options // The first context passed to this function within a given batch window will be provided to // the registered BatchFunc. func (l *Loader[K, V]) Load(ctx context.Context, key K) (value V, err error) { - c := make(chan Result[V], 1) + c := make(chan Result[V]) // this is sent to batch fn. It contains the key and the channel to return // the result on req := batchRequest[K, V]{ctx, key, c} l.batchLock.Lock() - // start the batch window if it hasn't already started. - if l.curBatcher == nil { - l.curBatcher = l.newBatcher() - // start a sleeper for the current batcher - l.thresholdReached = make(chan bool) - - // unlock either here or on the else condition - l.batchLock.Unlock() - - // we will run the batch function either after some time or after a threshold has been reached - b := l.curBatcher - go func() { + // we will run the batch function either after some time or after a threshold has been reached + if len(l.curBatcher.requests) == 0 { + go func(b *batcher[K, V]) { select { - case <-l.thresholdReached: + case <-b.thresholdReached: case <-time.After(l.wait): - } - - // We can end here also if the batcher has already been closed and a - // new one has been created. So reset the loader state only if the batcher - // is the current one - if l.curBatcher == b { - l.reset() + l.batchLock.Lock() + l.curBatcher = l.newBatcher() + l.batchLock.Unlock() } var ( @@ -124,22 +106,17 @@ func (l *Loader[K, V]) Load(ctx context.Context, key K) (value V, err error) { } close(req.channel) } - }() - } else { - l.batchLock.Unlock() + }(l.curBatcher) } l.curBatcher.requests = append(l.curBatcher.requests, req) - - l.count++ - if l.count == l.batchCap { - close(l.thresholdReached) - - // end the batcher synchronously here because another call to Load - // may concurrently happen and needs to go to a new batcher. - l.reset() + if len(l.curBatcher.requests) == l.batchCap { + close(l.curBatcher.thresholdReached) + l.curBatcher = l.newBatcher() } + l.batchLock.Unlock() + if v, ok := <-c; ok { return v.Data, v.Error } @@ -147,22 +124,17 @@ func (l *Loader[K, V]) Load(ctx context.Context, key K) (value V, err error) { return value, NoValueError } -func (l *Loader[K, V]) reset() { - l.batchLock.Lock() - defer l.batchLock.Unlock() - l.count = 0 - l.curBatcher = nil -} - type batcher[K comparable, V any] struct { - requests []batchRequest[K, V] - batchFn BatchFunc[K, V] + thresholdReached chan struct{} + requests []batchRequest[K, V] + batchFn BatchFunc[K, V] } // newBatcher returns a batcher for the current requests func (l *Loader[K, V]) newBatcher() *batcher[K, V] { return &batcher[K, V]{ - requests: make([]batchRequest[K, V], 0, l.batchCap), - batchFn: l.batchFn, + thresholdReached: make(chan struct{}), + requests: make([]batchRequest[K, V], 0, l.batchCap), + batchFn: l.batchFn, } } diff --git a/sdk/replaceable_loader.go b/sdk/replaceable_loader.go index f1e7a7b..4ef502c 100644 --- a/sdk/replaceable_loader.go +++ b/sdk/replaceable_loader.go @@ -6,6 +6,7 @@ import ( "slices" "strconv" "sync" + "sync/atomic" "time" "github.com/nbd-wtf/go-nostr" @@ -74,7 +75,8 @@ func (sys *System) batchLoadReplaceableEvents( cm := sync.Mutex{} aggregatedContext, aggregatedCancel := context.WithCancel(context.Background()) - waiting := len(pubkeys) + waiting := atomic.Int32{} + waiting.Add(int32(len(pubkeys))) for i, pubkey := range pubkeys { ctx, cancel := context.WithCancel(ctxs[i]) @@ -111,8 +113,7 @@ func (sys *System) batchLoadReplaceableEvents( wg.Done() <-ctx.Done() - waiting-- - if waiting == 0 { + if waiting.Add(-1) == 0 { aggregatedCancel() } }(i, pubkey)