diff --git a/count_test.go b/count_test.go new file mode 100644 index 0000000..97bc418 --- /dev/null +++ b/count_test.go @@ -0,0 +1,26 @@ +package nostr + +import ( + "context" + "testing" +) + +func TestCount(t *testing.T) { + const RELAY = "wss://relay.nostr.band" + + rl := mustRelayConnect(RELAY) + defer rl.Close() + + count, err := rl.Count(context.Background(), Filters{ + {Kinds: []int{3}, Tags: TagMap{"p": []string{"3bf0c63fcb93463407af97a5e5ee64fa883d107ef9e558472c4eb9aaaefa459d"}}}, + }) + if err != nil { + t.Errorf("count request failed: %v", err) + return + } + + if count <= 0 { + t.Errorf("count result wrong: %v", count) + return + } +} diff --git a/envelopes.go b/envelopes.go index 5c78fdd..aa7675a 100644 --- a/envelopes.go +++ b/envelopes.go @@ -137,6 +137,7 @@ func (v ReqEnvelope) MarshalJSON() ([]byte, error) { type CountEnvelope struct { SubscriptionID string Filters + Count *int64 } func (_ CountEnvelope) Label() string { return "COUNT" } @@ -148,12 +149,28 @@ func (v *CountEnvelope) UnmarshalJSON(data []byte) error { return fmt.Errorf("failed to decode COUNT envelope: missing filters") } v.SubscriptionID = arr[1].Str + + if len(arr) < 3 { + return fmt.Errorf("COUNT array must have at least 3 items") + } + + var countResult struct { + Count *int64 `json:"count"` + } + if err := json.Unmarshal([]byte(arr[2].Raw), &countResult); err == nil && countResult.Count != nil { + v.Count = countResult.Count + return nil + } + v.Filters = make(Filters, len(arr)-2) f := 0 for i := 2; i < len(arr); i++ { - if err := easyjson.Unmarshal([]byte(arr[i].Raw), &v.Filters[f]); err != nil { + item := []byte(arr[i].Raw) + + if err := easyjson.Unmarshal(item, &v.Filters[f]); err != nil { return fmt.Errorf("%w -- on filter %d", err, f) } + f++ } @@ -164,9 +181,13 @@ func (v CountEnvelope) MarshalJSON() ([]byte, error) { w := jwriter.Writer{} w.RawString(`["COUNT",`) w.RawString(`"` + v.SubscriptionID + `"`) - for _, filter := range v.Filters { - w.RawString(`,`) - filter.MarshalEasyJSON(&w) + if v.Count != nil { + w.RawString(fmt.Sprintf(`{"count":%d}`, *v.Count)) + } else { + for _, filter := range v.Filters { + w.RawString(`,`) + filter.MarshalEasyJSON(&w) + } } w.RawString(`]`) return w.BuildBytes() diff --git a/relay.go b/relay.go index ac7941f..68c18ac 100644 --- a/relay.go +++ b/relay.go @@ -306,6 +306,10 @@ func (r *Relay) Connect(ctx context.Context) error { }() } } + case *CountEnvelope: + if subscription, ok := r.Subscriptions.Load(string(env.SubscriptionID)); ok && env.Count != nil && subscription.countResult != nil { + subscription.countResult <- *env.Count + } case *OKEnvelope: if okCallback, exist := r.okCallbacks.Load(env.EventID); exist { okCallback(env.OK, env.Reason) @@ -512,7 +516,7 @@ func (r *Relay) QuerySync(ctx context.Context, filter Filter, opts ...Subscripti defer sub.Unsub() if _, ok := ctx.Deadline(); !ok { - // if no timeout is set, force it to 3 seconds + // if no timeout is set, force it to 7 seconds var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, 7*time.Second) defer cancel() @@ -535,6 +539,33 @@ func (r *Relay) QuerySync(ctx context.Context, filter Filter, opts ...Subscripti } } +func (r *Relay) Count(ctx context.Context, filters Filters, opts ...SubscriptionOption) (int64, error) { + sub := r.PrepareSubscription(ctx, filters, opts...) + sub.countResult = make(chan int64) + + if err := sub.Fire(); err != nil { + return 0, err + } + + defer sub.Unsub() + + if _, ok := ctx.Deadline(); !ok { + // if no timeout is set, force it to 7 seconds + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, 7*time.Second) + defer cancel() + } + + for { + select { + case count := <-sub.countResult: + return count, nil + case <-ctx.Done(): + return 0, ctx.Err() + } + } +} + func (r *Relay) Close() error { if r.connectionContextCancel == nil { return fmt.Errorf("relay not connected") diff --git a/subscription.go b/subscription.go index 49a8008..9a3f413 100644 --- a/subscription.go +++ b/subscription.go @@ -15,6 +15,9 @@ type Subscription struct { Relay *Relay Filters Filters + // for this to be treated as a COUNT and not a REQ this must be set + countResult chan int64 + // the Events channel emits all EVENTs that come in a Subscription // will be closed when the subscription ends Events chan *Event @@ -123,7 +126,12 @@ func (sub *Subscription) Sub(ctx context.Context, filters Filters) { func (sub *Subscription) Fire() error { id := sub.GetID() - reqb, _ := ReqEnvelope{id, sub.Filters}.MarshalJSON() + var reqb []byte + if sub.countResult == nil { + reqb, _ = ReqEnvelope{id, sub.Filters}.MarshalJSON() + } else { + reqb, _ = CountEnvelope{id, sub.Filters, nil}.MarshalJSON() + } debugLogf("{%s} sending %v", sub.Relay.URL, reqb) sub.live.Store(true)