relay.Count()

This commit is contained in:
fiatjaf 2023-07-18 16:17:00 -03:00
parent 6cee628149
commit 53b9dde6e0
No known key found for this signature in database
GPG Key ID: BAD43C4BE5C1A3A1
4 changed files with 92 additions and 6 deletions

26
count_test.go Normal file
View File

@ -0,0 +1,26 @@
package nostr
import (
"context"
"testing"
)
func TestCount(t *testing.T) {
const RELAY = "wss://relay.nostr.band"
rl := mustRelayConnect(RELAY)
defer rl.Close()
count, err := rl.Count(context.Background(), Filters{
{Kinds: []int{3}, Tags: TagMap{"p": []string{"3bf0c63fcb93463407af97a5e5ee64fa883d107ef9e558472c4eb9aaaefa459d"}}},
})
if err != nil {
t.Errorf("count request failed: %v", err)
return
}
if count <= 0 {
t.Errorf("count result wrong: %v", count)
return
}
}

View File

@ -137,6 +137,7 @@ func (v ReqEnvelope) MarshalJSON() ([]byte, error) {
type CountEnvelope struct {
SubscriptionID string
Filters
Count *int64
}
func (_ CountEnvelope) Label() string { return "COUNT" }
@ -148,12 +149,28 @@ func (v *CountEnvelope) UnmarshalJSON(data []byte) error {
return fmt.Errorf("failed to decode COUNT envelope: missing filters")
}
v.SubscriptionID = arr[1].Str
if len(arr) < 3 {
return fmt.Errorf("COUNT array must have at least 3 items")
}
var countResult struct {
Count *int64 `json:"count"`
}
if err := json.Unmarshal([]byte(arr[2].Raw), &countResult); err == nil && countResult.Count != nil {
v.Count = countResult.Count
return nil
}
v.Filters = make(Filters, len(arr)-2)
f := 0
for i := 2; i < len(arr); i++ {
if err := easyjson.Unmarshal([]byte(arr[i].Raw), &v.Filters[f]); err != nil {
item := []byte(arr[i].Raw)
if err := easyjson.Unmarshal(item, &v.Filters[f]); err != nil {
return fmt.Errorf("%w -- on filter %d", err, f)
}
f++
}
@ -164,10 +181,14 @@ func (v CountEnvelope) MarshalJSON() ([]byte, error) {
w := jwriter.Writer{}
w.RawString(`["COUNT",`)
w.RawString(`"` + v.SubscriptionID + `"`)
if v.Count != nil {
w.RawString(fmt.Sprintf(`{"count":%d}`, *v.Count))
} else {
for _, filter := range v.Filters {
w.RawString(`,`)
filter.MarshalEasyJSON(&w)
}
}
w.RawString(`]`)
return w.BuildBytes()
}

View File

@ -306,6 +306,10 @@ func (r *Relay) Connect(ctx context.Context) error {
}()
}
}
case *CountEnvelope:
if subscription, ok := r.Subscriptions.Load(string(env.SubscriptionID)); ok && env.Count != nil && subscription.countResult != nil {
subscription.countResult <- *env.Count
}
case *OKEnvelope:
if okCallback, exist := r.okCallbacks.Load(env.EventID); exist {
okCallback(env.OK, env.Reason)
@ -512,7 +516,7 @@ func (r *Relay) QuerySync(ctx context.Context, filter Filter, opts ...Subscripti
defer sub.Unsub()
if _, ok := ctx.Deadline(); !ok {
// if no timeout is set, force it to 3 seconds
// if no timeout is set, force it to 7 seconds
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 7*time.Second)
defer cancel()
@ -535,6 +539,33 @@ func (r *Relay) QuerySync(ctx context.Context, filter Filter, opts ...Subscripti
}
}
func (r *Relay) Count(ctx context.Context, filters Filters, opts ...SubscriptionOption) (int64, error) {
sub := r.PrepareSubscription(ctx, filters, opts...)
sub.countResult = make(chan int64)
if err := sub.Fire(); err != nil {
return 0, err
}
defer sub.Unsub()
if _, ok := ctx.Deadline(); !ok {
// if no timeout is set, force it to 7 seconds
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 7*time.Second)
defer cancel()
}
for {
select {
case count := <-sub.countResult:
return count, nil
case <-ctx.Done():
return 0, ctx.Err()
}
}
}
func (r *Relay) Close() error {
if r.connectionContextCancel == nil {
return fmt.Errorf("relay not connected")

View File

@ -15,6 +15,9 @@ type Subscription struct {
Relay *Relay
Filters Filters
// for this to be treated as a COUNT and not a REQ this must be set
countResult chan int64
// the Events channel emits all EVENTs that come in a Subscription
// will be closed when the subscription ends
Events chan *Event
@ -123,7 +126,12 @@ func (sub *Subscription) Sub(ctx context.Context, filters Filters) {
func (sub *Subscription) Fire() error {
id := sub.GetID()
reqb, _ := ReqEnvelope{id, sub.Filters}.MarshalJSON()
var reqb []byte
if sub.countResult == nil {
reqb, _ = ReqEnvelope{id, sub.Filters}.MarshalJSON()
} else {
reqb, _ = CountEnvelope{id, sub.Filters, nil}.MarshalJSON()
}
debugLogf("{%s} sending %v", sub.Relay.URL, reqb)
sub.live.Store(true)