pool: support CountMany() using hyperloglog.

This commit is contained in:
fiatjaf 2024-11-16 16:59:24 -03:00
parent 99e45035d5
commit 0d40b40c9c
7 changed files with 72 additions and 19 deletions

View File

@ -2,6 +2,7 @@ package nostr
import (
"bytes"
"encoding/hex"
"encoding/json"
"fmt"
"strconv"
@ -142,7 +143,8 @@ func (v ReqEnvelope) MarshalJSON() ([]byte, error) {
type CountEnvelope struct {
SubscriptionID string
Filters
Count *int64
Count *int64
HyperLogLog []byte
}
func (_ CountEnvelope) Label() string { return "COUNT" }
@ -161,9 +163,11 @@ func (v *CountEnvelope) UnmarshalJSON(data []byte) error {
var countResult struct {
Count *int64 `json:"count"`
HLL string `json:"hll"`
}
if err := json.Unmarshal([]byte(arr[2].Raw), &countResult); err == nil && countResult.Count != nil {
v.Count = countResult.Count
v.HyperLogLog, _ = hex.DecodeString(countResult.HLL)
return nil
}
@ -189,6 +193,13 @@ func (v CountEnvelope) MarshalJSON() ([]byte, error) {
if v.Count != nil {
w.RawString(`,{"count":`)
w.RawString(strconv.FormatInt(*v.Count, 10))
if v.HyperLogLog != nil {
w.RawString(`,"hll":"`)
hllHex := make([]byte, 0, 512)
hex.Encode(hllHex, v.HyperLogLog)
w.Buffer.AppendBytes(hllHex)
w.RawString(`"`)
}
w.RawString(`}`)
} else {
for _, filter := range v.Filters {

View File

@ -1,4 +1,4 @@
package nip45
package hyperloglog
import (
"math"

View File

@ -1,4 +1,4 @@
package nip45
package hyperloglog
import (
"encoding/binary"
@ -18,13 +18,14 @@ func New() *HyperLogLog {
return hll
}
func (hll *HyperLogLog) Encode() string {
return hex.EncodeToString(hll.registers)
}
func (hll *HyperLogLog) Decode(enc string) error {
_, err := hex.Decode(hll.registers, []byte(enc))
return err
func (hll *HyperLogLog) GetRegisters() []byte { return hll.registers }
func (hll *HyperLogLog) SetRegisters(enc []byte) { hll.registers = enc }
func (hll *HyperLogLog) MergeRegisters(other []byte) {
for i, v := range other {
if v > hll.registers[i] {
hll.registers[i] = v
}
}
}
func (hll *HyperLogLog) Clear() {
@ -45,13 +46,12 @@ func (hll *HyperLogLog) Add(id string) {
}
}
func (hll *HyperLogLog) Merge(other *HyperLogLog) error {
func (hll *HyperLogLog) Merge(other *HyperLogLog) {
for i, v := range other.registers {
if v > hll.registers[i] {
hll.registers[i] = v
}
}
return nil
}
func (hll *HyperLogLog) Count() uint64 {

View File

@ -1,4 +1,4 @@
package nip45
package hyperloglog
import (
"encoding/hex"

34
pool.go
View File

@ -10,6 +10,7 @@ import (
"sync"
"time"
"github.com/nbd-wtf/go-nostr/nip45/hyperloglog"
"github.com/puzpuzpuz/xsync/v3"
)
@ -468,6 +469,39 @@ func (pool *SimplePool) subManyEose(
return events
}
// CountMany aggregates count results from multiple relays using HyperLogLog
func (pool *SimplePool) CountMany(
ctx context.Context,
urls []string,
filter Filter,
opts []SubscriptionOption,
) int {
hll := hyperloglog.New()
wg := sync.WaitGroup{}
wg.Add(len(urls))
for _, url := range urls {
go func(nm string) {
defer wg.Done()
relay, err := pool.EnsureRelay(url)
if err != nil {
return
}
ce, err := relay.countInternal(ctx, Filters{filter}, opts...)
if err != nil {
return
}
if len(ce.HyperLogLog) != 256 {
return
}
hll.MergeRegisters(ce.HyperLogLog)
}(NormalizeURL(url))
}
wg.Wait()
return int(hll.Count())
}
// QuerySingle returns the first event returned by the first relay, cancels everything else.
func (pool *SimplePool) QuerySingle(ctx context.Context, urls []string, filter Filter) *RelayEvent {
ctx, cancel := context.WithCancel(ctx)

View File

@ -273,7 +273,7 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error
}
case *CountEnvelope:
if subscription, ok := r.Subscriptions.Load(subIdToSerial(env.SubscriptionID)); ok && env.Count != nil && subscription.countResult != nil {
subscription.countResult <- *env.Count
subscription.countResult <- *env
}
case *OKEnvelope:
if okCallback, exist := r.okCallbacks.Load(env.EventID); exist {
@ -478,11 +478,19 @@ func (r *Relay) QuerySync(ctx context.Context, filter Filter) ([]*Event, error)
}
func (r *Relay) Count(ctx context.Context, filters Filters, opts ...SubscriptionOption) (int64, error) {
v, err := r.countInternal(ctx, filters, opts...)
if err != nil {
return 0, err
}
return *v.Count, nil
}
func (r *Relay) countInternal(ctx context.Context, filters Filters, opts ...SubscriptionOption) (CountEnvelope, error) {
sub := r.PrepareSubscription(ctx, filters, opts...)
sub.countResult = make(chan int64)
sub.countResult = make(chan CountEnvelope)
if err := sub.Fire(); err != nil {
return 0, err
return CountEnvelope{}, err
}
defer sub.Unsub()
@ -499,7 +507,7 @@ func (r *Relay) Count(ctx context.Context, filters Filters, opts ...Subscription
case count := <-sub.countResult:
return count, nil
case <-ctx.Done():
return 0, ctx.Err()
return CountEnvelope{}, ctx.Err()
}
}
}

View File

@ -15,7 +15,7 @@ type Subscription struct {
Filters Filters
// for this to be treated as a COUNT and not a REQ this must be set
countResult chan int64
countResult chan CountEnvelope
// the Events channel emits all EVENTs that come in a Subscription
// will be closed when the subscription ends
@ -152,7 +152,7 @@ func (sub *Subscription) Fire() error {
if sub.countResult == nil {
reqb, _ = ReqEnvelope{sub.id, sub.Filters}.MarshalJSON()
} else {
reqb, _ = CountEnvelope{sub.id, sub.Filters, nil}.MarshalJSON()
reqb, _ = CountEnvelope{sub.id, sub.Filters, nil, nil}.MarshalJSON()
}
sub.live.Store(true)