diff --git a/envelopes.go b/envelopes.go index d90916c..5e3e8fa 100644 --- a/envelopes.go +++ b/envelopes.go @@ -1,6 +1,7 @@ package nostr import ( + "bytes" "encoding/json" "fmt" @@ -9,11 +10,58 @@ import ( "github.com/tidwall/gjson" ) +func ParseMessage(message []byte) Envelope { + firstComma := bytes.Index(message, []byte{','}) + if firstComma == -1 { + return nil + } + label := message[0:firstComma] + var v Envelope + switch { + case bytes.Contains(label, []byte("EVENT")): + v = &EventEnvelope{} + case bytes.Contains(label, []byte("REQ")): + v = &ReqEnvelope{} + case bytes.Contains(label, []byte("NOTICE")): + x := NoticeEnvelope("") + v = &x + case bytes.Contains(label, []byte("EOSE")): + x := EOSEEnvelope("") + v = &x + case bytes.Contains(label, []byte("OK")): + v = &OKEnvelope{} + case bytes.Contains(label, []byte("AUTH")): + v = &AuthEnvelope{} + } + + if err := v.UnmarshalJSON(message); err != nil { + return nil + } + return v +} + +type Envelope interface { + Label() string + UnmarshalJSON([]byte) error + MarshalJSON() ([]byte, error) +} + type EventEnvelope struct { SubscriptionID *string Event } +var ( + _ Envelope = (*EventEnvelope)(nil) + _ Envelope = (*ReqEnvelope)(nil) + _ Envelope = (*NoticeEnvelope)(nil) + _ Envelope = (*EOSEEnvelope)(nil) + _ Envelope = (*OKEnvelope)(nil) + _ Envelope = (*AuthEnvelope)(nil) +) + +func (_ EventEnvelope) Label() string { return "EVENT" } + func (v *EventEnvelope) UnmarshalJSON(data []byte) error { r := gjson.ParseBytes(data) arr := r.Array() @@ -44,6 +92,8 @@ type ReqEnvelope struct { Filters } +func (_ ReqEnvelope) Label() string { return "REQ" } + func (v *ReqEnvelope) UnmarshalJSON(data []byte) error { r := gjson.ParseBytes(data) arr := r.Array() @@ -77,6 +127,8 @@ func (v ReqEnvelope) MarshalJSON() ([]byte, error) { type NoticeEnvelope string +func (_ NoticeEnvelope) Label() string { return "NOTICE" } + func (v *NoticeEnvelope) UnmarshalJSON(data []byte) error { r := gjson.ParseBytes(data) arr := r.Array() @@ -99,6 +151,8 @@ func (v NoticeEnvelope) MarshalJSON() ([]byte, error) { type EOSEEnvelope string +func (_ EOSEEnvelope) Label() string { return "EOSE" } + func (v *EOSEEnvelope) UnmarshalJSON(data []byte) error { r := gjson.ParseBytes(data) arr := r.Array() @@ -125,6 +179,8 @@ type OKEnvelope struct { Reason *string } +func (_ OKEnvelope) Label() string { return "OK" } + func (v *OKEnvelope) UnmarshalJSON(data []byte) error { r := gjson.ParseBytes(data) arr := r.Array() @@ -163,6 +219,8 @@ type AuthEnvelope struct { Event Event } +func (_ AuthEnvelope) Label() string { return "AUTH" } + func (v *AuthEnvelope) UnmarshalJSON(data []byte) error { r := gjson.ParseBytes(data) arr := r.Array() diff --git a/relay.go b/relay.go index a4af3d0..48d4aaf 100644 --- a/relay.go +++ b/relay.go @@ -2,7 +2,6 @@ package nostr import ( "context" - "encoding/json" "fmt" "net/http" "sync" @@ -135,66 +134,49 @@ func (r *Relay) Connect(ctx context.Context) error { break } - if len(message) == 0 || message[0] != '[' { + envelope := ParseMessage(message) + if envelope == nil { continue } - var jsonMessage []json.RawMessage - err = json.Unmarshal(message, &jsonMessage) - if err != nil { - continue - } - - if len(jsonMessage) < 2 { - continue - } - - var command string - json.Unmarshal(jsonMessage[0], &command) - - switch command { - case "NOTICE": - debugLog("{%s} %v\n", r.URL, jsonMessage) - var content string - json.Unmarshal(jsonMessage[1], &content) + switch env := envelope.(type) { + case *NoticeEnvelope: + debugLog("{%s} %v\n", r.URL, string(message)) + // TODO: improve this, otherwise if the application doesn't read the notices + // we'll consume ever more memory with each new notice go func() { r.mutex.RLock() if r.ConnectionContext.Err() == nil { - r.Notices <- content + r.Notices <- string(*env) } r.mutex.RUnlock() }() - case "AUTH": - debugLog("{%s} %v\n", r.URL, jsonMessage) - var challenge string - json.Unmarshal(jsonMessage[1], &challenge) - go func() { - r.mutex.RLock() - if r.ConnectionContext.Err() == nil { - r.Challenges <- challenge - } - r.mutex.RUnlock() - }() - case "EVENT": - debugLog("{%s} %v\n", r.URL, jsonMessage) - if len(jsonMessage) < 3 { + case *AuthEnvelope: + debugLog("{%s} %v\n", r.URL, string(message)) + if env.Challenge == nil { continue } - - var subId string - json.Unmarshal(jsonMessage[1], &subId) - if subscription, ok := r.subscriptions.Load(subId); !ok { - InfoLogger.Printf("{%s} no subscription with id '%s'\n", r.URL, subId) + // TODO: same as with NoticeEnvelope + go func() { + r.mutex.RLock() + if r.ConnectionContext.Err() == nil { + r.Challenges <- *env.Challenge + } + r.mutex.RUnlock() + }() + case *EventEnvelope: + debugLog("{%s} %v\n", r.URL, string(message)) + if env.SubscriptionID == nil { + continue + } + if subscription, ok := r.subscriptions.Load(*env.SubscriptionID); !ok { + InfoLogger.Printf("{%s} no subscription with id '%s'\n", r.URL, *env.SubscriptionID) continue } else { func() { - // decode event - var event Event - json.Unmarshal(jsonMessage[2], &event) - // check if the event matches the desired filter, ignore otherwise - if !subscription.Filters.Match(&event) { - InfoLogger.Printf("{%s} filter does not match: %v ~ %v\n", r.URL, subscription.Filters[0], event) + if !subscription.Filters.Match(&env.Event) { + InfoLogger.Printf("{%s} filter does not match: %v ~ %v\n", r.URL, subscription.Filters[0], env.Event) return } @@ -206,7 +188,7 @@ func (r *Relay) Connect(ctx context.Context) error { // check signature, ignore invalid, except from trusted (AssumeValid) relays if !r.AssumeValid { - if ok, err := event.CheckSignature(); !ok { + if ok, err := env.Event.CheckSignature(); !ok { errmsg := "" if err != nil { errmsg = err.Error() @@ -216,40 +198,19 @@ func (r *Relay) Connect(ctx context.Context) error { } } - subscription.Events <- &event + subscription.Events <- &env.Event }() } - case "EOSE": - if len(jsonMessage) < 2 { - continue - } - debugLog("{%s} %v\n", r.URL, jsonMessage) - var subId string - json.Unmarshal(jsonMessage[1], &subId) - if subscription, ok := r.subscriptions.Load(subId); ok { + case *EOSEEnvelope: + debugLog("{%s} %v\n", r.URL, string(message)) + if subscription, ok := r.subscriptions.Load(string(*env)); ok { subscription.emitEose.Do(func() { subscription.EndOfStoredEvents <- struct{}{} }) } - case "OK": - if len(jsonMessage) < 3 { - continue - } - debugLog("{%s} %v\n", r.URL, jsonMessage) - var ( - eventId string - ok bool - msg string - ) - json.Unmarshal(jsonMessage[1], &eventId) - json.Unmarshal(jsonMessage[2], &ok) - - if len(jsonMessage) > 3 { - json.Unmarshal(jsonMessage[3], &msg) - } - - if okCallback, exist := r.okCallbacks.Load(eventId); exist { - okCallback(ok, msg) + case *OKEnvelope: + if okCallback, exist := r.okCallbacks.Load(env.EventID); exist { + okCallback(env.OK, *env.Reason) } } }