mirror of
https://github.com/nbd-wtf/go-nostr.git
synced 2025-06-27 01:02:44 +02:00
negentropy: use bytes.Reader instead of pointer to byte slice.
This commit is contained in:
parent
fd1fc8340f
commit
36bf0b4f14
@ -58,69 +58,52 @@ func (n *Negentropy) Initiate() ([]byte, error) {
|
|||||||
return output, nil
|
return output, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Negentropy) Reconcile(query []byte) ([]byte, error) {
|
func (n *Negentropy) Reconcile(query []byte) (output []byte, haveIds []string, needIds []string, err error) {
|
||||||
if n.IsInitiator {
|
if n.IsInitiator {
|
||||||
return []byte{}, fmt.Errorf("initiator not asking for have/need IDs")
|
return nil, nil, nil, fmt.Errorf("initiator not asking for have/need IDs")
|
||||||
}
|
}
|
||||||
var haveIds, needIds []string
|
|
||||||
|
|
||||||
output, err := n.ReconcileAux(query, &haveIds, &needIds)
|
reader := bytes.NewReader(query)
|
||||||
|
haveIds = make([]string, 0, 100)
|
||||||
|
needIds = make([]string, 0, 100)
|
||||||
|
|
||||||
|
output, err = n.ReconcileAux(reader, &haveIds, &needIds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(output) == 1 && n.IsInitiator {
|
if len(output) == 1 && n.IsInitiator {
|
||||||
return nil, nil
|
return nil, haveIds, needIds, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return output, nil
|
return output, haveIds, needIds, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReconcileWithIDs when IDs are expected to be returned.
|
func (n *Negentropy) ReconcileAux(reader *bytes.Reader, haveIds, needIds *[]string) ([]byte, error) {
|
||||||
func (n *Negentropy) ReconcileWithIDs(query []byte, haveIds, needIds *[]string) ([]byte, error) {
|
|
||||||
if !n.IsInitiator {
|
|
||||||
return nil, fmt.Errorf("non-initiator asking for have/need IDs")
|
|
||||||
}
|
|
||||||
|
|
||||||
output, err := n.ReconcileAux(query, haveIds, needIds)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if len(output) == 1 {
|
|
||||||
// Assuming an empty string is a special case indicating a condition similar to std::nullopt
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return output, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]byte, error) {
|
|
||||||
n.lastTimestampIn, n.lastTimestampOut = 0, 0 // Reset for each message
|
n.lastTimestampIn, n.lastTimestampOut = 0, 0 // Reset for each message
|
||||||
|
|
||||||
var fullOutput []byte
|
var fullOutput []byte
|
||||||
fullOutput = append(fullOutput, protocolVersion)
|
fullOutput = append(fullOutput, protocolVersion)
|
||||||
|
|
||||||
protocolVersion, err := getByte(&query)
|
pv, err := reader.ReadByte()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if protocolVersion < 0x60 || protocolVersion > 0x6F {
|
if pv < 0x60 || pv > 0x6F {
|
||||||
return nil, fmt.Errorf("invalid negentropy protocol version byte")
|
return nil, fmt.Errorf("invalid negentropy protocol version byte")
|
||||||
}
|
}
|
||||||
if protocolVersion != protocolVersion {
|
if pv != protocolVersion {
|
||||||
if n.IsInitiator {
|
if n.IsInitiator {
|
||||||
return nil, fmt.Errorf("unsupported negentropy protocol version requested")
|
return nil, fmt.Errorf("unsupported negentropy protocol version requested")
|
||||||
}
|
}
|
||||||
return fullOutput, nil
|
return fullOutput, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
storageSize := n.storage.Size()
|
|
||||||
var prevBound Bound
|
var prevBound Bound
|
||||||
prevIndex := 0
|
prevIndex := 0
|
||||||
skip := false
|
skip := false
|
||||||
|
|
||||||
// convert the loop to process the query until it's consumed
|
for reader.Len() > 0 {
|
||||||
for len(query) > 0 {
|
|
||||||
var o []byte
|
var o []byte
|
||||||
|
|
||||||
doSkip := func() {
|
doSkip := func() {
|
||||||
@ -135,18 +118,18 @@ func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]b
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
currBound, err := n.DecodeBound(&query)
|
currBound, err := n.DecodeBound(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
modeVal, err := decodeVarInt(&query)
|
modeVal, err := decodeVarInt(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mode := Mode(modeVal)
|
mode := Mode(modeVal)
|
||||||
|
|
||||||
lower := prevIndex
|
lower := prevIndex
|
||||||
upper, err := n.storage.FindLowerBound(prevIndex, storageSize, currBound)
|
upper, err := n.storage.FindLowerBound(prevIndex, n.storage.Size(), currBound)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -156,7 +139,8 @@ func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]b
|
|||||||
skip = true
|
skip = true
|
||||||
|
|
||||||
case FingerprintMode:
|
case FingerprintMode:
|
||||||
theirFingerprint, err := getBytes(&query, FingerprintSize)
|
theirFingerprint := make([]byte, FingerprintSize)
|
||||||
|
_, err := reader.Read(theirFingerprint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -173,19 +157,20 @@ func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]b
|
|||||||
}
|
}
|
||||||
|
|
||||||
case IdListMode:
|
case IdListMode:
|
||||||
numIds64, err := decodeVarInt(&query)
|
numIds64, err := decodeVarInt(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
numIds := int(numIds64)
|
numIds := int(numIds64)
|
||||||
|
|
||||||
theirElems := make(map[string]struct{})
|
theirElems := make(map[string]struct{})
|
||||||
|
idb := make([]byte, n.idSize)
|
||||||
for i := 0; i < numIds; i++ {
|
for i := 0; i < numIds; i++ {
|
||||||
e, err := getBytes(&query, n.idSize)
|
_, err := reader.Read(idb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
theirElems[hex.EncodeToString(e)] = struct{}{}
|
theirElems[hex.EncodeToString(idb)] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
n.storage.Iterate(lower, upper, func(item Item, _ int) bool {
|
n.storage.Iterate(lower, upper, func(item Item, _ int) bool {
|
||||||
@ -249,7 +234,7 @@ func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]b
|
|||||||
// Check if the frame size limit is exceeded
|
// Check if the frame size limit is exceeded
|
||||||
if n.ExceededFrameSizeLimit(len(fullOutput) + len(o)) {
|
if n.ExceededFrameSizeLimit(len(fullOutput) + len(o)) {
|
||||||
// Frame size limit exceeded, handle by encoding a boundary and fingerprint for the remaining range
|
// Frame size limit exceeded, handle by encoding a boundary and fingerprint for the remaining range
|
||||||
remainingFingerprint, err := n.storage.Fingerprint(upper, storageSize)
|
remainingFingerprint, err := n.storage.Fingerprint(upper, n.storage.Size())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@ -277,15 +262,14 @@ func (n *Negentropy) ReconcileAux(query []byte, haveIds, needIds *[]string) ([]b
|
|||||||
|
|
||||||
func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *[]byte) {
|
func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *[]byte) {
|
||||||
numElems := upper - lower
|
numElems := upper - lower
|
||||||
const Buckets = 16
|
const buckets = 16
|
||||||
|
|
||||||
if numElems < Buckets*2 {
|
if numElems < buckets*2 {
|
||||||
boundEncoded, err := n.encodeBound(upperBound)
|
boundEncoded, err := n.encodeBound(upperBound)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintln(os.Stderr, err)
|
fmt.Fprintln(os.Stderr, err)
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
fmt.Println("upp", upperBound, boundEncoded)
|
|
||||||
*output = append(*output, boundEncoded...)
|
*output = append(*output, boundEncoded...)
|
||||||
*output = append(*output, encodeVarInt(IdListMode)...)
|
*output = append(*output, encodeVarInt(IdListMode)...)
|
||||||
*output = append(*output, encodeVarInt(numElems)...)
|
*output = append(*output, encodeVarInt(numElems)...)
|
||||||
@ -296,11 +280,11 @@ func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *[]by
|
|||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
itemsPerBucket := numElems / Buckets
|
itemsPerBucket := numElems / buckets
|
||||||
bucketsWithExtra := numElems % Buckets
|
bucketsWithExtra := numElems % buckets
|
||||||
curr := lower
|
curr := lower
|
||||||
|
|
||||||
for i := 0; i < Buckets; i++ {
|
for i := 0; i < buckets; i++ {
|
||||||
bucketSize := itemsPerBucket
|
bucketSize := itemsPerBucket
|
||||||
if i < bucketsWithExtra {
|
if i < bucketsWithExtra {
|
||||||
bucketSize++
|
bucketSize++
|
||||||
@ -350,8 +334,8 @@ func (n *Negentropy) ExceededFrameSizeLimit(size int) bool {
|
|||||||
|
|
||||||
// Decoding
|
// Decoding
|
||||||
|
|
||||||
func (n *Negentropy) DecodeTimestampIn(encoded *[]byte) (nostr.Timestamp, error) {
|
func (n *Negentropy) DecodeTimestampIn(reader *bytes.Reader) (nostr.Timestamp, error) {
|
||||||
t, err := decodeVarInt(encoded)
|
t, err := decodeVarInt(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@ -371,19 +355,19 @@ func (n *Negentropy) DecodeTimestampIn(encoded *[]byte) (nostr.Timestamp, error)
|
|||||||
return timestamp, nil
|
return timestamp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *Negentropy) DecodeBound(encoded *[]byte) (Bound, error) {
|
func (n *Negentropy) DecodeBound(reader *bytes.Reader) (Bound, error) {
|
||||||
timestamp, err := n.DecodeTimestampIn(encoded)
|
timestamp, err := n.DecodeTimestampIn(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Bound{}, err
|
return Bound{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
length, err := decodeVarInt(encoded)
|
length, err := decodeVarInt(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Bound{}, err
|
return Bound{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
id, err := getBytes(encoded, length)
|
id := make([]byte, length)
|
||||||
if err != nil {
|
if _, err = reader.Read(id); err != nil {
|
||||||
return Bound{}, err
|
return Bound{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,58 +1,18 @@
|
|||||||
package negentropy
|
package negentropy
|
||||||
|
|
||||||
import (
|
import "bytes"
|
||||||
"errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
var ErrParseEndsPrematurely = errors.New("parse ends prematurely")
|
func decodeVarInt(reader *bytes.Reader) (int, error) {
|
||||||
|
|
||||||
func getByte(encoded *[]byte) (byte, error) {
|
|
||||||
if len(*encoded) < 1 {
|
|
||||||
return 0, ErrParseEndsPrematurely
|
|
||||||
}
|
|
||||||
b := (*encoded)[0]
|
|
||||||
*encoded = (*encoded)[1:]
|
|
||||||
|
|
||||||
return b, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBytes(encoded *[]byte, n int) ([]byte, error) {
|
|
||||||
// fmt.Fprintln(os.Stderr, "getBytes", len(*encoded), n)
|
|
||||||
if len(*encoded) < n {
|
|
||||||
return nil, errors.New("parse ends prematurely")
|
|
||||||
}
|
|
||||||
result := (*encoded)[:n]
|
|
||||||
*encoded = (*encoded)[n:]
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeVarInt(encoded *[]byte) (int, error) {
|
|
||||||
//var res uint64
|
|
||||||
//
|
|
||||||
//for i := 0; i < len(*encoded); i++ {
|
|
||||||
// byte := (*encoded)[i]
|
|
||||||
// res = (res << 7) | uint64(byte&0x7F)
|
|
||||||
// if (byte & 0x80) == 0 {
|
|
||||||
// fmt.Fprintln(os.Stderr, "decodeVarInt", encoded, i)
|
|
||||||
// *encoded = (*encoded)[i+1:] // Advance the slice to reflect consumed bytes
|
|
||||||
// return res, nil
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
//return 0, ErrParseEndsPrematurely
|
|
||||||
var res int = 0
|
var res int = 0
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if len(*encoded) == 0 {
|
b, err := reader.ReadByte()
|
||||||
return 0, errors.New("parse ends prematurely")
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the first byte from the slice and update the slice.
|
res = (res << 7) | (int(b) & 127)
|
||||||
// This simulates JavaScript's shift operation on arrays.
|
if (b & 128) == 0 {
|
||||||
byte := (*encoded)[0]
|
|
||||||
*encoded = (*encoded)[1:]
|
|
||||||
|
|
||||||
res = (res << 7) | (int(byte) & 127)
|
|
||||||
if (byte & 128) == 0 {
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,7 @@ func TestSimple(t *testing.T) {
|
|||||||
n2.Insert(events[i])
|
n2.Insert(events[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
q, err = n2.Reconcile(q)
|
q, _, _, err = n2.Reconcile(q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
@ -57,7 +57,7 @@ func TestSimple(t *testing.T) {
|
|||||||
{
|
{
|
||||||
var have []string
|
var have []string
|
||||||
var need []string
|
var need []string
|
||||||
q, err = n1.ReconcileWithIDs(q, &have, &need)
|
q, have, need, err = n1.Reconcile(q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
|
Loading…
x
Reference in New Issue
Block a user