ParseMessage() returns an Envelope, use that on the main relay handler loop.

This commit is contained in:
fiatjaf 2023-05-09 16:55:21 -03:00
parent f7ce78d7f8
commit d36fbb95b9
No known key found for this signature in database
GPG Key ID: BAD43C4BE5C1A3A1
2 changed files with 94 additions and 75 deletions

View File

@ -1,6 +1,7 @@
package nostr package nostr
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -9,11 +10,58 @@ import (
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
func ParseMessage(message []byte) Envelope {
firstComma := bytes.Index(message, []byte{','})
if firstComma == -1 {
return nil
}
label := message[0:firstComma]
var v Envelope
switch {
case bytes.Contains(label, []byte("EVENT")):
v = &EventEnvelope{}
case bytes.Contains(label, []byte("REQ")):
v = &ReqEnvelope{}
case bytes.Contains(label, []byte("NOTICE")):
x := NoticeEnvelope("")
v = &x
case bytes.Contains(label, []byte("EOSE")):
x := EOSEEnvelope("")
v = &x
case bytes.Contains(label, []byte("OK")):
v = &OKEnvelope{}
case bytes.Contains(label, []byte("AUTH")):
v = &AuthEnvelope{}
}
if err := v.UnmarshalJSON(message); err != nil {
return nil
}
return v
}
type Envelope interface {
Label() string
UnmarshalJSON([]byte) error
MarshalJSON() ([]byte, error)
}
type EventEnvelope struct { type EventEnvelope struct {
SubscriptionID *string SubscriptionID *string
Event Event
} }
var (
_ Envelope = (*EventEnvelope)(nil)
_ Envelope = (*ReqEnvelope)(nil)
_ Envelope = (*NoticeEnvelope)(nil)
_ Envelope = (*EOSEEnvelope)(nil)
_ Envelope = (*OKEnvelope)(nil)
_ Envelope = (*AuthEnvelope)(nil)
)
func (_ EventEnvelope) Label() string { return "EVENT" }
func (v *EventEnvelope) UnmarshalJSON(data []byte) error { func (v *EventEnvelope) UnmarshalJSON(data []byte) error {
r := gjson.ParseBytes(data) r := gjson.ParseBytes(data)
arr := r.Array() arr := r.Array()
@ -44,6 +92,8 @@ type ReqEnvelope struct {
Filters Filters
} }
func (_ ReqEnvelope) Label() string { return "REQ" }
func (v *ReqEnvelope) UnmarshalJSON(data []byte) error { func (v *ReqEnvelope) UnmarshalJSON(data []byte) error {
r := gjson.ParseBytes(data) r := gjson.ParseBytes(data)
arr := r.Array() arr := r.Array()
@ -77,6 +127,8 @@ func (v ReqEnvelope) MarshalJSON() ([]byte, error) {
type NoticeEnvelope string type NoticeEnvelope string
func (_ NoticeEnvelope) Label() string { return "NOTICE" }
func (v *NoticeEnvelope) UnmarshalJSON(data []byte) error { func (v *NoticeEnvelope) UnmarshalJSON(data []byte) error {
r := gjson.ParseBytes(data) r := gjson.ParseBytes(data)
arr := r.Array() arr := r.Array()
@ -99,6 +151,8 @@ func (v NoticeEnvelope) MarshalJSON() ([]byte, error) {
type EOSEEnvelope string type EOSEEnvelope string
func (_ EOSEEnvelope) Label() string { return "EOSE" }
func (v *EOSEEnvelope) UnmarshalJSON(data []byte) error { func (v *EOSEEnvelope) UnmarshalJSON(data []byte) error {
r := gjson.ParseBytes(data) r := gjson.ParseBytes(data)
arr := r.Array() arr := r.Array()
@ -125,6 +179,8 @@ type OKEnvelope struct {
Reason *string Reason *string
} }
func (_ OKEnvelope) Label() string { return "OK" }
func (v *OKEnvelope) UnmarshalJSON(data []byte) error { func (v *OKEnvelope) UnmarshalJSON(data []byte) error {
r := gjson.ParseBytes(data) r := gjson.ParseBytes(data)
arr := r.Array() arr := r.Array()
@ -163,6 +219,8 @@ type AuthEnvelope struct {
Event Event Event Event
} }
func (_ AuthEnvelope) Label() string { return "AUTH" }
func (v *AuthEnvelope) UnmarshalJSON(data []byte) error { func (v *AuthEnvelope) UnmarshalJSON(data []byte) error {
r := gjson.ParseBytes(data) r := gjson.ParseBytes(data)
arr := r.Array() arr := r.Array()

View File

@ -2,7 +2,6 @@ package nostr
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"sync" "sync"
@ -135,66 +134,49 @@ func (r *Relay) Connect(ctx context.Context) error {
break break
} }
if len(message) == 0 || message[0] != '[' { envelope := ParseMessage(message)
if envelope == nil {
continue continue
} }
var jsonMessage []json.RawMessage switch env := envelope.(type) {
err = json.Unmarshal(message, &jsonMessage) case *NoticeEnvelope:
if err != nil { debugLog("{%s} %v\n", r.URL, string(message))
continue // TODO: improve this, otherwise if the application doesn't read the notices
} // we'll consume ever more memory with each new notice
if len(jsonMessage) < 2 {
continue
}
var command string
json.Unmarshal(jsonMessage[0], &command)
switch command {
case "NOTICE":
debugLog("{%s} %v\n", r.URL, jsonMessage)
var content string
json.Unmarshal(jsonMessage[1], &content)
go func() { go func() {
r.mutex.RLock() r.mutex.RLock()
if r.ConnectionContext.Err() == nil { if r.ConnectionContext.Err() == nil {
r.Notices <- content r.Notices <- string(*env)
} }
r.mutex.RUnlock() r.mutex.RUnlock()
}() }()
case "AUTH": case *AuthEnvelope:
debugLog("{%s} %v\n", r.URL, jsonMessage) debugLog("{%s} %v\n", r.URL, string(message))
var challenge string if env.Challenge == nil {
json.Unmarshal(jsonMessage[1], &challenge) continue
}
// TODO: same as with NoticeEnvelope
go func() { go func() {
r.mutex.RLock() r.mutex.RLock()
if r.ConnectionContext.Err() == nil { if r.ConnectionContext.Err() == nil {
r.Challenges <- challenge r.Challenges <- *env.Challenge
} }
r.mutex.RUnlock() r.mutex.RUnlock()
}() }()
case "EVENT": case *EventEnvelope:
debugLog("{%s} %v\n", r.URL, jsonMessage) debugLog("{%s} %v\n", r.URL, string(message))
if len(jsonMessage) < 3 { if env.SubscriptionID == nil {
continue continue
} }
if subscription, ok := r.subscriptions.Load(*env.SubscriptionID); !ok {
var subId string InfoLogger.Printf("{%s} no subscription with id '%s'\n", r.URL, *env.SubscriptionID)
json.Unmarshal(jsonMessage[1], &subId)
if subscription, ok := r.subscriptions.Load(subId); !ok {
InfoLogger.Printf("{%s} no subscription with id '%s'\n", r.URL, subId)
continue continue
} else { } else {
func() { func() {
// decode event
var event Event
json.Unmarshal(jsonMessage[2], &event)
// check if the event matches the desired filter, ignore otherwise // check if the event matches the desired filter, ignore otherwise
if !subscription.Filters.Match(&event) { if !subscription.Filters.Match(&env.Event) {
InfoLogger.Printf("{%s} filter does not match: %v ~ %v\n", r.URL, subscription.Filters[0], event) InfoLogger.Printf("{%s} filter does not match: %v ~ %v\n", r.URL, subscription.Filters[0], env.Event)
return return
} }
@ -206,7 +188,7 @@ func (r *Relay) Connect(ctx context.Context) error {
// check signature, ignore invalid, except from trusted (AssumeValid) relays // check signature, ignore invalid, except from trusted (AssumeValid) relays
if !r.AssumeValid { if !r.AssumeValid {
if ok, err := event.CheckSignature(); !ok { if ok, err := env.Event.CheckSignature(); !ok {
errmsg := "" errmsg := ""
if err != nil { if err != nil {
errmsg = err.Error() errmsg = err.Error()
@ -216,40 +198,19 @@ func (r *Relay) Connect(ctx context.Context) error {
} }
} }
subscription.Events <- &event subscription.Events <- &env.Event
}() }()
} }
case "EOSE": case *EOSEEnvelope:
if len(jsonMessage) < 2 { debugLog("{%s} %v\n", r.URL, string(message))
continue if subscription, ok := r.subscriptions.Load(string(*env)); ok {
}
debugLog("{%s} %v\n", r.URL, jsonMessage)
var subId string
json.Unmarshal(jsonMessage[1], &subId)
if subscription, ok := r.subscriptions.Load(subId); ok {
subscription.emitEose.Do(func() { subscription.emitEose.Do(func() {
subscription.EndOfStoredEvents <- struct{}{} subscription.EndOfStoredEvents <- struct{}{}
}) })
} }
case "OK": case *OKEnvelope:
if len(jsonMessage) < 3 { if okCallback, exist := r.okCallbacks.Load(env.EventID); exist {
continue okCallback(env.OK, *env.Reason)
}
debugLog("{%s} %v\n", r.URL, jsonMessage)
var (
eventId string
ok bool
msg string
)
json.Unmarshal(jsonMessage[1], &eventId)
json.Unmarshal(jsonMessage[2], &ok)
if len(jsonMessage) > 3 {
json.Unmarshal(jsonMessage[3], &msg)
}
if okCallback, exist := r.okCallbacks.Load(eventId); exist {
okCallback(ok, msg)
} }
} }
} }