From 4fb6fcd9a21345b07fd6046645bafde8603a34b6 Mon Sep 17 00:00:00 2001 From: fiatjaf Date: Wed, 5 Mar 2025 23:42:16 -0300 Subject: [PATCH] make simdjson great again. now it is generally a little faster than the easyjson approach. goos: linux goarch: amd64 pkg: github.com/nbd-wtf/go-nostr cpu: AMD Ryzen 3 3200G with Radeon Vega Graphics BenchmarkParseMessage/stdlib-4 90 15616341 ns/op BenchmarkParseMessage/easyjson-4 110 11306466 ns/op BenchmarkParseMessage/simdjson-4 162 7779856 ns/op PASS ok github.com/nbd-wtf/go-nostr 5.547s --- envelopes.go | 226 ------------------------------------ envelopes_benchmark_test.go | 20 +++- envelopes_simdjson.go | 210 +++++++++++++++++++++++++++++++++ envelopes_test.go | 5 +- event_simdjson.go | 29 ++--- filter_simdjson.go | 25 ++-- 6 files changed, 258 insertions(+), 257 deletions(-) create mode 100644 envelopes_simdjson.go diff --git a/envelopes.go b/envelopes.go index 9c3a45c..41ad2e2 100644 --- a/envelopes.go +++ b/envelopes.go @@ -9,7 +9,6 @@ import ( "github.com/mailru/easyjson" jwriter "github.com/mailru/easyjson/jwriter" - "github.com/minio/simdjson-go" "github.com/tidwall/gjson" ) @@ -27,55 +26,6 @@ var ( UnknownLabel = errors.New("unknown envelope label") ) -// ParseMessageSIMD parses a message using the experimental simdjson-go library. -func ParseMessageSIMD(message []byte, reuse *simdjson.ParsedJson) (Envelope, error) { - parsed, err := simdjson.Parse(message, reuse) - if err != nil { - return nil, fmt.Errorf("simdjson parse failed: %w", err) - } - - iter := parsed.Iter() - iter.AdvanceInto() - if t := iter.Advance(); t != simdjson.TypeArray { - return nil, fmt.Errorf("top-level must be an array") - } - arr, _ := iter.Array(nil) - iter = arr.Iter() - iter.Advance() - label, _ := iter.StringBytes() - - var v EnvelopeSIMD - - switch { - case bytes.Equal(label, labelEvent): - v = &EventEnvelope{} - case bytes.Equal(label, labelReq): - v = &ReqEnvelope{} - case bytes.Equal(label, labelCount): - v = &CountEnvelope{} - case bytes.Equal(label, labelNotice): - x := NoticeEnvelope("") - v = &x - case bytes.Equal(label, labelEose): - x := EOSEEnvelope("") - v = &x - case bytes.Equal(label, labelOk): - v = &OKEnvelope{} - case bytes.Equal(label, labelAuth): - v = &AuthEnvelope{} - case bytes.Equal(label, labelClosed): - v = &ClosedEnvelope{} - case bytes.Equal(label, labelClose): - x := CloseEnvelope("") - v = &x - default: - return nil, UnknownLabel - } - - err = v.UnmarshalSIMD(iter) - return v, err -} - // ParseMessage parses a message into an Envelope. func ParseMessage(message []byte) Envelope { firstComma := bytes.Index(message, []byte{','}) @@ -125,12 +75,6 @@ type Envelope interface { String() string } -// EnvelopeSIMD extends Envelope with SIMD unmarshaling capability. -type EnvelopeSIMD interface { - Envelope - UnmarshalSIMD(simdjson.Iter) error -} - var ( _ Envelope = (*EventEnvelope)(nil) _ Envelope = (*ReqEnvelope)(nil) @@ -164,25 +108,6 @@ func (v *EventEnvelope) UnmarshalJSON(data []byte) error { } } -func (v *EventEnvelope) UnmarshalSIMD(iter simdjson.Iter) error { - // we may or may not have a subscription ID, so peek - if iter.PeekNext() == simdjson.TypeString { - iter.Advance() - // we have a subscription ID - subID, err := iter.String() - if err != nil { - return err - } - v.SubscriptionID = &subID - } - - // now get the event - if typ := iter.Advance(); typ == simdjson.TypeNone { - return fmt.Errorf("missing event") - } - return v.Event.UnmarshalSIMD(&iter) -} - func (v EventEnvelope) MarshalJSON() ([]byte, error) { w := jwriter.Writer{NoEscapeHTML: true} w.RawString(`["EVENT",`) @@ -223,44 +148,6 @@ func (v *ReqEnvelope) UnmarshalJSON(data []byte) error { return nil } -func (v *ReqEnvelope) UnmarshalSIMD(iter simdjson.Iter) error { - var err error - - // we must have a subscription id - if typ := iter.Advance(); typ == simdjson.TypeString { - v.SubscriptionID, err = iter.String() - if err != nil { - return err - } - } else { - return fmt.Errorf("unexpected %s for REQ subscription id", typ) - } - - // now get the filters - v.Filters = make(Filters, 0, 1) - tempIter := &simdjson.Iter{} // make a new iterator here because there may come multiple filters - for { - if typ, err := iter.AdvanceIter(tempIter); err != nil { - return err - } else if typ == simdjson.TypeNone { - break - } else { - } - - var filter Filter - if err := filter.UnmarshalSIMD(tempIter); err != nil { - return err - } - v.Filters = append(v.Filters, filter) - } - - if len(v.Filters) == 0 { - return fmt.Errorf("need at least one filter") - } - - return nil -} - func (v ReqEnvelope) MarshalJSON() ([]byte, error) { w := jwriter.Writer{NoEscapeHTML: true} w.RawString(`["REQ","`) @@ -326,53 +213,6 @@ func (v *CountEnvelope) UnmarshalJSON(data []byte) error { return nil } -func (v *CountEnvelope) UnmarshalSIMD(iter simdjson.Iter) error { - var err error - - // this has two cases: - // in the first case (request from client) this is like REQ except with always one filter - // in the other (response from relay) we have a json object response - // but both cases start with a subscription id - - if typ := iter.Advance(); typ == simdjson.TypeString { - v.SubscriptionID, err = iter.String() - if err != nil { - return err - } - } else { - return fmt.Errorf("unexpected %s for COUNT subscription id", typ) - } - - // now get either a single filter or stuff from the json object - if typ := iter.Advance(); typ == simdjson.TypeNone { - return fmt.Errorf("missing json object") - } - - if el, err := iter.FindElement(nil, "count"); err == nil { - c, _ := el.Iter.Uint() - count := int64(c) - v.Count = &count - if el, err = iter.FindElement(nil, "hll"); err == nil { - if hllHex, err := el.Iter.StringBytes(); err != nil || len(hllHex) != 512 { - return fmt.Errorf("hll is malformed") - } else { - v.HyperLogLog = make([]byte, 256) - if _, err := hex.Decode(v.HyperLogLog, hllHex); err != nil { - return fmt.Errorf("hll is invalid hex") - } - } - } - } else { - var filter Filter - if err := filter.UnmarshalSIMD(&iter); err != nil { - return err - } - v.Filters = Filters{filter} - } - - return nil -} - func (v CountEnvelope) MarshalJSON() ([]byte, error) { w := jwriter.Writer{NoEscapeHTML: true} w.RawString(`["COUNT","`) @@ -418,14 +258,6 @@ func (v *NoticeEnvelope) UnmarshalJSON(data []byte) error { return nil } -func (v *NoticeEnvelope) UnmarshalSIMD(iter simdjson.Iter) error { - if typ := iter.Advance(); typ == simdjson.TypeString { - msg, _ := iter.String() - *v = NoticeEnvelope(msg) - } - return nil -} - func (v NoticeEnvelope) MarshalJSON() ([]byte, error) { w := jwriter.Writer{NoEscapeHTML: true} w.RawString(`["NOTICE",`) @@ -453,14 +285,6 @@ func (v *EOSEEnvelope) UnmarshalJSON(data []byte) error { return nil } -func (v *EOSEEnvelope) UnmarshalSIMD(iter simdjson.Iter) error { - if typ := iter.Advance(); typ == simdjson.TypeString { - msg, _ := iter.String() - *v = EOSEEnvelope(msg) - } - return nil -} - func (v EOSEEnvelope) MarshalJSON() ([]byte, error) { w := jwriter.Writer{NoEscapeHTML: true} w.RawString(`["EOSE",`) @@ -490,14 +314,6 @@ func (v *CloseEnvelope) UnmarshalJSON(data []byte) error { } } -func (v *CloseEnvelope) UnmarshalSIMD(iter simdjson.Iter) error { - if typ := iter.Advance(); typ == simdjson.TypeString { - msg, _ := iter.String() - *v = CloseEnvelope(msg) - } - return nil -} - func (v CloseEnvelope) MarshalJSON() ([]byte, error) { w := jwriter.Writer{NoEscapeHTML: true} w.RawString(`["CLOSE",`) @@ -530,16 +346,6 @@ func (v *ClosedEnvelope) UnmarshalJSON(data []byte) error { } } -func (v *ClosedEnvelope) UnmarshalSIMD(iter simdjson.Iter) error { - if typ := iter.Advance(); typ == simdjson.TypeString { - v.SubscriptionID, _ = iter.String() - } - if typ := iter.Advance(); typ == simdjson.TypeString { - v.Reason, _ = iter.String() - } - return nil -} - func (v ClosedEnvelope) MarshalJSON() ([]byte, error) { w := jwriter.Writer{NoEscapeHTML: true} w.RawString(`["CLOSED",`) @@ -576,23 +382,6 @@ func (v *OKEnvelope) UnmarshalJSON(data []byte) error { return nil } -func (v *OKEnvelope) UnmarshalSIMD(iter simdjson.Iter) error { - if typ := iter.Advance(); typ == simdjson.TypeString { - v.EventID, _ = iter.String() - } else { - return fmt.Errorf("unexpected %s for OK id", typ) - } - if typ := iter.Advance(); typ == simdjson.TypeBool { - v.OK, _ = iter.Bool() - } else { - return fmt.Errorf("unexpected %s for OK status", typ) - } - if typ := iter.Advance(); typ == simdjson.TypeString { - v.Reason, _ = iter.String() - } - return nil -} - func (v OKEnvelope) MarshalJSON() ([]byte, error) { w := jwriter.Writer{NoEscapeHTML: true} w.RawString(`["OK","`) @@ -635,21 +424,6 @@ func (v *AuthEnvelope) UnmarshalJSON(data []byte) error { return nil } -func (v *AuthEnvelope) UnmarshalSIMD(iter simdjson.Iter) error { - if typ := iter.Advance(); typ == simdjson.TypeString { - // we have a challenge - subID, err := iter.String() - if err != nil { - return err - } - v.Challenge = &subID - return nil - } else { - // we have an event - return v.Event.UnmarshalSIMD(&iter) - } -} - func (v AuthEnvelope) MarshalJSON() ([]byte, error) { w := jwriter.Writer{NoEscapeHTML: true} w.RawString(`["AUTH",`) diff --git a/envelopes_benchmark_test.go b/envelopes_benchmark_test.go index 5666e0b..323454e 100644 --- a/envelopes_benchmark_test.go +++ b/envelopes_benchmark_test.go @@ -1,6 +1,7 @@ package nostr import ( + stdlibjson "encoding/json" "fmt" "math/rand/v2" "testing" @@ -12,7 +13,16 @@ import ( func BenchmarkParseMessage(b *testing.B) { messages := generateTestMessages(2000) - b.Run("golang", func(b *testing.B) { + b.Run("stdlib", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, msg := range messages { + var v any + stdlibjson.Unmarshal(msg, &v) + } + } + }) + + b.Run("easyjson", func(b *testing.B) { for i := 0; i < b.N; i++ { for _, msg := range messages { _ = ParseMessage(msg) @@ -21,10 +31,10 @@ func BenchmarkParseMessage(b *testing.B) { }) b.Run("simdjson", func(b *testing.B) { - pj := &simdjson.ParsedJson{} + smp := SIMDMessageParser{ParsedJSON: &simdjson.ParsedJson{}, AuxIter: &simdjson.Iter{}} for i := 0; i < b.N; i++ { for _, msg := range messages { - _, _ = ParseMessageSIMD(msg, pj) + _, _ = smp.ParseMessage(msg) } } }) @@ -76,9 +86,9 @@ func generateRandomEvent() Event { tags := make(Tags, tagCount) for i := 0; i < tagCount; i++ { tagType := string([]byte{byte('a' + rand.IntN(26))}) - tagValues := make([]string, rand.IntN(5)+1) + tagValues := make([]string, rand.IntN(3)+1) for j := range tagValues { - tagValues[j] = fmt.Sprintf("value_%d_%d", i, j) + tagValues[j] = fmt.Sprintf("%d", j) } tags[i] = append([]string{tagType}, tagValues...) } diff --git a/envelopes_simdjson.go b/envelopes_simdjson.go new file mode 100644 index 0000000..72f8a2b --- /dev/null +++ b/envelopes_simdjson.go @@ -0,0 +1,210 @@ +package nostr + +import ( + "bytes" + "encoding/hex" + "fmt" + + "github.com/minio/simdjson-go" +) + +type SIMDMessageParser struct { + ParsedJSON *simdjson.ParsedJson + TopLevelArray *simdjson.Array // used for the top-level envelope + TargetObject *simdjson.Object // used for the event object itself, or for the count object, or the filter object + TargetInternalArray *simdjson.Array // used for tags array inside the event or each of the values in a filter + AuxArray *simdjson.Array // used either for each of the tags inside the event or for each of the multiple filters that may code + AuxIter *simdjson.Iter +} + +func (smp *SIMDMessageParser) ParseMessage(message []byte) (Envelope, error) { + var err error + + smp.ParsedJSON, err = simdjson.Parse(message, smp.ParsedJSON) + if err != nil { + return nil, fmt.Errorf("simdjson parse failed: %w", err) + } + + iter := smp.ParsedJSON.Iter() + iter.AdvanceInto() + if t := iter.Advance(); t != simdjson.TypeArray { + return nil, fmt.Errorf("top-level must be an array") + } + arr, _ := iter.Array(nil) + iter = arr.Iter() + iter.Advance() + label, _ := iter.StringBytes() + + switch { + case bytes.Equal(label, labelEvent): + v := &EventEnvelope{} + // we may or may not have a subscription ID, so peek + if iter.PeekNext() == simdjson.TypeString { + iter.Advance() + // we have a subscription ID + subID, err := iter.String() + if err != nil { + return nil, err + } + v.SubscriptionID = &subID + } + // now get the event + if typ := iter.Advance(); typ == simdjson.TypeNone { + return nil, fmt.Errorf("missing event") + } + + smp.TargetObject, smp.TargetInternalArray, smp.AuxArray, err = v.Event.UnmarshalSIMD( + &iter, smp.TargetObject, smp.TargetInternalArray, smp.AuxArray) + return v, err + case bytes.Equal(label, labelReq): + v := &ReqEnvelope{} + + // we must have a subscription id + if typ := iter.Advance(); typ == simdjson.TypeString { + v.SubscriptionID, err = iter.String() + if err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("unexpected %s for REQ subscription id", typ) + } + + // now get the filters + v.Filters = make(Filters, 0, 1) + for { + if typ, err := iter.AdvanceIter(smp.AuxIter); err != nil { + return nil, err + } else if typ == simdjson.TypeNone { + break + } else { + } + + var filter Filter + smp.TargetObject, smp.TargetInternalArray, err = filter.UnmarshalSIMD( + smp.AuxIter, smp.TargetObject, smp.TargetInternalArray) + if err != nil { + return nil, err + } + v.Filters = append(v.Filters, filter) + } + + if len(v.Filters) == 0 { + return nil, fmt.Errorf("need at least one filter") + } + + return v, nil + case bytes.Equal(label, labelCount): + v := &CountEnvelope{} + // this has two cases: + // in the first case (request from client) this is like REQ except with always one filter + // in the other (response from relay) we have a json object response + // but both cases start with a subscription id + + if typ := iter.Advance(); typ == simdjson.TypeString { + v.SubscriptionID, err = iter.String() + if err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("unexpected %s for COUNT subscription id", typ) + } + + // now get either a single filter or stuff from the json object + if typ := iter.Advance(); typ == simdjson.TypeNone { + return nil, fmt.Errorf("missing json object") + } + + if el, err := iter.FindElement(nil, "count"); err == nil { + c, _ := el.Iter.Uint() + count := int64(c) + v.Count = &count + if el, err = iter.FindElement(nil, "hll"); err == nil { + if hllHex, err := el.Iter.StringBytes(); err != nil || len(hllHex) != 512 { + return nil, fmt.Errorf("hll is malformed") + } else { + v.HyperLogLog = make([]byte, 256) + if _, err := hex.Decode(v.HyperLogLog, hllHex); err != nil { + return nil, fmt.Errorf("hll is invalid hex") + } + } + } + } else { + var filter Filter + smp.TargetObject, smp.TargetInternalArray, err = filter.UnmarshalSIMD( + &iter, smp.TargetObject, smp.TargetInternalArray) + if err != nil { + return nil, err + } + v.Filters = Filters{filter} + } + + return v, nil + case bytes.Equal(label, labelNotice): + x := NoticeEnvelope("") + v := &x + if typ := iter.Advance(); typ == simdjson.TypeString { + msg, _ := iter.String() + *v = NoticeEnvelope(msg) + } + return v, nil + case bytes.Equal(label, labelEose): + x := EOSEEnvelope("") + v := &x + if typ := iter.Advance(); typ == simdjson.TypeString { + msg, _ := iter.String() + *v = EOSEEnvelope(msg) + } + return v, nil + case bytes.Equal(label, labelOk): + v := &OKEnvelope{} + if typ := iter.Advance(); typ == simdjson.TypeString { + v.EventID, _ = iter.String() + } else { + return nil, fmt.Errorf("unexpected %s for OK id", typ) + } + if typ := iter.Advance(); typ == simdjson.TypeBool { + v.OK, _ = iter.Bool() + } else { + return nil, fmt.Errorf("unexpected %s for OK status", typ) + } + if typ := iter.Advance(); typ == simdjson.TypeString { + v.Reason, _ = iter.String() + } + return v, nil + case bytes.Equal(label, labelAuth): + v := &AuthEnvelope{} + if typ := iter.Advance(); typ == simdjson.TypeString { + // we have a challenge + subID, err := iter.String() + if err != nil { + return nil, err + } + v.Challenge = &subID + return v, nil + } else { + // we have an event + smp.TargetObject, smp.TargetInternalArray, smp.AuxArray, err = v.Event.UnmarshalSIMD( + &iter, smp.TargetObject, smp.TargetInternalArray, smp.AuxArray) + return v, err + } + case bytes.Equal(label, labelClosed): + v := &ClosedEnvelope{} + if typ := iter.Advance(); typ == simdjson.TypeString { + v.SubscriptionID, _ = iter.String() + } + if typ := iter.Advance(); typ == simdjson.TypeString { + v.Reason, _ = iter.String() + } + return v, nil + case bytes.Equal(label, labelClose): + x := CloseEnvelope("") + v := &x + if typ := iter.Advance(); typ == simdjson.TypeString { + msg, _ := iter.String() + *v = CloseEnvelope(msg) + } + return v, nil + default: + return nil, UnknownLabel + } +} diff --git a/envelopes_test.go b/envelopes_test.go index ca64dcd..91410e1 100644 --- a/envelopes_test.go +++ b/envelopes_test.go @@ -304,9 +304,10 @@ func TestParseMessageSIMD(t *testing.T) { } for _, testCase := range testCases { + smp := SIMDMessageParser{AuxIter: &simdjson.Iter{}} + t.Run(testCase.Name, func(t *testing.T) { - var pj *simdjson.ParsedJson - envelope, err := ParseMessageSIMD(testCase.Message, pj) + envelope, err := smp.ParseMessage(testCase.Message) if testCase.ExpectedErrorSubstring == "" { require.NoError(t, err) diff --git a/event_simdjson.go b/event_simdjson.go index 4fc3bf9..a3d7cd8 100644 --- a/event_simdjson.go +++ b/event_simdjson.go @@ -17,16 +17,21 @@ var ( attrSig = []byte("sig") ) -func (event *Event) UnmarshalSIMD(iter *simdjson.Iter) error { - obj, err := iter.Object(nil) +func (event *Event) UnmarshalSIMD( + iter *simdjson.Iter, + obj *simdjson.Object, + arr *simdjson.Array, + subArr *simdjson.Array, +) (*simdjson.Object, *simdjson.Array, *simdjson.Array, error) { + obj, err := iter.Object(obj) if err != nil { - return fmt.Errorf("unexpected at event: %w", err) + return obj, arr, subArr, fmt.Errorf("unexpected at event: %w", err) } for { name, t, err := obj.NextElementBytes(iter) if err != nil { - return err + return obj, arr, subArr, err } else if t == simdjson.TypeNone { break } @@ -49,36 +54,34 @@ func (event *Event) UnmarshalSIMD(iter *simdjson.Iter) error { kind, err = iter.Uint() event.Kind = int(kind) case bytes.Equal(name, attrTags): - var arr *simdjson.Array - arr, err = iter.Array(nil) + arr, err = iter.Array(arr) if err != nil { - return err + return obj, arr, subArr, err } event.Tags = make(Tags, 0, 10) titer := arr.Iter() - var subArr *simdjson.Array for { if t := titer.Advance(); t == simdjson.TypeNone { break } subArr, err = titer.Array(subArr) if err != nil { - return err + return obj, arr, subArr, err } tag, err := subArr.AsString() if err != nil { - return err + return obj, arr, subArr, err } event.Tags = append(event.Tags, tag) } default: - return fmt.Errorf("unexpected event field '%s'", name) + return obj, arr, subArr, fmt.Errorf("unexpected event field '%s'", name) } if err != nil { - return err + return obj, arr, subArr, err } } - return nil + return obj, arr, subArr, nil } diff --git a/filter_simdjson.go b/filter_simdjson.go index 69fabd4..700a1dc 100644 --- a/filter_simdjson.go +++ b/filter_simdjson.go @@ -17,17 +17,20 @@ var ( attrSearch = []byte("search") ) -func (filter *Filter) UnmarshalSIMD(iter *simdjson.Iter) error { - obj, err := iter.Object(nil) +func (filter *Filter) UnmarshalSIMD( + iter *simdjson.Iter, + obj *simdjson.Object, + arr *simdjson.Array, +) (*simdjson.Object, *simdjson.Array, error) { + obj, err := iter.Object(obj) if err != nil { - return fmt.Errorf("unexpected at filter: %w", err) + return obj, arr, fmt.Errorf("unexpected at filter: %w", err) } - var arr *simdjson.Array for { name, t, err := obj.NextElementBytes(iter) if err != nil { - return err + return obj, arr, err } else if t == simdjson.TypeNone { break } @@ -51,7 +54,7 @@ func (filter *Filter) UnmarshalSIMD(iter *simdjson.Iter) error { break } if kind, err := i.Uint(); err != nil { - return err + return obj, arr, err } else { filter.Kinds = append(filter.Kinds, int(kind)) } @@ -84,24 +87,24 @@ func (filter *Filter) UnmarshalSIMD(iter *simdjson.Iter) error { arr, err := iter.Array(arr) if err != nil { - return err + return obj, arr, err } vals, err := arr.AsString() if err != nil { - return err + return obj, arr, err } filter.Tags[string(name[1:])] = vals continue } - return fmt.Errorf("unexpected filter field '%s'", name) + return obj, arr, fmt.Errorf("unexpected filter field '%s'", name) } if err != nil { - return err + return obj, arr, err } } - return nil + return obj, arr, nil }