negentropy: use bytes.Reader instead of pointer to byte slice.

This commit is contained in:
fiatjaf 2024-07-21 17:26:11 -03:00
parent fd1fc8340f
commit 36bf0b4f14
3 changed files with 46 additions and 102 deletions

View File

@ -58,69 +58,52 @@ func (n *Negentropy) Initiate() ([]byte, error) {
return output, nil
}
func (n *Negentropy) Reconcile(query []byte) ([]byte, error) {
func (n *Negentropy) Reconcile(query []byte) (output []byte, haveIds []string, needIds []string, err error) {
if n.IsInitiator {
return []byte{}, fmt.Errorf("initiator not asking for have/need IDs")
return nil, nil, nil, fmt.Errorf("initiator not asking for have/need IDs")
}
var haveIds, needIds []string
output, err := n.ReconcileAux(query, &haveIds, &needIds)
reader := bytes.NewReader(query)
haveIds = make([]string, 0, 100)
needIds = make([]string, 0, 100)
output, err = n.ReconcileAux(reader, &haveIds, &needIds)
if err != nil {
return nil, err
return nil, nil, nil, err
}
if len(output) == 1 && n.IsInitiator {
return nil, nil
return nil, haveIds, needIds, nil
}
return output, nil
return output, haveIds, needIds, nil
}
// ReconcileWithIDs when IDs are expected to be returned.
func (n *Negentropy) ReconcileWithIDs(query []byte, haveIds, needIds *[]string) ([]byte, error) {
if !n.IsInitiator {
return nil, fmt.Errorf("non-initiator asking for have/need IDs")
}
output, err := n.ReconcileAux(query, haveIds, needIds)
if err != nil {
return nil, err
}
if len(output) == 1 {
// Assuming an empty string is a special case indicating a condition similar to std::nullopt
return nil, nil
}
return output, nil
}
func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]byte, error) {
func (n *Negentropy) ReconcileAux(reader *bytes.Reader, haveIds, needIds *[]string) ([]byte, error) {
n.lastTimestampIn, n.lastTimestampOut = 0, 0 // Reset for each message
var fullOutput []byte
fullOutput = append(fullOutput, protocolVersion)
protocolVersion, err := getByte(&query)
pv, err := reader.ReadByte()
if err != nil {
return nil, err
}
if protocolVersion < 0x60 || protocolVersion > 0x6F {
if pv < 0x60 || pv > 0x6F {
return nil, fmt.Errorf("invalid negentropy protocol version byte")
}
if protocolVersion != protocolVersion {
if pv != protocolVersion {
if n.IsInitiator {
return nil, fmt.Errorf("unsupported negentropy protocol version requested")
}
return fullOutput, nil
}
storageSize := n.storage.Size()
var prevBound Bound
prevIndex := 0
skip := false
// convert the loop to process the query until it's consumed
for len(query) > 0 {
for reader.Len() > 0 {
var o []byte
doSkip := func() {
@ -135,18 +118,18 @@ func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]b
}
}
currBound, err := n.DecodeBound(&query)
currBound, err := n.DecodeBound(reader)
if err != nil {
return nil, err
}
modeVal, err := decodeVarInt(&query)
modeVal, err := decodeVarInt(reader)
if err != nil {
return nil, err
}
mode := Mode(modeVal)
lower := prevIndex
upper, err := n.storage.FindLowerBound(prevIndex, storageSize, currBound)
upper, err := n.storage.FindLowerBound(prevIndex, n.storage.Size(), currBound)
if err != nil {
return nil, err
}
@ -156,7 +139,8 @@ func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]b
skip = true
case FingerprintMode:
theirFingerprint, err := getBytes(&query, FingerprintSize)
theirFingerprint := make([]byte, FingerprintSize)
_, err := reader.Read(theirFingerprint)
if err != nil {
return nil, err
}
@ -173,19 +157,20 @@ func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]b
}
case IdListMode:
numIds64, err := decodeVarInt(&query)
numIds64, err := decodeVarInt(reader)
if err != nil {
return nil, err
}
numIds := int(numIds64)
theirElems := make(map[string]struct{})
idb := make([]byte, n.idSize)
for i := 0; i < numIds; i++ {
e, err := getBytes(&query, n.idSize)
_, err := reader.Read(idb)
if err != nil {
return nil, err
}
theirElems[hex.EncodeToString(e)] = struct{}{}
theirElems[hex.EncodeToString(idb)] = struct{}{}
}
n.storage.Iterate(lower, upper, func(item Item, _ int) bool {
@ -249,7 +234,7 @@ func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]b
// Check if the frame size limit is exceeded
if n.ExceededFrameSizeLimit(len(fullOutput) + len(o)) {
// Frame size limit exceeded, handle by encoding a boundary and fingerprint for the remaining range
remainingFingerprint, err := n.storage.Fingerprint(upper, storageSize)
remainingFingerprint, err := n.storage.Fingerprint(upper, n.storage.Size())
if err != nil {
panic(err)
}
@ -277,15 +262,14 @@ func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]b
func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *[]byte) {
numElems := upper - lower
const Buckets = 16
const buckets = 16
if numElems < Buckets*2 {
if numElems < buckets*2 {
boundEncoded, err := n.encodeBound(upperBound)
if err != nil {
fmt.Fprintln(os.Stderr, err)
panic(err)
}
fmt.Println("upp", upperBound, boundEncoded)
*output = append(*output, boundEncoded...)
*output = append(*output, encodeVarInt(IdListMode)...)
*output = append(*output, encodeVarInt(numElems)...)
@ -296,11 +280,11 @@ func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *[]by
return true
})
} else {
itemsPerBucket := numElems / Buckets
bucketsWithExtra := numElems % Buckets
itemsPerBucket := numElems / buckets
bucketsWithExtra := numElems % buckets
curr := lower
for i := 0; i < Buckets; i++ {
for i := 0; i < buckets; i++ {
bucketSize := itemsPerBucket
if i < bucketsWithExtra {
bucketSize++
@ -350,8 +334,8 @@ func (n *Negentropy) ExceededFrameSizeLimit(size int) bool {
// Decoding
func (n *Negentropy) DecodeTimestampIn(encoded *[]byte) (nostr.Timestamp, error) {
t, err := decodeVarInt(encoded)
func (n *Negentropy) DecodeTimestampIn(reader *bytes.Reader) (nostr.Timestamp, error) {
t, err := decodeVarInt(reader)
if err != nil {
return 0, err
}
@ -371,19 +355,19 @@ func (n *Negentropy) DecodeTimestampIn(encoded *[]byte) (nostr.Timestamp, error)
return timestamp, nil
}
func (n *Negentropy) DecodeBound(encoded *[]byte) (Bound, error) {
timestamp, err := n.DecodeTimestampIn(encoded)
func (n *Negentropy) DecodeBound(reader *bytes.Reader) (Bound, error) {
timestamp, err := n.DecodeTimestampIn(reader)
if err != nil {
return Bound{}, err
}
length, err := decodeVarInt(encoded)
length, err := decodeVarInt(reader)
if err != nil {
return Bound{}, err
}
id, err := getBytes(encoded, length)
if err != nil {
id := make([]byte, length)
if _, err = reader.Read(id); err != nil {
return Bound{}, err
}

View File

@ -1,58 +1,18 @@
package negentropy
import (
"errors"
)
import "bytes"
var ErrParseEndsPrematurely = errors.New("parse ends prematurely")
func getByte(encoded *[]byte) (byte, error) {
if len(*encoded) < 1 {
return 0, ErrParseEndsPrematurely
}
b := (*encoded)[0]
*encoded = (*encoded)[1:]
return b, nil
}
func getBytes(encoded *[]byte, n int) ([]byte, error) {
// fmt.Fprintln(os.Stderr, "getBytes", len(*encoded), n)
if len(*encoded) < n {
return nil, errors.New("parse ends prematurely")
}
result := (*encoded)[:n]
*encoded = (*encoded)[n:]
return result, nil
}
func decodeVarInt(encoded *[]byte) (int, error) {
//var res uint64
//
//for i := 0; i < len(*encoded); i++ {
// byte := (*encoded)[i]
// res = (res << 7) | uint64(byte&0x7F)
// if (byte & 0x80) == 0 {
// fmt.Fprintln(os.Stderr, "decodeVarInt", encoded, i)
// *encoded = (*encoded)[i+1:] // Advance the slice to reflect consumed bytes
// return res, nil
// }
//}
//return 0, ErrParseEndsPrematurely
func decodeVarInt(reader *bytes.Reader) (int, error) {
var res int = 0
for {
if len(*encoded) == 0 {
return 0, errors.New("parse ends prematurely")
b, err := reader.ReadByte()
if err != nil {
return 0, err
}
// Remove the first byte from the slice and update the slice.
// This simulates JavaScript's shift operation on arrays.
byte := (*encoded)[0]
*encoded = (*encoded)[1:]
res = (res << 7) | (int(byte) & 127)
if (byte & 128) == 0 {
res = (res << 7) | (int(b) & 127)
if (b & 128) == 0 {
break
}
}

View File

@ -46,7 +46,7 @@ func TestSimple(t *testing.T) {
n2.Insert(events[i])
}
q, err = n2.Reconcile(q)
q, _, _, err = n2.Reconcile(q)
if err != nil {
t.Fatal(err)
return
@ -57,7 +57,7 @@ func TestSimple(t *testing.T) {
{
var have []string
var need []string
q, err = n1.ReconcileWithIDs(q, &have, &need)
q, have, need, err = n1.Reconcile(q)
if err != nil {
t.Fatal(err)
return