diff --git a/negentropy/negentropy.go b/negentropy/negentropy.go index 52b15d8..6bbdd8e 100644 --- a/negentropy/negentropy.go +++ b/negentropy/negentropy.go @@ -58,69 +58,52 @@ func (n *Negentropy) Initiate() ([]byte, error) { return output, nil } -func (n *Negentropy) Reconcile(query []byte) ([]byte, error) { +func (n *Negentropy) Reconcile(query []byte) (output []byte, haveIds []string, needIds []string, err error) { if n.IsInitiator { - return []byte{}, fmt.Errorf("initiator not asking for have/need IDs") + return nil, nil, nil, fmt.Errorf("initiator not asking for have/need IDs") } - var haveIds, needIds []string - output, err := n.ReconcileAux(query, &haveIds, &needIds) + reader := bytes.NewReader(query) + haveIds = make([]string, 0, 100) + needIds = make([]string, 0, 100) + + output, err = n.ReconcileAux(reader, &haveIds, &needIds) if err != nil { - return nil, err + return nil, nil, nil, err } if len(output) == 1 && n.IsInitiator { - return nil, nil + return nil, haveIds, needIds, nil } - return output, nil + return output, haveIds, needIds, nil } -// ReconcileWithIDs when IDs are expected to be returned. -func (n *Negentropy) ReconcileWithIDs(query []byte, haveIds, needIds *[]string) ([]byte, error) { - if !n.IsInitiator { - return nil, fmt.Errorf("non-initiator asking for have/need IDs") - } - - output, err := n.ReconcileAux(query, haveIds, needIds) - if err != nil { - return nil, err - } - if len(output) == 1 { - // Assuming an empty string is a special case indicating a condition similar to std::nullopt - return nil, nil - } - - return output, nil -} - -func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]byte, error) { +func (n *Negentropy) ReconcileAux(reader *bytes.Reader, haveIds, needIds *[]string) ([]byte, error) { n.lastTimestampIn, n.lastTimestampOut = 0, 0 // Reset for each message var fullOutput []byte fullOutput = append(fullOutput, protocolVersion) - protocolVersion, err := getByte(&query) + pv, err := reader.ReadByte() if err != nil { return nil, err } - if protocolVersion < 0x60 || protocolVersion > 0x6F { + if pv < 0x60 || pv > 0x6F { return nil, fmt.Errorf("invalid negentropy protocol version byte") } - if protocolVersion != protocolVersion { + if pv != protocolVersion { if n.IsInitiator { return nil, fmt.Errorf("unsupported negentropy protocol version requested") } return fullOutput, nil } - storageSize := n.storage.Size() var prevBound Bound prevIndex := 0 skip := false - // convert the loop to process the query until it's consumed - for len(query) > 0 { + for reader.Len() > 0 { var o []byte doSkip := func() { @@ -135,18 +118,18 @@ func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]b } } - currBound, err := n.DecodeBound(&query) + currBound, err := n.DecodeBound(reader) if err != nil { return nil, err } - modeVal, err := decodeVarInt(&query) + modeVal, err := decodeVarInt(reader) if err != nil { return nil, err } mode := Mode(modeVal) lower := prevIndex - upper, err := n.storage.FindLowerBound(prevIndex, storageSize, currBound) + upper, err := n.storage.FindLowerBound(prevIndex, n.storage.Size(), currBound) if err != nil { return nil, err } @@ -156,7 +139,8 @@ func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]b skip = true case FingerprintMode: - theirFingerprint, err := getBytes(&query, FingerprintSize) + theirFingerprint := make([]byte, FingerprintSize) + _, err := reader.Read(theirFingerprint) if err != nil { return nil, err } @@ -173,19 +157,20 @@ func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]b } case IdListMode: - numIds64, err := decodeVarInt(&query) + numIds64, err := decodeVarInt(reader) if err != nil { return nil, err } numIds := int(numIds64) theirElems := make(map[string]struct{}) + idb := make([]byte, n.idSize) for i := 0; i < numIds; i++ { - e, err := getBytes(&query, n.idSize) + _, err := reader.Read(idb) if err != nil { return nil, err } - theirElems[hex.EncodeToString(e)] = struct{}{} + theirElems[hex.EncodeToString(idb)] = struct{}{} } n.storage.Iterate(lower, upper, func(item Item, _ int) bool { @@ -249,7 +234,7 @@ func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]b // Check if the frame size limit is exceeded if n.ExceededFrameSizeLimit(len(fullOutput) + len(o)) { // Frame size limit exceeded, handle by encoding a boundary and fingerprint for the remaining range - remainingFingerprint, err := n.storage.Fingerprint(upper, storageSize) + remainingFingerprint, err := n.storage.Fingerprint(upper, n.storage.Size()) if err != nil { panic(err) } @@ -277,15 +262,14 @@ func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]b func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *[]byte) { numElems := upper - lower - const Buckets = 16 + const buckets = 16 - if numElems < Buckets*2 { + if numElems < buckets*2 { boundEncoded, err := n.encodeBound(upperBound) if err != nil { fmt.Fprintln(os.Stderr, err) panic(err) } - fmt.Println("upp", upperBound, boundEncoded) *output = append(*output, boundEncoded...) *output = append(*output, encodeVarInt(IdListMode)...) *output = append(*output, encodeVarInt(numElems)...) @@ -296,11 +280,11 @@ func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *[]by return true }) } else { - itemsPerBucket := numElems / Buckets - bucketsWithExtra := numElems % Buckets + itemsPerBucket := numElems / buckets + bucketsWithExtra := numElems % buckets curr := lower - for i := 0; i < Buckets; i++ { + for i := 0; i < buckets; i++ { bucketSize := itemsPerBucket if i < bucketsWithExtra { bucketSize++ @@ -350,8 +334,8 @@ func (n *Negentropy) ExceededFrameSizeLimit(size int) bool { // Decoding -func (n *Negentropy) DecodeTimestampIn(encoded *[]byte) (nostr.Timestamp, error) { - t, err := decodeVarInt(encoded) +func (n *Negentropy) DecodeTimestampIn(reader *bytes.Reader) (nostr.Timestamp, error) { + t, err := decodeVarInt(reader) if err != nil { return 0, err } @@ -371,19 +355,19 @@ func (n *Negentropy) DecodeTimestampIn(encoded *[]byte) (nostr.Timestamp, error) return timestamp, nil } -func (n *Negentropy) DecodeBound(encoded *[]byte) (Bound, error) { - timestamp, err := n.DecodeTimestampIn(encoded) +func (n *Negentropy) DecodeBound(reader *bytes.Reader) (Bound, error) { + timestamp, err := n.DecodeTimestampIn(reader) if err != nil { return Bound{}, err } - length, err := decodeVarInt(encoded) + length, err := decodeVarInt(reader) if err != nil { return Bound{}, err } - id, err := getBytes(encoded, length) - if err != nil { + id := make([]byte, length) + if _, err = reader.Read(id); err != nil { return Bound{}, err } diff --git a/negentropy/utils.go b/negentropy/utils.go index ae7e87a..ce99e5c 100644 --- a/negentropy/utils.go +++ b/negentropy/utils.go @@ -1,58 +1,18 @@ package negentropy -import ( - "errors" -) +import "bytes" -var ErrParseEndsPrematurely = errors.New("parse ends prematurely") - -func getByte(encoded *[]byte) (byte, error) { - if len(*encoded) < 1 { - return 0, ErrParseEndsPrematurely - } - b := (*encoded)[0] - *encoded = (*encoded)[1:] - - return b, nil -} - -func getBytes(encoded *[]byte, n int) ([]byte, error) { - // fmt.Fprintln(os.Stderr, "getBytes", len(*encoded), n) - if len(*encoded) < n { - return nil, errors.New("parse ends prematurely") - } - result := (*encoded)[:n] - *encoded = (*encoded)[n:] - return result, nil -} - -func decodeVarInt(encoded *[]byte) (int, error) { - //var res uint64 - // - //for i := 0; i < len(*encoded); i++ { - // byte := (*encoded)[i] - // res = (res << 7) | uint64(byte&0x7F) - // if (byte & 0x80) == 0 { - // fmt.Fprintln(os.Stderr, "decodeVarInt", encoded, i) - // *encoded = (*encoded)[i+1:] // Advance the slice to reflect consumed bytes - // return res, nil - // } - //} - //return 0, ErrParseEndsPrematurely +func decodeVarInt(reader *bytes.Reader) (int, error) { var res int = 0 for { - if len(*encoded) == 0 { - return 0, errors.New("parse ends prematurely") + b, err := reader.ReadByte() + if err != nil { + return 0, err } - // Remove the first byte from the slice and update the slice. - // This simulates JavaScript's shift operation on arrays. - byte := (*encoded)[0] - *encoded = (*encoded)[1:] - - res = (res << 7) | (int(byte) & 127) - if (byte & 128) == 0 { + res = (res << 7) | (int(b) & 127) + if (b & 128) == 0 { break } } diff --git a/negentropy/whatever_test.go b/negentropy/whatever_test.go index e6723c5..d2f0567 100644 --- a/negentropy/whatever_test.go +++ b/negentropy/whatever_test.go @@ -46,7 +46,7 @@ func TestSimple(t *testing.T) { n2.Insert(events[i]) } - q, err = n2.Reconcile(q) + q, _, _, err = n2.Reconcile(q) if err != nil { t.Fatal(err) return @@ -57,7 +57,7 @@ func TestSimple(t *testing.T) { { var have []string var need []string - q, err = n1.ReconcileWithIDs(q, &have, &need) + q, have, need, err = n1.Reconcile(q) if err != nil { t.Fatal(err) return