negentropy: do the algorithm entirely in hex.

plus:
  - nicer iterators
  - some optimizations here and there.
  - something else I forgot.
This commit is contained in:
fiatjaf 2024-09-14 16:28:19 -03:00
parent b5f8d48f79
commit 286040c4ce
6 changed files with 293 additions and 194 deletions

View File

@ -1,13 +1,12 @@
package negentropy
import (
"bytes"
"encoding/hex"
"fmt"
"github.com/nbd-wtf/go-nostr"
)
func (n *Negentropy) DecodeTimestampIn(reader *bytes.Reader) (nostr.Timestamp, error) {
func (n *Negentropy) DecodeTimestampIn(reader *StringHexReader) (nostr.Timestamp, error) {
t, err := decodeVarInt(reader)
if err != nil {
return 0, err
@ -28,47 +27,42 @@ func (n *Negentropy) DecodeTimestampIn(reader *bytes.Reader) (nostr.Timestamp, e
return timestamp, nil
}
func (n *Negentropy) DecodeBound(reader *bytes.Reader) (Bound, error) {
func (n *Negentropy) DecodeBound(reader *StringHexReader) (Bound, error) {
timestamp, err := n.DecodeTimestampIn(reader)
if err != nil {
return Bound{}, err
return Bound{}, fmt.Errorf("failed to decode bound timestamp: %w", err)
}
length, err := decodeVarInt(reader)
if err != nil {
return Bound{}, err
return Bound{}, fmt.Errorf("failed to decode bound length: %w", err)
}
id := make([]byte, length)
if _, err = reader.Read(id); err != nil {
return Bound{}, err
id, err := reader.ReadString(length * 2)
if err != nil {
return Bound{}, fmt.Errorf("failed to read bound id: %w", err)
}
return Bound{Item{timestamp, hex.EncodeToString(id)}}, nil
return Bound{Item{timestamp, id}}, nil
}
func (n *Negentropy) encodeTimestampOut(timestamp nostr.Timestamp) []byte {
func (n *Negentropy) encodeTimestampOut(w *StringHexWriter, timestamp nostr.Timestamp) {
if timestamp == maxTimestamp {
n.lastTimestampOut = maxTimestamp
return encodeVarInt(0)
encodeVarIntToHex(w, 0)
return
}
temp := timestamp
timestamp -= n.lastTimestampOut
n.lastTimestampOut = temp
return encodeVarInt(int(timestamp + 1))
encodeVarIntToHex(w, int(timestamp+1))
return
}
func (n *Negentropy) encodeBound(bound Bound) []byte {
var output []byte
t := n.encodeTimestampOut(bound.Timestamp)
idlen := encodeVarInt(len(bound.ID) / 2)
output = append(output, t...)
output = append(output, idlen...)
id, _ := hex.DecodeString(bound.Item.ID)
output = append(output, id...)
return output
func (n *Negentropy) encodeBound(w *StringHexWriter, bound Bound) {
n.encodeTimestampOut(w, bound.Timestamp)
encodeVarIntToHex(w, len(bound.ID)/2)
w.WriteHex(bound.Item.ID)
}
func getMinimalBound(prev, curr Item) Bound {
@ -89,11 +83,11 @@ func getMinimalBound(prev, curr Item) Bound {
return Bound{Item{curr.Timestamp, curr.ID[:(sharedPrefixBytes+1)*2]}}
}
func decodeVarInt(reader *bytes.Reader) (int, error) {
func decodeVarInt(reader *StringHexReader) (int, error) {
var res int = 0
for {
b, err := reader.ReadByte()
b, err := reader.ReadHexByte()
if err != nil {
return 0, err
}
@ -124,3 +118,21 @@ func encodeVarInt(n int) []byte {
return o
}
func encodeVarIntToHex(w *StringHexWriter, n int) {
if n == 0 {
w.WriteByte(0)
}
var o []byte
for n != 0 {
o = append([]byte{byte(n & 0x7F)}, o...)
n >>= 7
}
for i := 0; i < len(o)-1; i++ {
o[i] |= 0x80
}
w.WriteBytes(o)
}

96
nip77/negentropy/hex.go Normal file
View File

@ -0,0 +1,96 @@
package negentropy
import (
"encoding/hex"
"io"
)
func NewStringHexReader(source string) *StringHexReader {
return &StringHexReader{source, 0, make([]byte, 1)}
}
type StringHexReader struct {
source string
idx int
tmp []byte
}
func (r *StringHexReader) Len() int {
return len(r.source) - r.idx
}
func (r *StringHexReader) ReadHexBytes(buf []byte) error {
n := len(buf) * 2
r.idx += n
if len(r.source) < r.idx {
return io.EOF
}
_, err := hex.Decode(buf, []byte(r.source[r.idx-n:r.idx]))
return err
}
func (r *StringHexReader) ReadHexByte() (byte, error) {
err := r.ReadHexBytes(r.tmp)
return r.tmp[0], err
}
func (r *StringHexReader) ReadString(size int) (string, error) {
if size == 0 {
return "", nil
}
r.idx += size
if len(r.source) < r.idx {
return "", io.EOF
}
return r.source[r.idx-size : r.idx], nil
}
func NewStringHexWriter(buf []byte) *StringHexWriter {
return &StringHexWriter{buf, make([]byte, 2)}
}
type StringHexWriter struct {
hexbuf []byte
tmp []byte
}
func (r *StringHexWriter) Len() int {
return len(r.hexbuf)
}
func (r *StringHexWriter) Hex() string {
return string(r.hexbuf)
}
func (r *StringHexWriter) Reset() {
r.hexbuf = r.hexbuf[:0]
}
func (r *StringHexWriter) WriteHex(hexString string) {
r.hexbuf = append(r.hexbuf, hexString...)
return
}
func (r *StringHexWriter) WriteByte(b byte) error {
hex.Encode(r.tmp, []byte{b})
r.hexbuf = append(r.hexbuf, r.tmp...)
return nil
}
func (r *StringHexWriter) WriteBytes(in []byte) {
r.hexbuf = hex.AppendEncode(r.hexbuf, in)
// curr := len(r.hexbuf)
// next := curr + len(in)*2
// for cap(r.hexbuf) < next {
// r.hexbuf = append(r.hexbuf, in...)
// }
// r.hexbuf = r.hexbuf[0:next]
// dst := r.hexbuf[curr:next]
// hex.Encode(dst, in)
return
}

View File

@ -1,11 +1,10 @@
package negentropy
import (
"bytes"
"encoding/hex"
"fmt"
"math"
"os"
"slices"
"strings"
"unsafe"
"github.com/nbd-wtf/go-nostr"
@ -22,7 +21,7 @@ type Negentropy struct {
storage Storage
sealed bool
frameSizeLimit int
isInitiator bool
isClient bool
lastTimestampIn nostr.Timestamp
lastTimestampOut nostr.Timestamp
@ -37,6 +36,17 @@ func NewNegentropy(storage Storage, frameSizeLimit int) *Negentropy {
}
}
func (n *Negentropy) String() string {
label := "unsealed"
if n.sealed {
label = "server"
if n.isClient {
label = "client"
}
}
return fmt.Sprintf("<Negentropy %s with %d items>", label, n.storage.Size())
}
func (n *Negentropy) Insert(evt *nostr.Event) {
err := n.storage.Insert(evt.CreatedAt, evt.ID)
if err != nil {
@ -51,83 +61,76 @@ func (n *Negentropy) seal() {
n.sealed = true
}
func (n *Negentropy) Initiate() []byte {
func (n *Negentropy) Initiate() string {
n.seal()
n.isInitiator = true
n.isClient = true
n.Haves = make(chan string, n.storage.Size()/2)
n.HaveNots = make(chan string, n.storage.Size()/2)
output := bytes.NewBuffer(make([]byte, 0, 1+n.storage.Size()*32))
output := NewStringHexWriter(make([]byte, 0, 1+n.storage.Size()*64))
output.WriteByte(protocolVersion)
n.SplitRange(0, n.storage.Size(), infiniteBound, output)
return output.Bytes()
return output.Hex()
}
func (n *Negentropy) Reconcile(msg []byte) (output []byte, err error) {
func (n *Negentropy) Reconcile(msg string) (output string, err error) {
n.seal()
reader := bytes.NewReader(msg)
reader := NewStringHexReader(msg)
output, err = n.reconcileAux(reader)
if err != nil {
return nil, err
return "", err
}
if len(output) == 1 && n.isInitiator {
if len(output) == 2 && n.isClient {
close(n.Haves)
close(n.HaveNots)
return nil, nil
return "", nil
}
return output, nil
}
func (n *Negentropy) reconcileAux(reader *bytes.Reader) ([]byte, error) {
func (n *Negentropy) reconcileAux(reader *StringHexReader) (string, error) {
n.lastTimestampIn, n.lastTimestampOut = 0, 0 // reset for each message
fullOutput := bytes.NewBuffer(make([]byte, 0, 5000))
fullOutput := NewStringHexWriter(make([]byte, 0, 5000))
fullOutput.WriteByte(protocolVersion)
pv, err := reader.ReadByte()
pv, err := reader.ReadHexByte()
if err != nil {
return nil, err
}
if pv < 0x60 || pv > 0x6f {
return nil, fmt.Errorf("invalid protocol version byte")
return "", fmt.Errorf("failed to read pv: %w", err)
}
if pv != protocolVersion {
if n.isInitiator {
return nil, fmt.Errorf("unsupported negentropy protocol version requested")
}
return fullOutput.Bytes(), nil
return "", fmt.Errorf("unsupported negentropy protocol version %v", pv)
}
var prevBound Bound
prevIndex := 0
skip := false
skipping := false // this means we are currently coalescing ranges into skip
partialOutput := bytes.NewBuffer(make([]byte, 0, 100))
partialOutput := NewStringHexWriter(make([]byte, 0, 100))
for reader.Len() > 0 {
partialOutput.Reset()
doSkip := func() {
if skip {
skip = false
encodedBound := n.encodeBound(prevBound)
partialOutput.Write(encodedBound)
partialOutput.WriteByte(SkipMode)
finishSkip := func() {
// end skip range, if necessary, so we can start a new bound that isn't a skip
if skipping {
skipping = false
n.encodeBound(partialOutput, prevBound)
partialOutput.WriteByte(byte(SkipMode))
}
}
currBound, err := n.DecodeBound(reader)
if err != nil {
return nil, err
return "", fmt.Errorf("failed to decode bound: %w", err)
}
modeVal, err := decodeVarInt(reader)
if err != nil {
return nil, err
return "", fmt.Errorf("failed to decode mode: %w", err)
}
mode := Mode(modeVal)
@ -136,134 +139,129 @@ func (n *Negentropy) reconcileAux(reader *bytes.Reader) ([]byte, error) {
switch mode {
case SkipMode:
skip = true
skipping = true
case FingerprintMode:
var theirFingerprint [FingerprintSize]byte
_, err := reader.Read(theirFingerprint[:])
if err != nil {
return nil, err
}
ourFingerprint, err := n.storage.Fingerprint(lower, upper)
if err != nil {
return nil, err
if err := reader.ReadHexBytes(theirFingerprint[:]); err != nil {
return "", fmt.Errorf("failed to read fingerprint: %w", err)
}
ourFingerprint := n.storage.Fingerprint(lower, upper)
if theirFingerprint == ourFingerprint {
skip = true
skipping = true
} else {
doSkip()
finishSkip()
n.SplitRange(lower, upper, currBound, partialOutput)
}
case IdListMode:
numIds, err := decodeVarInt(reader)
if err != nil {
return nil, err
return "", fmt.Errorf("failed to decode number of ids: %w", err)
}
theirElems := make(map[string]struct{})
var idb [32]byte
// what they have
theirItems := make([]string, 0, numIds)
for i := 0; i < numIds; i++ {
_, err := reader.Read(idb[:])
if err != nil {
return nil, err
if id, err := reader.ReadString(64); err != nil {
return "", fmt.Errorf("failed to read id (#%d/%d) in list: %w", i, numIds, err)
} else {
theirItems = append(theirItems, id)
}
id := hex.EncodeToString(idb[:])
theirElems[id] = struct{}{}
}
n.storage.Iterate(lower, upper, func(item Item, _ int) bool {
// what we have
for _, item := range n.storage.Range(lower, upper) {
id := item.ID
if _, exists := theirElems[id]; !exists {
if n.isInitiator {
if idx, theyHave := slices.BinarySearch(theirItems, id); theyHave {
// if we have and they have, ignore
theirItems[idx] = ""
} else {
// if we have and they don't, notify client
if n.isClient {
n.Haves <- id
}
} else {
delete(theirElems, id)
}
return true
})
}
if n.isInitiator {
skip = true
for id := range theirElems {
n.HaveNots <- id
if n.isClient {
// notify client of what they have and we don't
for _, id := range theirItems {
if id != "" {
n.HaveNots <- id
}
}
// client got list of ids, it's done, skip
skipping = true
} else {
doSkip()
// server got list of ids, reply with their own ids for the same range
finishSkip()
responseIds := strings.Builder{}
responseIds.Grow(64 * 100)
responses := 0
responseIds := make([]byte, 0, 32*n.storage.Size())
endBound := currBound
n.storage.Iterate(lower, upper, func(item Item, index int) bool {
if n.frameSizeLimit-200 < fullOutput.Len()+len(responseIds) {
for index, item := range n.storage.Range(lower, upper) {
if n.frameSizeLimit-200 < fullOutput.Len()+1+8+responseIds.Len() {
endBound = Bound{item}
upper = index
return false
break
}
responseIds.WriteString(item.ID)
responses++
}
id, _ := hex.DecodeString(item.ID)
responseIds = append(responseIds, id...)
return true
})
n.encodeBound(partialOutput, endBound)
partialOutput.WriteByte(byte(IdListMode))
encodeVarIntToHex(partialOutput, responses)
partialOutput.WriteHex(responseIds.String())
encodedBound := n.encodeBound(endBound)
partialOutput.Write(encodedBound)
partialOutput.WriteByte(IdListMode)
partialOutput.Write(encodeVarInt(len(responseIds) / 32))
partialOutput.Write(responseIds)
partialOutput.WriteTo(fullOutput)
fullOutput.WriteHex(partialOutput.Hex())
partialOutput.Reset()
}
default:
return nil, fmt.Errorf("unexpected mode %d", mode)
return "", fmt.Errorf("unexpected mode %d", mode)
}
if n.frameSizeLimit-200 < fullOutput.Len()+partialOutput.Len() {
// frame size limit exceeded, handle by encoding a boundary and fingerprint for the remaining range
remainingFingerprint, err := n.storage.Fingerprint(upper, n.storage.Size())
if err != nil {
panic(err)
}
fullOutput.Write(n.encodeBound(infiniteBound))
fullOutput.WriteByte(FingerprintMode)
fullOutput.Write(remainingFingerprint[:])
remainingFingerprint := n.storage.Fingerprint(upper, n.storage.Size())
n.encodeBound(fullOutput, infiniteBound)
fullOutput.WriteByte(byte(FingerprintMode))
fullOutput.WriteBytes(remainingFingerprint[:])
break // stop processing further
} else {
// append the constructed output for this iteration
partialOutput.WriteTo(fullOutput)
fullOutput.WriteHex(partialOutput.Hex())
}
prevIndex = upper
prevBound = currBound
}
return fullOutput.Bytes(), nil
return fullOutput.Hex(), nil
}
func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *bytes.Buffer) {
func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *StringHexWriter) {
numElems := upper - lower
const buckets = 16
if numElems < buckets*2 {
// we just send the full ids here
boundEncoded := n.encodeBound(upperBound)
output.Write(boundEncoded)
output.WriteByte(IdListMode)
output.Write(encodeVarInt(numElems))
n.encodeBound(output, upperBound)
output.WriteByte(byte(IdListMode))
encodeVarIntToHex(output, numElems)
n.storage.Iterate(lower, upper, func(item Item, _ int) bool {
id, _ := hex.DecodeString(item.ID)
output.Write(id)
return true
})
for _, item := range n.storage.Range(lower, upper) {
output.WriteHex(item.ID)
}
} else {
itemsPerBucket := numElems / buckets
bucketsWithExtra := numElems % buckets
@ -274,12 +272,7 @@ func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *byte
if i < bucketsWithExtra {
bucketSize++
}
ourFingerprint, err := n.storage.Fingerprint(curr, curr+bucketSize)
if err != nil {
fmt.Fprintln(os.Stderr, err)
panic(err)
}
ourFingerprint := n.storage.Fingerprint(curr, curr+bucketSize)
curr += bucketSize
var nextBound Bound
@ -288,23 +281,21 @@ func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *byte
} else {
var prevItem, currItem Item
n.storage.Iterate(curr-1, curr+1, func(item Item, index int) bool {
for index, item := range n.storage.Range(curr-1, curr+1) {
if index == curr-1 {
prevItem = item
} else {
currItem = item
}
return true
})
}
minBound := getMinimalBound(prevItem, currItem)
nextBound = minBound
}
boundEncoded := n.encodeBound(nextBound)
output.Write(boundEncoded)
output.WriteByte(FingerprintMode)
output.Write(ourFingerprint[:])
n.encodeBound(output, nextBound)
output.WriteByte(byte(FingerprintMode))
output.WriteBytes(ourFingerprint[:])
}
}
}

View File

@ -1,10 +1,11 @@
package negentropy
import (
"cmp"
"crypto/sha256"
"encoding/binary"
"encoding/hex"
"fmt"
"iter"
"strings"
"github.com/nbd-wtf/go-nostr"
@ -12,22 +13,35 @@ import (
const FingerprintSize = 16
type Mode int
type Mode uint8
const (
SkipMode = 0
FingerprintMode = 1
IdListMode = 2
SkipMode Mode = 0
FingerprintMode Mode = 1
IdListMode Mode = 2
)
func (v Mode) String() string {
switch v {
case SkipMode:
return "SKIP"
case FingerprintMode:
return "FINGERPRINT"
case IdListMode:
return "IDLIST"
default:
return "<UNKNOWN-ERROR>"
}
}
type Storage interface {
Insert(nostr.Timestamp, string) error
Seal()
Size() int
Iterate(begin, end int, cb func(item Item, i int) bool) error
Range(begin, end int) iter.Seq2[int, Item]
FindLowerBound(begin, end int, value Bound) int
GetBound(idx int) Bound
Fingerprint(begin, end int) ([FingerprintSize]byte, error)
Fingerprint(begin, end int) [FingerprintSize]byte
}
type Item struct {
@ -36,10 +50,10 @@ type Item struct {
}
func itemCompare(a, b Item) int {
if a.Timestamp != b.Timestamp {
return int(a.Timestamp - b.Timestamp)
if a.Timestamp == b.Timestamp {
return strings.Compare(a.ID, b.ID)
}
return strings.Compare(a.ID, b.ID)
return cmp.Compare(a.Timestamp, b.Timestamp)
}
func (i Item) String() string { return fmt.Sprintf("Item<%d:%s>", i.Timestamp, i.ID) }
@ -61,11 +75,6 @@ func (acc *Accumulator) SetToZero() {
acc.Buf = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
}
func (acc *Accumulator) Add(id string) {
b, _ := hex.DecodeString(id)
acc.AddBytes(b)
}
func (acc *Accumulator) AddAccumulator(other Accumulator) {
acc.AddBytes(other.Buf)
}
@ -95,12 +104,8 @@ func (acc *Accumulator) AddBytes(other []byte) {
}
}
func (acc *Accumulator) SV() []byte {
return acc.Buf[:]
}
func (acc *Accumulator) GetFingerprint(n int) [FingerprintSize]byte {
input := acc.SV()
input := acc.Buf[:]
input = append(input, encodeVarInt(n)...)
hash := sha256.Sum256(input)

View File

@ -1,7 +1,9 @@
package negentropy
import (
"encoding/hex"
"fmt"
"iter"
"slices"
"github.com/nbd-wtf/go-nostr"
@ -45,13 +47,14 @@ func (v *Vector) GetBound(idx int) Bound {
return infiniteBound
}
func (v *Vector) Iterate(begin, end int, cb func(Item, int) bool) error {
for i := begin; i < end; i++ {
if !cb(v.items[i], i) {
break
func (v *Vector) Range(begin, end int) iter.Seq2[int, Item] {
return func(yield func(int, Item) bool) {
for i := begin; i < end; i++ {
if !yield(i, v.items[i]) {
break
}
}
}
return nil
}
func (v *Vector) FindLowerBound(begin, end int, bound Bound) int {
@ -59,16 +62,15 @@ func (v *Vector) FindLowerBound(begin, end int, bound Bound) int {
return begin + idx
}
func (v *Vector) Fingerprint(begin, end int) ([FingerprintSize]byte, error) {
func (v *Vector) Fingerprint(begin, end int) [FingerprintSize]byte {
var out Accumulator
out.SetToZero()
if err := v.Iterate(begin, end, func(item Item, _ int) bool {
out.Add(item.ID)
return true
}); err != nil {
return [FingerprintSize]byte{}, err
tmp := make([]byte, 32)
for _, item := range v.Range(begin, end) {
hex.Decode(tmp, []byte(item.ID))
out.AddBytes(tmp)
}
return out.GetFingerprint(end - begin), nil
return out.GetFingerprint(end - begin)
}

View File

@ -1,10 +1,9 @@
package negentropy
import (
"encoding/hex"
"fmt"
"log"
"slices"
"strings"
"sync"
"testing"
@ -60,7 +59,7 @@ func runTestWith(t *testing.T,
expectedN1NeedRanges [][]int, expectedN1HaveRanges [][]int,
) {
var err error
var q []byte
var q string
var n1 *Negentropy
var n2 *Negentropy
@ -109,18 +108,21 @@ func runTestWith(t *testing.T,
wg := sync.WaitGroup{}
wg.Add(3)
var fatal error
go func() {
wg.Done()
for n := n1; q != nil; n = invert[n] {
defer wg.Done()
for n := n1; q != ""; n = invert[n] {
i++
fmt.Println("processing reconcile", n)
q, err = n.Reconcile(q)
if err != nil {
t.Fatal(err)
fatal = err
return
}
if q == nil {
if q == "" {
return
}
}
@ -141,6 +143,7 @@ func runTestWith(t *testing.T,
}
haves = append(haves, item)
}
slices.Sort(haves)
require.ElementsMatch(t, expectedHave, haves, "wrong have")
}()
@ -159,22 +162,12 @@ func runTestWith(t *testing.T,
}
havenots = append(havenots, item)
}
slices.Sort(havenots)
require.ElementsMatch(t, expectedNeed, havenots, "wrong need")
}()
wg.Wait()
}
func hexedBytes(o []byte) string {
s := strings.Builder{}
s.Grow(2 + 1 + len(o)*5)
s.WriteString("[ ")
for _, b := range o {
x := hex.EncodeToString([]byte{b})
s.WriteString("0x")
s.WriteString(x)
s.WriteString(" ")
if fatal != nil {
log.Fatal(fatal)
}
s.WriteString("]")
return s.String()
}