pool: support CountMany() using hyperloglog.

This commit is contained in:
fiatjaf 2024-11-16 16:59:24 -03:00
parent 99e45035d5
commit 0d40b40c9c
7 changed files with 72 additions and 19 deletions

View File

@ -2,6 +2,7 @@ package nostr
import ( import (
"bytes" "bytes"
"encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv" "strconv"
@ -142,7 +143,8 @@ func (v ReqEnvelope) MarshalJSON() ([]byte, error) {
type CountEnvelope struct { type CountEnvelope struct {
SubscriptionID string SubscriptionID string
Filters Filters
Count *int64 Count *int64
HyperLogLog []byte
} }
func (_ CountEnvelope) Label() string { return "COUNT" } func (_ CountEnvelope) Label() string { return "COUNT" }
@ -161,9 +163,11 @@ func (v *CountEnvelope) UnmarshalJSON(data []byte) error {
var countResult struct { var countResult struct {
Count *int64 `json:"count"` Count *int64 `json:"count"`
HLL string `json:"hll"`
} }
if err := json.Unmarshal([]byte(arr[2].Raw), &countResult); err == nil && countResult.Count != nil { if err := json.Unmarshal([]byte(arr[2].Raw), &countResult); err == nil && countResult.Count != nil {
v.Count = countResult.Count v.Count = countResult.Count
v.HyperLogLog, _ = hex.DecodeString(countResult.HLL)
return nil return nil
} }
@ -189,6 +193,13 @@ func (v CountEnvelope) MarshalJSON() ([]byte, error) {
if v.Count != nil { if v.Count != nil {
w.RawString(`,{"count":`) w.RawString(`,{"count":`)
w.RawString(strconv.FormatInt(*v.Count, 10)) w.RawString(strconv.FormatInt(*v.Count, 10))
if v.HyperLogLog != nil {
w.RawString(`,"hll":"`)
hllHex := make([]byte, 0, 512)
hex.Encode(hllHex, v.HyperLogLog)
w.Buffer.AppendBytes(hllHex)
w.RawString(`"`)
}
w.RawString(`}`) w.RawString(`}`)
} else { } else {
for _, filter := range v.Filters { for _, filter := range v.Filters {

View File

@ -1,4 +1,4 @@
package nip45 package hyperloglog
import ( import (
"math" "math"

View File

@ -1,4 +1,4 @@
package nip45 package hyperloglog
import ( import (
"encoding/binary" "encoding/binary"
@ -18,13 +18,14 @@ func New() *HyperLogLog {
return hll return hll
} }
func (hll *HyperLogLog) Encode() string { func (hll *HyperLogLog) GetRegisters() []byte { return hll.registers }
return hex.EncodeToString(hll.registers) func (hll *HyperLogLog) SetRegisters(enc []byte) { hll.registers = enc }
} func (hll *HyperLogLog) MergeRegisters(other []byte) {
for i, v := range other {
func (hll *HyperLogLog) Decode(enc string) error { if v > hll.registers[i] {
_, err := hex.Decode(hll.registers, []byte(enc)) hll.registers[i] = v
return err }
}
} }
func (hll *HyperLogLog) Clear() { func (hll *HyperLogLog) Clear() {
@ -45,13 +46,12 @@ func (hll *HyperLogLog) Add(id string) {
} }
} }
func (hll *HyperLogLog) Merge(other *HyperLogLog) error { func (hll *HyperLogLog) Merge(other *HyperLogLog) {
for i, v := range other.registers { for i, v := range other.registers {
if v > hll.registers[i] { if v > hll.registers[i] {
hll.registers[i] = v hll.registers[i] = v
} }
} }
return nil
} }
func (hll *HyperLogLog) Count() uint64 { func (hll *HyperLogLog) Count() uint64 {

View File

@ -1,4 +1,4 @@
package nip45 package hyperloglog
import ( import (
"encoding/hex" "encoding/hex"

34
pool.go
View File

@ -10,6 +10,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/nbd-wtf/go-nostr/nip45/hyperloglog"
"github.com/puzpuzpuz/xsync/v3" "github.com/puzpuzpuz/xsync/v3"
) )
@ -468,6 +469,39 @@ func (pool *SimplePool) subManyEose(
return events return events
} }
// CountMany aggregates count results from multiple relays using HyperLogLog
func (pool *SimplePool) CountMany(
ctx context.Context,
urls []string,
filter Filter,
opts []SubscriptionOption,
) int {
hll := hyperloglog.New()
wg := sync.WaitGroup{}
wg.Add(len(urls))
for _, url := range urls {
go func(nm string) {
defer wg.Done()
relay, err := pool.EnsureRelay(url)
if err != nil {
return
}
ce, err := relay.countInternal(ctx, Filters{filter}, opts...)
if err != nil {
return
}
if len(ce.HyperLogLog) != 256 {
return
}
hll.MergeRegisters(ce.HyperLogLog)
}(NormalizeURL(url))
}
wg.Wait()
return int(hll.Count())
}
// QuerySingle returns the first event returned by the first relay, cancels everything else. // QuerySingle returns the first event returned by the first relay, cancels everything else.
func (pool *SimplePool) QuerySingle(ctx context.Context, urls []string, filter Filter) *RelayEvent { func (pool *SimplePool) QuerySingle(ctx context.Context, urls []string, filter Filter) *RelayEvent {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)

View File

@ -273,7 +273,7 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error
} }
case *CountEnvelope: case *CountEnvelope:
if subscription, ok := r.Subscriptions.Load(subIdToSerial(env.SubscriptionID)); ok && env.Count != nil && subscription.countResult != nil { if subscription, ok := r.Subscriptions.Load(subIdToSerial(env.SubscriptionID)); ok && env.Count != nil && subscription.countResult != nil {
subscription.countResult <- *env.Count subscription.countResult <- *env
} }
case *OKEnvelope: case *OKEnvelope:
if okCallback, exist := r.okCallbacks.Load(env.EventID); exist { if okCallback, exist := r.okCallbacks.Load(env.EventID); exist {
@ -478,11 +478,19 @@ func (r *Relay) QuerySync(ctx context.Context, filter Filter) ([]*Event, error)
} }
func (r *Relay) Count(ctx context.Context, filters Filters, opts ...SubscriptionOption) (int64, error) { func (r *Relay) Count(ctx context.Context, filters Filters, opts ...SubscriptionOption) (int64, error) {
v, err := r.countInternal(ctx, filters, opts...)
if err != nil {
return 0, err
}
return *v.Count, nil
}
func (r *Relay) countInternal(ctx context.Context, filters Filters, opts ...SubscriptionOption) (CountEnvelope, error) {
sub := r.PrepareSubscription(ctx, filters, opts...) sub := r.PrepareSubscription(ctx, filters, opts...)
sub.countResult = make(chan int64) sub.countResult = make(chan CountEnvelope)
if err := sub.Fire(); err != nil { if err := sub.Fire(); err != nil {
return 0, err return CountEnvelope{}, err
} }
defer sub.Unsub() defer sub.Unsub()
@ -499,7 +507,7 @@ func (r *Relay) Count(ctx context.Context, filters Filters, opts ...Subscription
case count := <-sub.countResult: case count := <-sub.countResult:
return count, nil return count, nil
case <-ctx.Done(): case <-ctx.Done():
return 0, ctx.Err() return CountEnvelope{}, ctx.Err()
} }
} }
} }

View File

@ -15,7 +15,7 @@ type Subscription struct {
Filters Filters Filters Filters
// for this to be treated as a COUNT and not a REQ this must be set // for this to be treated as a COUNT and not a REQ this must be set
countResult chan int64 countResult chan CountEnvelope
// the Events channel emits all EVENTs that come in a Subscription // the Events channel emits all EVENTs that come in a Subscription
// will be closed when the subscription ends // will be closed when the subscription ends
@ -152,7 +152,7 @@ func (sub *Subscription) Fire() error {
if sub.countResult == nil { if sub.countResult == nil {
reqb, _ = ReqEnvelope{sub.id, sub.Filters}.MarshalJSON() reqb, _ = ReqEnvelope{sub.id, sub.Filters}.MarshalJSON()
} else { } else {
reqb, _ = CountEnvelope{sub.id, sub.Filters, nil}.MarshalJSON() reqb, _ = CountEnvelope{sub.id, sub.Filters, nil, nil}.MarshalJSON()
} }
sub.live.Store(true) sub.live.Store(true)