diff --git a/handlers.go b/handlers.go index 81dc11a..e100516 100644 --- a/handlers.go +++ b/handlers.go @@ -168,12 +168,12 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { filter := filters[i] - for _, reject := range rl.RejectFilter { - if rejecting, msg := reject(ctx, filter); rejecting { - ws.WriteJSON(nostr.NoticeEnvelope(msg)) - continue - } + // overwrite the filter (for example, to eliminate some kinds or tags that we know we don't support) + for _, ovw := range rl.OverwriteCountFilter { + ovw(ctx, &filter) } + + // then check if we'll reject this filter for _, reject := range rl.RejectCountFilter { if rejecting, msg := reject(ctx, filter); rejecting { ws.WriteJSON(nostr.NoticeEnvelope(msg)) @@ -181,6 +181,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { } } + // run the functions to count (generally it will be just one) for _, count := range rl.CountEvents { res, err := count(ctx, filter) if err != nil { @@ -215,14 +216,25 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { filter := filters[i] - for _, reject := range rl.RejectCountFilter { + // overwrite the filter (for example, to eliminate some kinds or + // that we know we don't support) + for _, ovw := range rl.OverwriteFilter { + ovw(ctx, &filter) + } + + // then check if we'll reject this filter (we apply this after overwriting + // because we may, for example, remove some things from the incoming filters + // that we know we don't support, and then if the end result is an empty + // filter we can just reject it) + for _, reject := range rl.RejectFilter { if rejecting, msg := reject(ctx, filter); rejecting { ws.WriteJSON(nostr.NoticeEnvelope(msg)) - eose.Done() continue } } + // run the functions to query events (generally just one, + // but we might be fetching stuff from multiple places) eose.Add(len(rl.QueryEvents)) for _, query := range rl.QueryEvents { ch, err := query(ctx, filter) diff --git a/plugins/filters.go b/plugins/filters.go index c74a4b6..676c022 100644 --- a/plugins/filters.go +++ b/plugins/filters.go @@ -2,26 +2,11 @@ package plugins import ( "context" - "fmt" "github.com/nbd-wtf/go-nostr" + "golang.org/x/exp/slices" ) -func NoPrefixFilters(ctx context.Context, filter nostr.Filter) (reject bool, msg string) { - for _, id := range filter.IDs { - if len(id) != 64 { - return true, fmt.Sprintf("filters can only contain full ids") - } - } - for _, pk := range filter.Authors { - if len(pk) != 64 { - return true, fmt.Sprintf("filters can only contain full pubkeys") - } - } - - return false, "" -} - func NoComplexFilters(ctx context.Context, filter nostr.Filter) (reject bool, msg string) { items := len(filter.Tags) + len(filter.Kinds) @@ -31,3 +16,49 @@ func NoComplexFilters(ctx context.Context, filter nostr.Filter) (reject bool, ms return false, "" } + +func NoEmptyFilters(ctx context.Context, filter nostr.Filter) (reject bool, msg string) { + c := len(filter.Kinds) + len(filter.IDs) + len(filter.Authors) + for _, tagItems := range filter.Tags { + c += len(tagItems) + } + if c == 0 { + return true, "can't handle empty filters" + } + return false, "" +} + +func NoSearchQueries(ctx context.Context, filter nostr.Filter) (reject bool, msg string) { + if filter.Search != "" { + return true, "search is not supported" + } + return false, "" +} + +func RemoveSearchQueries(ctx context.Context, filter *nostr.Filter) { + filter.Search = "" +} + +func RemoveKinds(kinds ...int) func(context.Context, *nostr.Filter) { + return func(ctx context.Context, filter *nostr.Filter) { + if n := len(filter.Kinds); n > 0 { + newKinds := make([]int, 0, n) + for i := 0; i < n; i++ { + if k := filter.Kinds[i]; !slices.Contains(kinds, k) { + newKinds = append(newKinds, k) + } + } + filter.Kinds = newKinds + } + } +} + +func RemoveTags(tagNames ...string) func(context.Context, *nostr.Filter) { + return func(ctx context.Context, filter *nostr.Filter) { + for tagName := range filter.Tags { + if slices.Contains(tagNames, tagName) { + delete(filter.Tags, tagName) + } + } + } +} diff --git a/relay.go b/relay.go index e8be0dc..76e44f2 100644 --- a/relay.go +++ b/relay.go @@ -46,6 +46,8 @@ type Relay struct { RejectCountFilter []func(ctx context.Context, filter nostr.Filter) (reject bool, msg string) OverwriteDeletionOutcome []func(ctx context.Context, target *nostr.Event, deletion *nostr.Event) (acceptDeletion bool, msg string) OverwriteResponseEvent []func(ctx context.Context, event *nostr.Event) + OverwriteFilter []func(ctx context.Context, filter *nostr.Filter) + OverwriteCountFilter []func(ctx context.Context, filter *nostr.Filter) StoreEvent []func(ctx context.Context, event *nostr.Event) error DeleteEvent []func(ctx context.Context, event *nostr.Event) error QueryEvents []func(ctx context.Context, filter nostr.Filter) (chan *nostr.Event, error)