diff --git a/go.mod b/go.mod index bb90aaf..5bbbd90 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.21.0 require ( github.com/fasthttp/websocket v1.5.3 github.com/fiatjaf/eventstore v0.1.0 - github.com/nbd-wtf/go-nostr v0.25.7 + github.com/nbd-wtf/go-nostr v0.26.0 github.com/puzpuzpuz/xsync/v2 v2.5.1 github.com/rs/cors v1.7.0 golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 diff --git a/handlers.go b/handlers.go index d4150c3..c257f02 100644 --- a/handlers.go +++ b/handlers.go @@ -96,135 +96,93 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { go func(message []byte) { ctx = context.Background() - var request []json.RawMessage - if err := json.Unmarshal(message, &request); err != nil { + envelope := nostr.ParseMessage(message) + if envelope == nil { // stop silently return } - if len(request) < 2 { - ws.WriteJSON(nostr.NoticeEnvelope("request has less than 2 parameters")) - return - } - - var typ string - json.Unmarshal(request[0], &typ) - - switch typ { - case "EVENT": - // it's a new event - var evt nostr.Event - if err := json.Unmarshal(request[1], &evt); err != nil { - ws.WriteJSON(nostr.NoticeEnvelope("failed to decode event: " + err.Error())) - return - } - + switch env := envelope.(type) { + case *nostr.EventEnvelope: // check id - hash := sha256.Sum256(evt.Serialize()) + hash := sha256.Sum256(env.Event.Serialize()) id := hex.EncodeToString(hash[:]) - if id != evt.ID { - ws.WriteJSON(nostr.OKEnvelope{EventID: evt.ID, OK: false, Reason: "invalid: id is computed incorrectly"}) + if id != env.Event.ID { + ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "invalid: id is computed incorrectly"}) return } // check signature - if ok, err := evt.CheckSignature(); err != nil { - ws.WriteJSON(nostr.OKEnvelope{EventID: evt.ID, OK: false, Reason: "error: failed to verify signature"}) + if ok, err := env.Event.CheckSignature(); err != nil { + ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "error: failed to verify signature"}) return } else if !ok { - ws.WriteJSON(nostr.OKEnvelope{EventID: evt.ID, OK: false, Reason: "invalid: signature is invalid"}) + ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "invalid: signature is invalid"}) return } var ok bool - if evt.Kind == 5 { - err = rl.handleDeleteRequest(ctx, &evt) + if env.Event.Kind == 5 { + err = rl.handleDeleteRequest(ctx, &env.Event) } else { - err = rl.AddEvent(ctx, &evt) + err = rl.AddEvent(ctx, &env.Event) } var reason string if err == nil { ok = true } else { - reason = err.Error() + reason = nostr.NormalizeOKMessage(err.Error(), "blocked") } - ws.WriteJSON(nostr.OKEnvelope{EventID: evt.ID, OK: ok, Reason: reason}) - case "COUNT": + ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: ok, Reason: reason}) + case *nostr.CountEnvelope: if rl.CountEvents == nil { - ws.WriteJSON(nostr.NoticeEnvelope("this relay does not support NIP-45")) + ws.WriteJSON(nostr.ClosedEnvelope{SubscriptionID: env.SubscriptionID, Reason: "unsupported: this relay does not support NIP-45"}) return } - - var id string - json.Unmarshal(request[1], &id) - if id == "" { - ws.WriteJSON(nostr.NoticeEnvelope("COUNT has no ")) - return - } - var total int64 - filters := make(nostr.Filters, len(request)-2) - for i, filterReq := range request[2:] { - if err := json.Unmarshal(filterReq, &filters[i]); err != nil { - ws.WriteJSON(nostr.NoticeEnvelope("failed to decode filter")) - continue - } - total += rl.handleCountRequest(ctx, ws, filters[i]) + for _, filter := range env.Filters { + total += rl.handleCountRequest(ctx, ws, filter) } - - ws.WriteJSON([]interface{}{"COUNT", id, map[string]int64{"count": total}}) - case "REQ": - var id string - json.Unmarshal(request[1], &id) - if id == "" { - ws.WriteJSON(nostr.NoticeEnvelope("REQ has no ")) - return - } - - filters := make(nostr.Filters, len(request)-2) + ws.WriteJSON(nostr.CountEnvelope{SubscriptionID: env.SubscriptionID, Count: &total}) + case *nostr.ReqEnvelope: eose := sync.WaitGroup{} - eose.Add(len(request[2:])) + eose.Add(len(env.Filters)) - for i, filterReq := range request[2:] { - if err := json.Unmarshal(filterReq, &filters[i]); err != nil { - ws.WriteJSON(nostr.NoticeEnvelope("failed to decode filter")) - eose.Done() - continue + isFullyRejected := true + var reason string + for _, filter := range env.Filters { + err := rl.handleRequest(ctx, env.SubscriptionID, &eose, ws, filter) + if err == nil { + isFullyRejected = false + } else { + reason = err.Error() } - - go rl.handleRequest(ctx, id, &eose, ws, filters[i]) + } + if isFullyRejected { + // this will be called only if all the filters were invalidated + reason = nostr.NormalizeOKMessage(reason, "blocked") + ws.WriteJSON(nostr.ClosedEnvelope{SubscriptionID: env.SubscriptionID, Reason: reason}) + return } go func() { eose.Wait() - ws.WriteJSON(nostr.EOSEEnvelope(id)) + ws.WriteJSON(nostr.EOSEEnvelope(env.SubscriptionID)) }() - setListener(id, ws, filters) - case "CLOSE": - var id string - json.Unmarshal(request[1], &id) - if id == "" { - ws.WriteJSON(nostr.NoticeEnvelope("CLOSE has no ")) - return - } - - removeListenerId(ws, id) - case "AUTH": + setListener(env.SubscriptionID, ws, env.Filters) + case *nostr.CloseEnvelope: + removeListenerId(ws, string(*env)) + case *nostr.AuthEnvelope: if rl.ServiceURL != "" { - var evt nostr.Event - if err := json.Unmarshal(request[1], &evt); err != nil { - ws.WriteJSON(nostr.NoticeEnvelope("failed to decode auth event: " + err.Error())) - return - } - if pubkey, ok := nip42.ValidateAuthEvent(&evt, ws.Challenge, rl.ServiceURL); ok { + if pubkey, ok := nip42.ValidateAuthEvent(&env.Event, ws.Challenge, rl.ServiceURL); ok { ws.Authed = pubkey close(ws.WaitingForAuth) ctx = context.WithValue(ctx, AUTH_CONTEXT_KEY, pubkey) - ws.WriteJSON(nostr.OKEnvelope{EventID: evt.ID, OK: true}) + ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: true}) } else { - ws.WriteJSON(nostr.OKEnvelope{EventID: evt.ID, OK: false, Reason: "error: failed to authenticate"}) + ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "error: failed to authenticate"}) } } } diff --git a/serve-req.go b/serve-req.go index 00a44bc..ee63fbc 100644 --- a/serve-req.go +++ b/serve-req.go @@ -2,12 +2,13 @@ package khatru import ( "context" + "fmt" "sync" "github.com/nbd-wtf/go-nostr" ) -func (rl *Relay) handleRequest(ctx context.Context, id string, eose *sync.WaitGroup, ws *WebSocket, filter nostr.Filter) { +func (rl *Relay) handleRequest(ctx context.Context, id string, eose *sync.WaitGroup, ws *WebSocket, filter nostr.Filter) error { defer eose.Done() // overwrite the filter (for example, to eliminate some kinds or @@ -17,7 +18,7 @@ func (rl *Relay) handleRequest(ctx context.Context, id string, eose *sync.WaitGr } if filter.Limit < 0 { - return + return fmt.Errorf("filter invalidated") } // then check if we'll reject this filter (we apply this after overwriting @@ -27,7 +28,7 @@ func (rl *Relay) handleRequest(ctx context.Context, id string, eose *sync.WaitGr for _, reject := range rl.RejectFilter { if reject, msg := reject(ctx, filter); reject { ws.WriteJSON(nostr.NoticeEnvelope(msg)) - return + return fmt.Errorf(msg) } } @@ -52,6 +53,8 @@ func (rl *Relay) handleRequest(ctx context.Context, id string, eose *sync.WaitGr eose.Done() }(ch) } + + return nil } func (rl *Relay) handleCountRequest(ctx context.Context, ws *WebSocket, filter nostr.Filter) int64 {