mirror of
https://github.com/nbd-wtf/go-nostr.git
synced 2025-03-17 21:32:56 +01:00
pool: support CountMany() using hyperloglog.
This commit is contained in:
parent
99e45035d5
commit
0d40b40c9c
11
envelopes.go
11
envelopes.go
@ -2,6 +2,7 @@ package nostr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
@ -143,6 +144,7 @@ type CountEnvelope struct {
|
||||
SubscriptionID string
|
||||
Filters
|
||||
Count *int64
|
||||
HyperLogLog []byte
|
||||
}
|
||||
|
||||
func (_ CountEnvelope) Label() string { return "COUNT" }
|
||||
@ -161,9 +163,11 @@ func (v *CountEnvelope) UnmarshalJSON(data []byte) error {
|
||||
|
||||
var countResult struct {
|
||||
Count *int64 `json:"count"`
|
||||
HLL string `json:"hll"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(arr[2].Raw), &countResult); err == nil && countResult.Count != nil {
|
||||
v.Count = countResult.Count
|
||||
v.HyperLogLog, _ = hex.DecodeString(countResult.HLL)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -189,6 +193,13 @@ func (v CountEnvelope) MarshalJSON() ([]byte, error) {
|
||||
if v.Count != nil {
|
||||
w.RawString(`,{"count":`)
|
||||
w.RawString(strconv.FormatInt(*v.Count, 10))
|
||||
if v.HyperLogLog != nil {
|
||||
w.RawString(`,"hll":"`)
|
||||
hllHex := make([]byte, 0, 512)
|
||||
hex.Encode(hllHex, v.HyperLogLog)
|
||||
w.Buffer.AppendBytes(hllHex)
|
||||
w.RawString(`"`)
|
||||
}
|
||||
w.RawString(`}`)
|
||||
} else {
|
||||
for _, filter := range v.Filters {
|
||||
|
@ -1,4 +1,4 @@
|
||||
package nip45
|
||||
package hyperloglog
|
||||
|
||||
import (
|
||||
"math"
|
@ -1,4 +1,4 @@
|
||||
package nip45
|
||||
package hyperloglog
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
@ -18,13 +18,14 @@ func New() *HyperLogLog {
|
||||
return hll
|
||||
}
|
||||
|
||||
func (hll *HyperLogLog) Encode() string {
|
||||
return hex.EncodeToString(hll.registers)
|
||||
func (hll *HyperLogLog) GetRegisters() []byte { return hll.registers }
|
||||
func (hll *HyperLogLog) SetRegisters(enc []byte) { hll.registers = enc }
|
||||
func (hll *HyperLogLog) MergeRegisters(other []byte) {
|
||||
for i, v := range other {
|
||||
if v > hll.registers[i] {
|
||||
hll.registers[i] = v
|
||||
}
|
||||
}
|
||||
|
||||
func (hll *HyperLogLog) Decode(enc string) error {
|
||||
_, err := hex.Decode(hll.registers, []byte(enc))
|
||||
return err
|
||||
}
|
||||
|
||||
func (hll *HyperLogLog) Clear() {
|
||||
@ -45,13 +46,12 @@ func (hll *HyperLogLog) Add(id string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (hll *HyperLogLog) Merge(other *HyperLogLog) error {
|
||||
func (hll *HyperLogLog) Merge(other *HyperLogLog) {
|
||||
for i, v := range other.registers {
|
||||
if v > hll.registers[i] {
|
||||
hll.registers[i] = v
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hll *HyperLogLog) Count() uint64 {
|
@ -1,4 +1,4 @@
|
||||
package nip45
|
||||
package hyperloglog
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
34
pool.go
34
pool.go
@ -10,6 +10,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/nbd-wtf/go-nostr/nip45/hyperloglog"
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
)
|
||||
|
||||
@ -468,6 +469,39 @@ func (pool *SimplePool) subManyEose(
|
||||
return events
|
||||
}
|
||||
|
||||
// CountMany aggregates count results from multiple relays using HyperLogLog
|
||||
func (pool *SimplePool) CountMany(
|
||||
ctx context.Context,
|
||||
urls []string,
|
||||
filter Filter,
|
||||
opts []SubscriptionOption,
|
||||
) int {
|
||||
hll := hyperloglog.New()
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(len(urls))
|
||||
for _, url := range urls {
|
||||
go func(nm string) {
|
||||
defer wg.Done()
|
||||
relay, err := pool.EnsureRelay(url)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ce, err := relay.countInternal(ctx, Filters{filter}, opts...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(ce.HyperLogLog) != 256 {
|
||||
return
|
||||
}
|
||||
hll.MergeRegisters(ce.HyperLogLog)
|
||||
}(NormalizeURL(url))
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return int(hll.Count())
|
||||
}
|
||||
|
||||
// QuerySingle returns the first event returned by the first relay, cancels everything else.
|
||||
func (pool *SimplePool) QuerySingle(ctx context.Context, urls []string, filter Filter) *RelayEvent {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
16
relay.go
16
relay.go
@ -273,7 +273,7 @@ func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error
|
||||
}
|
||||
case *CountEnvelope:
|
||||
if subscription, ok := r.Subscriptions.Load(subIdToSerial(env.SubscriptionID)); ok && env.Count != nil && subscription.countResult != nil {
|
||||
subscription.countResult <- *env.Count
|
||||
subscription.countResult <- *env
|
||||
}
|
||||
case *OKEnvelope:
|
||||
if okCallback, exist := r.okCallbacks.Load(env.EventID); exist {
|
||||
@ -478,11 +478,19 @@ func (r *Relay) QuerySync(ctx context.Context, filter Filter) ([]*Event, error)
|
||||
}
|
||||
|
||||
func (r *Relay) Count(ctx context.Context, filters Filters, opts ...SubscriptionOption) (int64, error) {
|
||||
v, err := r.countInternal(ctx, filters, opts...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return *v.Count, nil
|
||||
}
|
||||
|
||||
func (r *Relay) countInternal(ctx context.Context, filters Filters, opts ...SubscriptionOption) (CountEnvelope, error) {
|
||||
sub := r.PrepareSubscription(ctx, filters, opts...)
|
||||
sub.countResult = make(chan int64)
|
||||
sub.countResult = make(chan CountEnvelope)
|
||||
|
||||
if err := sub.Fire(); err != nil {
|
||||
return 0, err
|
||||
return CountEnvelope{}, err
|
||||
}
|
||||
|
||||
defer sub.Unsub()
|
||||
@ -499,7 +507,7 @@ func (r *Relay) Count(ctx context.Context, filters Filters, opts ...Subscription
|
||||
case count := <-sub.countResult:
|
||||
return count, nil
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
return CountEnvelope{}, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ type Subscription struct {
|
||||
Filters Filters
|
||||
|
||||
// for this to be treated as a COUNT and not a REQ this must be set
|
||||
countResult chan int64
|
||||
countResult chan CountEnvelope
|
||||
|
||||
// the Events channel emits all EVENTs that come in a Subscription
|
||||
// will be closed when the subscription ends
|
||||
@ -152,7 +152,7 @@ func (sub *Subscription) Fire() error {
|
||||
if sub.countResult == nil {
|
||||
reqb, _ = ReqEnvelope{sub.id, sub.Filters}.MarshalJSON()
|
||||
} else {
|
||||
reqb, _ = CountEnvelope{sub.id, sub.Filters, nil}.MarshalJSON()
|
||||
reqb, _ = CountEnvelope{sub.id, sub.Filters, nil, nil}.MarshalJSON()
|
||||
}
|
||||
|
||||
sub.live.Store(true)
|
||||
|
Loading…
x
Reference in New Issue
Block a user