subscription labels: GetID() and SetLabel().

This commit is contained in:
fiatjaf
2023-03-18 08:39:31 -03:00
parent 661e299981
commit 3f66c60b5f
3 changed files with 38 additions and 26 deletions

View File

@@ -41,11 +41,11 @@ type Relay struct {
RequestHeader http.Header // e.g. for origin header RequestHeader http.Header // e.g. for origin header
Connection *Connection Connection *Connection
subscriptions s.MapOf[int, *Subscription] subscriptions s.MapOf[string, *Subscription]
Challenges chan string // NIP-42 Challenges Challenges chan string // NIP-42 Challenges
Notices chan string Notices chan string
ConnectionError chan error ConnectionError error
ConnectionContext context.Context // will be canceled when the connection closes ConnectionContext context.Context // will be canceled when the connection closes
okCallbacks s.MapOf[string, func(bool, string)] okCallbacks s.MapOf[string, func(bool, string)]
@@ -73,8 +73,8 @@ func (r *Relay) String() string {
// Once successfully connected, context expiration has no effect: call r.Close // Once successfully connected, context expiration has no effect: call r.Close
// to close the connection. // to close the connection.
func (r *Relay) Connect(ctx context.Context) error { func (r *Relay) Connect(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx) connectionContext, cancel := context.WithCancel(context.Background())
r.ConnectionContext = ctx r.ConnectionContext = connectionContext
if r.URL == "" { if r.URL == "" {
cancel() cancel()
@@ -96,7 +96,6 @@ func (r *Relay) Connect(ctx context.Context) error {
r.Challenges = make(chan string) r.Challenges = make(chan string)
r.Notices = make(chan string) r.Notices = make(chan string)
r.ConnectionError = make(chan error)
conn := NewConnection(socket) conn := NewConnection(socket)
r.Connection = conn r.Connection = conn
@@ -105,7 +104,7 @@ func (r *Relay) Connect(ctx context.Context) error {
for { for {
typ, message, err := conn.socket.ReadMessage() typ, message, err := conn.socket.ReadMessage()
if err != nil { if err != nil {
r.ConnectionError <- err r.ConnectionError = err
break break
} }
@@ -128,10 +127,10 @@ func (r *Relay) Connect(ctx context.Context) error {
continue continue
} }
var label string var command string
json.Unmarshal(jsonMessage[0], &label) json.Unmarshal(jsonMessage[0], &command)
switch label { switch command {
case "NOTICE": case "NOTICE":
var content string var content string
json.Unmarshal(jsonMessage[1], &content) json.Unmarshal(jsonMessage[1], &content)
@@ -149,9 +148,9 @@ func (r *Relay) Connect(ctx context.Context) error {
continue continue
} }
var channel int var subId string
json.Unmarshal(jsonMessage[1], &channel) json.Unmarshal(jsonMessage[1], &subId)
if subscription, ok := r.subscriptions.Load(channel); ok { if subscription, ok := r.subscriptions.Load(subId); ok {
var event Event var event Event
json.Unmarshal(jsonMessage[2], &event) json.Unmarshal(jsonMessage[2], &event)
@@ -162,7 +161,7 @@ func (r *Relay) Connect(ctx context.Context) error {
if err != nil { if err != nil {
errmsg = err.Error() errmsg = err.Error()
} }
log.Printf("bad signature: %s", errmsg) log.Printf("bad signature: %s\n", errmsg)
continue continue
} }
} }
@@ -181,9 +180,9 @@ func (r *Relay) Connect(ctx context.Context) error {
if len(jsonMessage) < 2 { if len(jsonMessage) < 2 {
continue continue
} }
var channel int var subId string
json.Unmarshal(jsonMessage[1], &channel) json.Unmarshal(jsonMessage[1], &subId)
if subscription, ok := r.subscriptions.Load(channel); ok { if subscription, ok := r.subscriptions.Load(subId); ok {
subscription.emitEose.Do(func() { subscription.emitEose.Do(func() {
subscription.EndOfStoredEvents <- struct{}{} subscription.EndOfStoredEvents <- struct{}{}
}) })
@@ -364,7 +363,7 @@ func (r *Relay) QuerySync(ctx context.Context, filter Filter) []*Event {
if _, ok := ctx.Deadline(); !ok { if _, ok := ctx.Deadline(); !ok {
// if no timeout is set, force it to 3 seconds // if no timeout is set, force it to 3 seconds
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 3*time.Second) ctx, cancel = context.WithTimeout(ctx, 7*time.Second)
defer cancel() defer cancel()
} }
@@ -397,7 +396,7 @@ func (r *Relay) PrepareSubscription() *Subscription {
EndOfStoredEvents: make(chan struct{}, 1), EndOfStoredEvents: make(chan struct{}, 1),
} }
r.subscriptions.Store(sub.id, sub) r.subscriptions.Store(sub.GetID(), sub)
return sub return sub
} }

View File

@@ -55,7 +55,7 @@ func TestPublish(t *testing.T) {
// connect a client and send the text note // connect a client and send the text note
rl := mustRelayConnect(ws.URL) rl := mustRelayConnect(ws.URL)
status := rl.Publish(context.Background(), textNote) status, _ := rl.Publish(context.Background(), textNote)
if status != PublishStatusSucceeded { if status != PublishStatusSucceeded {
t.Errorf("published status is %d, not %d", status, PublishStatusSucceeded) t.Errorf("published status is %d, not %d", status, PublishStatusSucceeded)
} }
@@ -85,7 +85,7 @@ func TestPublishBlocked(t *testing.T) {
// connect a client and send a text note // connect a client and send a text note
rl := mustRelayConnect(ws.URL) rl := mustRelayConnect(ws.URL)
status := rl.Publish(context.Background(), textNote) status, _ := rl.Publish(context.Background(), textNote)
if status != PublishStatusFailed { if status != PublishStatusFailed {
t.Errorf("published status is %d, not %d", status, PublishStatusSucceeded) t.Errorf("published status is %d, not %d", status, PublishStatusSucceeded)
} }

View File

@@ -7,6 +7,7 @@ import (
) )
type Subscription struct { type Subscription struct {
label string
id int id int
conn *Connection conn *Connection
mutex sync.Mutex mutex sync.Mutex
@@ -26,9 +27,15 @@ type EventMessage struct {
Relay string Relay string
} }
// GetID return the Nostr subscription ID as given to the relay, it will be a sequential number, stringified // SetLabel puts a label on the subscription that is prepended to the id that is sent to relays,
// it's only useful for debugging and sanity purposes.
func (sub *Subscription) SetLabel(label string) {
sub.label = label
}
// GetID return the Nostr subscription ID as given to the relay, it will be a sequential number, stringified.
func (sub *Subscription) GetID() string { func (sub *Subscription) GetID() string {
return strconv.Itoa(sub.id) return sub.label + ":" + strconv.Itoa(sub.id)
} }
// Unsub closes the subscription, sending "CLOSE" to relay as in NIP-01. // Unsub closes the subscription, sending "CLOSE" to relay as in NIP-01.
@@ -37,7 +44,7 @@ func (sub *Subscription) Unsub() {
sub.mutex.Lock() sub.mutex.Lock()
defer sub.mutex.Unlock() defer sub.mutex.Unlock()
sub.conn.WriteJSON([]interface{}{"CLOSE", strconv.Itoa(sub.id)}) sub.conn.WriteJSON([]interface{}{"CLOSE", sub.GetID()})
if sub.stopped == false && sub.Events != nil { if sub.stopped == false && sub.Events != nil {
close(sub.Events) close(sub.Events)
} }
@@ -52,16 +59,20 @@ func (sub *Subscription) Sub(ctx context.Context, filters Filters) {
// Fire sends the "REQ" command to the relay. // Fire sends the "REQ" command to the relay.
// When ctx is cancelled, sub.Unsub() is called, closing the subscription. // When ctx is cancelled, sub.Unsub() is called, closing the subscription.
func (sub *Subscription) Fire(ctx context.Context) { func (sub *Subscription) Fire(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
sub.Context = ctx sub.Context = ctx
message := []interface{}{"REQ", strconv.Itoa(sub.id)} message := []interface{}{"REQ", sub.GetID()}
for _, filter := range sub.Filters { for _, filter := range sub.Filters {
message = append(message, filter) message = append(message, filter)
} }
sub.conn.WriteJSON(message) err := sub.conn.WriteJSON(message)
if err != nil {
cancel()
return err
}
// the subscription ends once the context is canceled // the subscription ends once the context is canceled
go func() { go func() {
@@ -80,4 +91,6 @@ func (sub *Subscription) Fire(ctx context.Context) {
// we also cancel the context // we also cancel the context
cancel() cancel()
}() }()
return nil
} }