From 4a62a753e67413d168a32a18ceef6fbc43844dc0 Mon Sep 17 00:00:00 2001 From: fiatjaf Date: Sun, 1 Jan 2023 20:22:40 -0300 Subject: [PATCH] contexts everywhere. --- relay.go | 130 +++++++++++++++++++++++++----------------------- relay_test.go | 39 ++++----------- subscription.go | 13 +++-- 3 files changed, 89 insertions(+), 93 deletions(-) diff --git a/relay.go b/relay.go index 067d0cd..9fec740 100644 --- a/relay.go +++ b/relay.go @@ -43,19 +43,12 @@ type Relay struct { Notices chan string ConnectionError chan error - statusChans s.MapOf[string, chan Status] + okCallbacks s.MapOf[string, func(bool)] } -// RelayConnect forwards calls to RelayConnectContext with a background context. -func RelayConnect(url string) (*Relay, error) { - return RelayConnectContext(context.Background(), url) -} - -// RelayConnectContext creates a new relay client and connects to a canonical -// URL using Relay.ConnectContext, passing ctx as is. -func RelayConnectContext(ctx context.Context, url string) (*Relay, error) { +func RelayConnect(ctx context.Context, url string) (*Relay, error) { r := &Relay{URL: NormalizeURL(url)} - err := r.ConnectContext(ctx) + err := r.Connect(ctx) return r, err } @@ -63,16 +56,11 @@ func (r *Relay) String() string { return r.URL } -// Connect calls ConnectContext with a background context. -func (r *Relay) Connect() error { - return r.ConnectContext(context.Background()) -} - -// ConnectContext tries to establish a websocket connection to r.URL. +// Connect tries to establish a websocket connection to r.URL. // If the context expires before the connection is complete, an error is returned. // Once successfully connected, context expiration has no effect: call r.Close // to close the connection. -func (r *Relay) ConnectContext(ctx context.Context) error { +func (r *Relay) Connect(ctx context.Context) error { if r.URL == "" { return fmt.Errorf("invalid relay URL '%s'", r.URL) } @@ -176,12 +164,8 @@ func (r *Relay) ConnectContext(ctx context.Context) error { json.Unmarshal(jsonMessage[1], &eventId) json.Unmarshal(jsonMessage[2], &ok) - if statusChan, exist := r.statusChans.Load(eventId); exist { - if ok { - statusChan <- PublishStatusSucceeded - } else { - statusChan <- PublishStatusFailed - } + if okCallback, exist := r.okCallbacks.Load(eventId); exist { + okCallback(ok) } } } @@ -190,60 +174,84 @@ func (r *Relay) ConnectContext(ctx context.Context) error { return nil } -func (r *Relay) Publish(event Event) chan Status { - statusChan := make(chan Status, 4) +func (r *Relay) Publish(ctx context.Context, event Event) Status { + status := PublishStatusFailed - go func() { - // we keep track of this so the OK message can be used to close it - r.statusChans.Store(event.ID, statusChan) - defer r.statusChans.Delete(event.ID) + if _, ok := ctx.Deadline(); !ok { + // if no timeout is set, force it to 3 seconds + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, 3*time.Second) + defer cancel() + } - err := r.Connection.WriteJSON([]interface{}{"EVENT", event}) - if err != nil { - statusChan <- PublishStatusFailed - close(statusChan) - return + // make it cancellable so we can stop everything upon receiving an "OK" + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer cancel() + + // listen for an OK callback + okCallback := func(ok bool) { + if ok { + status = PublishStatusSucceeded + } else { + status = PublishStatusFailed } - statusChan <- PublishStatusSent + cancel() + } + r.okCallbacks.Store(event.ID, okCallback) + defer r.okCallbacks.Delete(event.ID) - // TODO: there's no reason to sub if the relay supports nip-20 (command results). - // in fact, subscribing here with nip20-enabled relays makes it send PublishStatusSucceded - // twice: once here, and the other on "OK" command result. - sub := r.Subscribe(Filters{Filter{IDs: []string{event.ID}}}) - for { - select { - case receivedEvent := <-sub.Events: - if receivedEvent.ID == event.ID { - statusChan <- PublishStatusSucceeded - close(statusChan) - break - } else { - continue - } - case <-time.After(5 * time.Second): - close(statusChan) - break + // publish event + err := r.Connection.WriteJSON([]interface{}{"EVENT", event}) + if err != nil { + return status + } + + // update status (this will be returned later) + status = PublishStatusSent + + sub := r.Subscribe(ctx, Filters{Filter{IDs: []string{event.ID}}}) + for { + select { + case receivedEvent := <-sub.Events: + if receivedEvent.ID == event.ID { + // we got a success, so update our status and proceed to return + status = PublishStatusSucceeded + return status } - break + case <-ctx.Done(): + // return status as it was + // will proceed to return status as it is + // e.g. if this happens because of the timeout then status will probably be "failed" + // but if it happens because okCallback was called then it might be "succeeded" + return status } - }() - - return statusChan + } } -func (r *Relay) Subscribe(filters Filters) *Subscription { +func (r *Relay) Subscribe(ctx context.Context, filters Filters) *Subscription { if r.Connection == nil { panic(fmt.Errorf("must call .Connect() first before calling .Subscribe()")) } sub := r.PrepareSubscription() sub.Filters = filters - sub.Fire() + sub.Fire(ctx) + return sub } -func (r *Relay) QuerySync(filter Filter, timeout time.Duration) []Event { - sub := r.Subscribe(Filters{filter}) +func (r *Relay) QuerySync(ctx context.Context, filter Filter) []Event { + sub := r.Subscribe(ctx, Filters{filter}) + defer sub.Unsub() + + if _, ok := ctx.Deadline(); !ok { + // if no timeout is set, force it to 3 seconds + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, 3*time.Second) + defer cancel() + } + var events []Event for { select { @@ -251,7 +259,7 @@ func (r *Relay) QuerySync(filter Filter, timeout time.Duration) []Event { events = append(events, evt) case <-sub.EndOfStoredEvents: return events - case <-time.After(timeout): + case <-ctx.Done(): return events } } diff --git a/relay_test.go b/relay_test.go index 0801733..31168d3 100644 --- a/relay_test.go +++ b/relay_test.go @@ -55,11 +55,11 @@ func TestPublish(t *testing.T) { // connect a client and send the text note rl := mustRelayConnect(ws.URL) - want := map[Status]bool{ - PublishStatusSent: true, - PublishStatusSucceeded: true, + status := rl.Publish(context.Background(), textNote) + if status != PublishStatusSucceeded { + t.Errorf("published status is %d, not %d", status, PublishStatusSucceeded) } - testPublishStatus(t, rl.Publish(textNote), want) + if !published { t.Errorf("fake relay server saw no event") } @@ -85,11 +85,10 @@ func TestPublishBlocked(t *testing.T) { // connect a client and send a text note rl := mustRelayConnect(ws.URL) - want := map[Status]bool{ - PublishStatusSent: true, - PublishStatusFailed: true, + status := rl.Publish(context.Background(), textNote) + if status != PublishStatusFailed { + t.Errorf("published status is %d, not %d", status, PublishStatusSucceeded) } - testPublishStatus(t, rl.Publish(textNote), want) } func TestConnectContext(t *testing.T) { @@ -107,7 +106,7 @@ func TestConnectContext(t *testing.T) { // relay client ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - r, err := RelayConnectContext(ctx, ws.URL) + r, err := RelayConnect(ctx, ws.URL) if err != nil { t.Fatalf("RelayConnectContext: %v", err) } @@ -130,7 +129,7 @@ func TestConnectContextCanceled(t *testing.T) { // relay client ctx, cancel := context.WithCancel(context.Background()) cancel() // make ctx expired - _, err := RelayConnectContext(ctx, ws.URL) + _, err := RelayConnect(ctx, ws.URL) if !errors.Is(err, context.Canceled) { t.Errorf("RelayConnectContext returned %v error; want context.Canceled", err) } @@ -163,7 +162,7 @@ func makeKeyPair(t *testing.T) (priv, pub string) { func mustRelayConnect(url string) *Relay { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - rl, err := RelayConnectContext(ctx, url) + rl, err := RelayConnect(ctx, url) if err != nil { panic(err.Error()) } @@ -211,21 +210,3 @@ func parseSubscriptionMessage(t *testing.T, raw []json.RawMessage) (subid string } return id, ff } - -func testPublishStatus(t *testing.T, ch <-chan Status, want map[Status]bool) { - for stat := range ch { - if !want[stat] { - t.Errorf("client reported %q status", stat) - } - delete(want, stat) - // stop early to speed up tests - if len(want) == 0 { - break - } - } - for stat, missed := range want { - if missed { - t.Errorf("client didn't report %q", stat) - } - } -} diff --git a/subscription.go b/subscription.go index 59c22e6..8ae66e5 100644 --- a/subscription.go +++ b/subscription.go @@ -1,6 +1,7 @@ package nostr import ( + "context" "sync" ) @@ -34,16 +35,22 @@ func (sub *Subscription) Unsub() { sub.stopped = true } -func (sub *Subscription) Sub(filters Filters) { +func (sub *Subscription) Sub(ctx context.Context, filters Filters) { sub.Filters = filters - sub.Fire() + sub.Fire(ctx) } -func (sub *Subscription) Fire() { +func (sub *Subscription) Fire(ctx context.Context) { message := []interface{}{"REQ", sub.id} for _, filter := range sub.Filters { message = append(message, filter) } sub.conn.WriteJSON(message) + + // the subscription ends once the context is canceled + go func() { + <-ctx.Done() + sub.Unsub() + }() }