diff --git a/event.go b/event.go index fee07c0..7d63f39 100644 --- a/event.go +++ b/event.go @@ -7,6 +7,8 @@ import ( "encoding/hex" "encoding/json" "fmt" + "strconv" + "time" "github.com/fiatjaf/bip340" ) @@ -24,7 +26,7 @@ type Event struct { ID string `json:"id"` // it's the hash of the serialized event PubKey string `json:"pubkey"` - CreatedAt uint32 `json:"created_at"` + CreatedAt Time `json:"created_at"` Kind int `json:"kind"` @@ -33,6 +35,24 @@ type Event struct { Sig string `json:"sig"` } +type Time time.Time + +func (tm *Time) UnmarshalJSON(payload []byte) error { + var unix int64 + err := json.Unmarshal(payload, &unix) + if err != nil { + return fmt.Errorf("time must be a unix timestamp as an integer, not '%s': %w", + string(payload), err) + } + t := Time(time.Unix(unix, 0)) + tm = &t + return nil +} + +func (t Time) MarshalJSON() ([]byte, error) { + return []byte(strconv.FormatInt(time.Time(t).Unix(), 10)), nil +} + // Serialize outputs a byte array that can be hashed/signed to identify/authenticate func (evt *Event) Serialize() []byte { // the serialization process is just putting everything into a JSON array @@ -46,7 +66,7 @@ func (evt *Event) Serialize() []byte { arr[1] = evt.PubKey // created_at - arr[2] = int64(evt.CreatedAt) + arr[2] = int64(time.Time(evt.CreatedAt).Unix()) // kind arr[3] = int64(evt.Kind) diff --git a/filter.go b/filter.go index 72b7acf..2507373 100644 --- a/filter.go +++ b/filter.go @@ -1,18 +1,21 @@ package nostr -type EventFilters []EventFilter +import ( + "time" +) -type EventFilter struct { - IDs StringList `json:"ids,omitempty"` - Kinds IntList `json:"kinds,omitempty"` - Authors StringList `json:"authors,omitempty"` - Since uint32 `json:"since,omitempty"` - Until uint32 `json:"until,omitempty"` - TagE StringList `json:"#e,omitempty"` - TagP StringList `json:"#p,omitempty"` +type Filters []Filter + +type Filter struct { + IDs StringList + Kinds IntList + Authors StringList + Since *time.Time + Until *time.Time + Tags map[string]StringList } -func (eff EventFilters) Match(event *Event) bool { +func (eff Filters) Match(event *Event) bool { for _, filter := range eff { if filter.Matches(event) { return true @@ -21,7 +24,7 @@ func (eff EventFilters) Match(event *Event) bool { return false } -func (ef EventFilter) Matches(event *Event) bool { +func (ef Filter) Matches(event *Event) bool { if event == nil { return false } @@ -38,26 +41,24 @@ func (ef EventFilter) Matches(event *Event) bool { return false } - if ef.TagE != nil && !event.Tags.ContainsAny("e", ef.TagE) { + for f, v := range ef.Tags { + if v != nil && !event.Tags.ContainsAny(f, v) { + return false + } + } + + if ef.Since != nil && time.Time(event.CreatedAt).Before(*ef.Since) { return false } - if ef.TagP != nil && !event.Tags.ContainsAny("p", ef.TagP) { - return false - } - - if ef.Since != 0 && event.CreatedAt < ef.Since { - return false - } - - if ef.Until != 0 && event.CreatedAt >= ef.Until { + if ef.Until != nil && time.Time(event.CreatedAt).After(*ef.Until) { return false } return true } -func FilterEqual(a EventFilter, b EventFilter) bool { +func FilterEqual(a Filter, b Filter) bool { if !a.Kinds.Equals(b.Kinds) { return false } @@ -70,12 +71,18 @@ func FilterEqual(a EventFilter, b EventFilter) bool { return false } - if !a.TagE.Equals(b.TagE) { + if len(a.Tags) != len(b.Tags) { return false } - if !a.TagP.Equals(b.TagP) { - return false + for f, av := range a.Tags { + if bv, ok := b.Tags[f]; !ok { + return false + } else { + if !av.Equals(bv) { + return false + } + } } if a.Since != b.Since { diff --git a/filter_aux.go b/filter_aux.go index b18b824..78cbe2c 100644 --- a/filter_aux.go +++ b/filter_aux.go @@ -1,5 +1,13 @@ package nostr +import ( + "fmt" + "strings" + "time" + + "github.com/valyala/fastjson" +) + type StringList []string type IntList []int @@ -62,3 +70,146 @@ func (haystack IntList) Contains(needle int) bool { } return false } + +func (f *Filter) UnmarshalJSON(payload []byte) error { + var fastjsonParser fastjson.Parser + parsed, err := fastjsonParser.ParseBytes(payload) + if err != nil { + return fmt.Errorf("failed to parse filter: %w", err) + } + + obj, err := parsed.Object() + if err != nil { + return fmt.Errorf("filter is not an object") + } + + f.Tags = make(map[string]StringList) + + var visiterr error + obj.Visit(func(k []byte, v *fastjson.Value) { + key := string(k) + switch key { + case "ids": + f.IDs, err = fastjsonArrayToStringList(v) + if err != nil { + visiterr = fmt.Errorf("invalid 'ids' field: %w", err) + } + case "kinds": + f.Kinds, err = fastjsonArrayToIntList(v) + if err != nil { + visiterr = fmt.Errorf("invalid 'kinds' field: %w", err) + } + case "authors": + f.Authors, err = fastjsonArrayToStringList(v) + if err != nil { + visiterr = fmt.Errorf("invalid 'authors' field: %w", err) + } + case "since": + val, err := v.Int64() + if err != nil { + visiterr = fmt.Errorf("invalid 'since' field: %w", err) + } + tm := time.Unix(val, 0) + f.Since = &tm + case "until": + val, err := v.Int64() + if err != nil { + visiterr = fmt.Errorf("invalid 'until' field: %w", err) + } + tm := time.Unix(val, 0) + f.Until = &tm + default: + if strings.HasPrefix(key, "#") { + f.Tags[key[1:]], err = fastjsonArrayToStringList(v) + if err != nil { + visiterr = fmt.Errorf("invalid '%s' field: %w", key, err) + } + } + } + }) + if visiterr != nil { + return visiterr + } + + return nil +} + +func (f Filter) MarshalJSON() ([]byte, error) { + var arena fastjson.Arena + + o := arena.NewObject() + + if f.IDs != nil { + o.Set("ids", stringListToFastjsonArray(&arena, f.IDs)) + } + if f.Kinds != nil { + o.Set("kinds", intListToFastjsonArray(&arena, f.Kinds)) + } + if f.Authors != nil { + o.Set("authors", stringListToFastjsonArray(&arena, f.Authors)) + } + if f.Since != nil { + o.Set("since", arena.NewNumberInt(int(f.Since.Unix()))) + } + if f.Until != nil { + o.Set("until", arena.NewNumberInt(int(f.Until.Unix()))) + } + if f.Tags != nil { + for k, v := range f.Tags { + o.Set("#"+k, stringListToFastjsonArray(&arena, v)) + } + } + + return o.MarshalTo(nil), nil +} + +func stringListToFastjsonArray(arena *fastjson.Arena, sl StringList) *fastjson.Value { + arr := arena.NewArray() + for i, v := range sl { + arr.SetArrayItem(i, arena.NewString(v)) + } + return arr +} + +func intListToFastjsonArray(arena *fastjson.Arena, il IntList) *fastjson.Value { + arr := arena.NewArray() + for i, v := range il { + arr.SetArrayItem(i, arena.NewNumberInt(v)) + } + return arr +} + +func fastjsonArrayToStringList(v *fastjson.Value) (StringList, error) { + arr, err := v.Array() + if err != nil { + return nil, err + } + + sl := make(StringList, len(arr)) + for i, v := range arr { + sb, err := v.StringBytes() + if err != nil { + return nil, err + } + sl[i] = string(sb) + } + + return sl, nil +} + +func fastjsonArrayToIntList(v *fastjson.Value) (IntList, error) { + arr, err := v.Array() + if err != nil { + return nil, err + } + + il := make(IntList, len(arr)) + for i, v := range arr { + il[i], err = v.Int() + if err != nil { + return nil, err + } + } + + return il, nil +} diff --git a/go.mod b/go.mod index b28cc64..cabf8b9 100644 --- a/go.mod +++ b/go.mod @@ -8,4 +8,5 @@ require ( github.com/gorilla/websocket v1.4.2 github.com/tyler-smith/go-bip32 v1.0.0 github.com/tyler-smith/go-bip39 v1.1.0 + github.com/valyala/fastjson v1.6.3 // indirect ) diff --git a/go.sum b/go.sum index 23c92af..d4552c7 100644 --- a/go.sum +++ b/go.sum @@ -47,6 +47,8 @@ github.com/tyler-smith/go-bip32 v1.0.0 h1:sDR9juArbUgX+bO/iblgZnMPeWY1KZMUC2AFUJ github.com/tyler-smith/go-bip32 v1.0.0/go.mod h1:onot+eHknzV4BVPwrzqY5OoVpyCvnwD7lMawL5aQupE= github.com/tyler-smith/go-bip39 v1.1.0 h1:5eUemwrMargf3BSLRRCalXT93Ns6pQJIjYQN2nyfOP8= github.com/tyler-smith/go-bip39 v1.1.0/go.mod h1:gUYDtqQw1JS3ZJ8UWVcGTGqqr6YIN3CWg+kkNaLt55U= +github.com/valyala/fastjson v1.6.3 h1:tAKFnnwmeMGPbwJ7IwxcTPCNr3uIzoIj3/Fh90ra4xc= +github.com/valyala/fastjson v1.6.3/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= golang.org/x/crypto v0.0.0-20170613210332-850760c427c5/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20170930174604-9419663f5a44/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190103213133-ff983b9c42bc/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= diff --git a/relaypool.go b/relaypool.go index 550a7e9..3138aae 100644 --- a/relaypool.go +++ b/relaypool.go @@ -35,7 +35,7 @@ type RelayPool struct { } type RelayPoolPolicy interface { - ShouldRead(EventFilters) bool + ShouldRead(Filters) bool ShouldWrite(*Event) bool } @@ -44,7 +44,7 @@ type SimplePolicy struct { Write bool } -func (s SimplePolicy) ShouldRead(_ EventFilters) bool { +func (s SimplePolicy) ShouldRead(_ Filters) bool { return s.Read } @@ -181,7 +181,7 @@ func (r *RelayPool) Remove(url string) { delete(r.websockets, nm) } -func (r *RelayPool) Sub(filters EventFilters) *Subscription { +func (r *RelayPool) Sub(filters Filters) *Subscription { random := make([]byte, 7) rand.Read(random) @@ -237,7 +237,7 @@ func (r *RelayPool) PublishEvent(evt *Event) (*Event, chan PublishStatus, error) } status <- PublishStatus{relay, PublishStatusSent} - subscription := r.Sub(EventFilters{EventFilter{IDs: []string{evt.ID}}}) + subscription := r.Sub(Filters{Filter{IDs: []string{evt.ID}}}) for { select { case event := <-subscription.UniqueEvents: diff --git a/subscription.go b/subscription.go index 880bc73..4c41b99 100644 --- a/subscription.go +++ b/subscription.go @@ -4,7 +4,7 @@ type Subscription struct { channel string relays map[string]*Connection - filters EventFilters + filters Filters Events chan EventMessage started bool