mirror of
https://github.com/nbd-wtf/go-nostr.git
synced 2025-05-03 15:20:13 +02:00
relay.Count()
This commit is contained in:
parent
6cee628149
commit
53b9dde6e0
26
count_test.go
Normal file
26
count_test.go
Normal file
@ -0,0 +1,26 @@
|
||||
package nostr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCount(t *testing.T) {
|
||||
const RELAY = "wss://relay.nostr.band"
|
||||
|
||||
rl := mustRelayConnect(RELAY)
|
||||
defer rl.Close()
|
||||
|
||||
count, err := rl.Count(context.Background(), Filters{
|
||||
{Kinds: []int{3}, Tags: TagMap{"p": []string{"3bf0c63fcb93463407af97a5e5ee64fa883d107ef9e558472c4eb9aaaefa459d"}}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("count request failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if count <= 0 {
|
||||
t.Errorf("count result wrong: %v", count)
|
||||
return
|
||||
}
|
||||
}
|
29
envelopes.go
29
envelopes.go
@ -137,6 +137,7 @@ func (v ReqEnvelope) MarshalJSON() ([]byte, error) {
|
||||
type CountEnvelope struct {
|
||||
SubscriptionID string
|
||||
Filters
|
||||
Count *int64
|
||||
}
|
||||
|
||||
func (_ CountEnvelope) Label() string { return "COUNT" }
|
||||
@ -148,12 +149,28 @@ func (v *CountEnvelope) UnmarshalJSON(data []byte) error {
|
||||
return fmt.Errorf("failed to decode COUNT envelope: missing filters")
|
||||
}
|
||||
v.SubscriptionID = arr[1].Str
|
||||
|
||||
if len(arr) < 3 {
|
||||
return fmt.Errorf("COUNT array must have at least 3 items")
|
||||
}
|
||||
|
||||
var countResult struct {
|
||||
Count *int64 `json:"count"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(arr[2].Raw), &countResult); err == nil && countResult.Count != nil {
|
||||
v.Count = countResult.Count
|
||||
return nil
|
||||
}
|
||||
|
||||
v.Filters = make(Filters, len(arr)-2)
|
||||
f := 0
|
||||
for i := 2; i < len(arr); i++ {
|
||||
if err := easyjson.Unmarshal([]byte(arr[i].Raw), &v.Filters[f]); err != nil {
|
||||
item := []byte(arr[i].Raw)
|
||||
|
||||
if err := easyjson.Unmarshal(item, &v.Filters[f]); err != nil {
|
||||
return fmt.Errorf("%w -- on filter %d", err, f)
|
||||
}
|
||||
|
||||
f++
|
||||
}
|
||||
|
||||
@ -164,9 +181,13 @@ func (v CountEnvelope) MarshalJSON() ([]byte, error) {
|
||||
w := jwriter.Writer{}
|
||||
w.RawString(`["COUNT",`)
|
||||
w.RawString(`"` + v.SubscriptionID + `"`)
|
||||
for _, filter := range v.Filters {
|
||||
w.RawString(`,`)
|
||||
filter.MarshalEasyJSON(&w)
|
||||
if v.Count != nil {
|
||||
w.RawString(fmt.Sprintf(`{"count":%d}`, *v.Count))
|
||||
} else {
|
||||
for _, filter := range v.Filters {
|
||||
w.RawString(`,`)
|
||||
filter.MarshalEasyJSON(&w)
|
||||
}
|
||||
}
|
||||
w.RawString(`]`)
|
||||
return w.BuildBytes()
|
||||
|
33
relay.go
33
relay.go
@ -306,6 +306,10 @@ func (r *Relay) Connect(ctx context.Context) error {
|
||||
}()
|
||||
}
|
||||
}
|
||||
case *CountEnvelope:
|
||||
if subscription, ok := r.Subscriptions.Load(string(env.SubscriptionID)); ok && env.Count != nil && subscription.countResult != nil {
|
||||
subscription.countResult <- *env.Count
|
||||
}
|
||||
case *OKEnvelope:
|
||||
if okCallback, exist := r.okCallbacks.Load(env.EventID); exist {
|
||||
okCallback(env.OK, env.Reason)
|
||||
@ -512,7 +516,7 @@ func (r *Relay) QuerySync(ctx context.Context, filter Filter, opts ...Subscripti
|
||||
defer sub.Unsub()
|
||||
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
// if no timeout is set, force it to 3 seconds
|
||||
// if no timeout is set, force it to 7 seconds
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, 7*time.Second)
|
||||
defer cancel()
|
||||
@ -535,6 +539,33 @@ func (r *Relay) QuerySync(ctx context.Context, filter Filter, opts ...Subscripti
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Relay) Count(ctx context.Context, filters Filters, opts ...SubscriptionOption) (int64, error) {
|
||||
sub := r.PrepareSubscription(ctx, filters, opts...)
|
||||
sub.countResult = make(chan int64)
|
||||
|
||||
if err := sub.Fire(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
defer sub.Unsub()
|
||||
|
||||
if _, ok := ctx.Deadline(); !ok {
|
||||
// if no timeout is set, force it to 7 seconds
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, 7*time.Second)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case count := <-sub.countResult:
|
||||
return count, nil
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Relay) Close() error {
|
||||
if r.connectionContextCancel == nil {
|
||||
return fmt.Errorf("relay not connected")
|
||||
|
@ -15,6 +15,9 @@ type Subscription struct {
|
||||
Relay *Relay
|
||||
Filters Filters
|
||||
|
||||
// for this to be treated as a COUNT and not a REQ this must be set
|
||||
countResult chan int64
|
||||
|
||||
// the Events channel emits all EVENTs that come in a Subscription
|
||||
// will be closed when the subscription ends
|
||||
Events chan *Event
|
||||
@ -123,7 +126,12 @@ func (sub *Subscription) Sub(ctx context.Context, filters Filters) {
|
||||
func (sub *Subscription) Fire() error {
|
||||
id := sub.GetID()
|
||||
|
||||
reqb, _ := ReqEnvelope{id, sub.Filters}.MarshalJSON()
|
||||
var reqb []byte
|
||||
if sub.countResult == nil {
|
||||
reqb, _ = ReqEnvelope{id, sub.Filters}.MarshalJSON()
|
||||
} else {
|
||||
reqb, _ = CountEnvelope{id, sub.Filters, nil}.MarshalJSON()
|
||||
}
|
||||
debugLogf("{%s} sending %v", sub.Relay.URL, reqb)
|
||||
|
||||
sub.live.Store(true)
|
||||
|
Loading…
x
Reference in New Issue
Block a user