negentropy: small refactors here and there, comments and making the code clearer.

This commit is contained in:
fiatjaf
2024-09-18 15:47:08 -03:00
parent 47243fdcc4
commit 6910f391fe
3 changed files with 69 additions and 78 deletions

View File

@@ -6,34 +6,38 @@ import (
"github.com/nbd-wtf/go-nostr" "github.com/nbd-wtf/go-nostr"
) )
func (n *Negentropy) DecodeTimestampIn(reader *StringHexReader) (nostr.Timestamp, error) { func (n *Negentropy) readTimestamp(reader *StringHexReader) (nostr.Timestamp, error) {
t, err := decodeVarInt(reader) delta, err := readVarInt(reader)
if err != nil { if err != nil {
return 0, err return 0, err
} }
timestamp := nostr.Timestamp(t) if delta == 0 {
if timestamp == 0 { // zeroes are infinite
timestamp = maxTimestamp timestamp := maxTimestamp
} else { n.lastTimestampIn = timestamp
timestamp-- return timestamp, nil
} }
timestamp += n.lastTimestampIn // remove 1 as we always add 1 when encoding
if timestamp < n.lastTimestampIn { // Check for overflow delta--
timestamp = maxTimestamp
} // we add the previously cached timestamp to get the current
timestamp := n.lastTimestampIn + nostr.Timestamp(delta)
// cache this so we can apply it to the delta next time
n.lastTimestampIn = timestamp n.lastTimestampIn = timestamp
return timestamp, nil return timestamp, nil
} }
func (n *Negentropy) DecodeBound(reader *StringHexReader) (Bound, error) { func (n *Negentropy) readBound(reader *StringHexReader) (Bound, error) {
timestamp, err := n.DecodeTimestampIn(reader) timestamp, err := n.readTimestamp(reader)
if err != nil { if err != nil {
return Bound{}, fmt.Errorf("failed to decode bound timestamp: %w", err) return Bound{}, fmt.Errorf("failed to decode bound timestamp: %w", err)
} }
length, err := decodeVarInt(reader) length, err := readVarInt(reader)
if err != nil { if err != nil {
return Bound{}, fmt.Errorf("failed to decode bound length: %w", err) return Bound{}, fmt.Errorf("failed to decode bound length: %w", err)
} }
@@ -46,22 +50,28 @@ func (n *Negentropy) DecodeBound(reader *StringHexReader) (Bound, error) {
return Bound{Item{timestamp, id}}, nil return Bound{Item{timestamp, id}}, nil
} }
func (n *Negentropy) encodeTimestampOut(w *StringHexWriter, timestamp nostr.Timestamp) { func (n *Negentropy) writeTimestamp(w *StringHexWriter, timestamp nostr.Timestamp) {
if timestamp == maxTimestamp { if timestamp == maxTimestamp {
n.lastTimestampOut = maxTimestamp // zeroes are infinite
encodeVarIntToHex(w, 0) n.lastTimestampOut = maxTimestamp // cache this (see below)
writeVarInt(w, 0)
return return
} }
temp := timestamp
timestamp -= n.lastTimestampOut // we will only encode the difference between this timestamp and the previous
n.lastTimestampOut = temp delta := timestamp - n.lastTimestampOut
encodeVarIntToHex(w, int(timestamp+1))
// we cache this here as the next timestamp we encode will be just a delta from this
n.lastTimestampOut = timestamp
// add 1 to prevent zeroes from being read as infinites
writeVarInt(w, int(delta+1))
return return
} }
func (n *Negentropy) encodeBound(w *StringHexWriter, bound Bound) { func (n *Negentropy) writeBound(w *StringHexWriter, bound Bound) {
n.encodeTimestampOut(w, bound.Timestamp) n.writeTimestamp(w, bound.Timestamp)
encodeVarIntToHex(w, len(bound.ID)/2) writeVarInt(w, len(bound.ID)/2)
w.WriteHex(bound.Item.ID) w.WriteHex(bound.Item.ID)
} }
@@ -83,7 +93,7 @@ func getMinimalBound(prev, curr Item) Bound {
return Bound{Item{curr.Timestamp, curr.ID[:(sharedPrefixBytes+1)*2]}} return Bound{Item{curr.Timestamp, curr.ID[:(sharedPrefixBytes+1)*2]}}
} }
func decodeVarInt(reader *StringHexReader) (int, error) { func readVarInt(reader *StringHexReader) (int, error) {
var res int = 0 var res int = 0
for { for {
@@ -101,6 +111,15 @@ func decodeVarInt(reader *StringHexReader) (int, error) {
return res, nil return res, nil
} }
func writeVarInt(w *StringHexWriter, n int) {
if n == 0 {
w.WriteByte(0)
return
}
w.WriteBytes(encodeVarInt(n))
}
func encodeVarInt(n int) []byte { func encodeVarInt(n int) []byte {
if n == 0 { if n == 0 {
return []byte{0} return []byte{0}
@@ -118,21 +137,3 @@ func encodeVarInt(n int) []byte {
return o return o
} }
func encodeVarIntToHex(w *StringHexWriter, n int) {
if n == 0 {
w.WriteByte(0)
}
var o []byte
for n != 0 {
o = append([]byte{byte(n & 0x7F)}, o...)
n >>= 7
}
for i := 0; i < len(o)-1; i++ {
o[i] |= 0x80
}
w.WriteBytes(o)
}

View File

@@ -125,16 +125,16 @@ func (n *Negentropy) reconcileAux(reader *StringHexReader) (string, error) {
// end skip range, if necessary, so we can start a new bound that isn't a skip // end skip range, if necessary, so we can start a new bound that isn't a skip
if skipping { if skipping {
skipping = false skipping = false
n.encodeBound(partialOutput, prevBound) n.writeBound(partialOutput, prevBound)
partialOutput.WriteByte(byte(SkipMode)) partialOutput.WriteByte(byte(SkipMode))
} }
} }
currBound, err := n.DecodeBound(reader) currBound, err := n.readBound(reader)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to decode bound: %w", err) return "", fmt.Errorf("failed to decode bound: %w", err)
} }
modeVal, err := decodeVarInt(reader) modeVal, err := readVarInt(reader)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to decode mode: %w", err) return "", fmt.Errorf("failed to decode mode: %w", err)
} }
@@ -162,7 +162,7 @@ func (n *Negentropy) reconcileAux(reader *StringHexReader) (string, error) {
} }
case IdListMode: case IdListMode:
numIds, err := decodeVarInt(reader) numIds, err := readVarInt(reader)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to decode number of ids: %w", err) return "", fmt.Errorf("failed to decode number of ids: %w", err)
} }
@@ -222,9 +222,9 @@ func (n *Negentropy) reconcileAux(reader *StringHexReader) (string, error) {
responses++ responses++
} }
n.encodeBound(partialOutput, endBound) n.writeBound(partialOutput, endBound)
partialOutput.WriteByte(byte(IdListMode)) partialOutput.WriteByte(byte(IdListMode))
encodeVarIntToHex(partialOutput, responses) writeVarInt(partialOutput, responses)
partialOutput.WriteHex(responseIds.String()) partialOutput.WriteHex(responseIds.String())
fullOutput.WriteHex(partialOutput.Hex()) fullOutput.WriteHex(partialOutput.Hex())
@@ -238,7 +238,7 @@ func (n *Negentropy) reconcileAux(reader *StringHexReader) (string, error) {
if n.frameSizeLimit-200 < fullOutput.Len()+partialOutput.Len() { if n.frameSizeLimit-200 < fullOutput.Len()+partialOutput.Len() {
// frame size limit exceeded, handle by encoding a boundary and fingerprint for the remaining range // frame size limit exceeded, handle by encoding a boundary and fingerprint for the remaining range
remainingFingerprint := n.storage.Fingerprint(upper, n.storage.Size()) remainingFingerprint := n.storage.Fingerprint(upper, n.storage.Size())
n.encodeBound(fullOutput, infiniteBound) n.writeBound(fullOutput, infiniteBound)
fullOutput.WriteByte(byte(FingerprintMode)) fullOutput.WriteByte(byte(FingerprintMode))
fullOutput.WriteBytes(remainingFingerprint[:]) fullOutput.WriteBytes(remainingFingerprint[:])
@@ -261,9 +261,9 @@ func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *Stri
if numElems < buckets*2 { if numElems < buckets*2 {
// we just send the full ids here // we just send the full ids here
n.encodeBound(output, upperBound) n.writeBound(output, upperBound)
output.WriteByte(byte(IdListMode)) output.WriteByte(byte(IdListMode))
encodeVarIntToHex(output, numElems) writeVarInt(output, numElems)
for _, item := range n.storage.Range(lower, upper) { for _, item := range n.storage.Range(lower, upper) {
output.WriteHex(item.ID) output.WriteHex(item.ID)
@@ -299,7 +299,7 @@ func (n *Negentropy) SplitRange(lower, upper int, upperBound Bound, output *Stri
nextBound = minBound nextBound = minBound
} }
n.encodeBound(output, nextBound) n.writeBound(output, nextBound)
output.WriteByte(byte(FingerprintMode)) output.WriteByte(byte(FingerprintMode))
output.WriteBytes(ourFingerprint[:]) output.WriteBytes(ourFingerprint[:])
} }

View File

@@ -2,7 +2,6 @@ package negentropy
import ( import (
"fmt" "fmt"
"log"
"slices" "slices"
"sync" "sync"
"testing" "testing"
@@ -106,27 +105,7 @@ func runTestWith(t *testing.T,
i := 1 i := 1
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(3) wg.Add(2)
var fatal error
go func() {
defer wg.Done()
for n := n1; q != ""; n = invert[n] {
i++
fmt.Println("processing reconcile", n)
q, err = n.Reconcile(q)
if err != nil {
fatal = err
return
}
if q == "" {
return
}
}
}()
go func() { go func() {
defer wg.Done() defer wg.Done()
@@ -166,8 +145,19 @@ func runTestWith(t *testing.T,
require.Equal(t, expectedNeed, havenots, "wrong need") require.Equal(t, expectedNeed, havenots, "wrong need")
}() }()
wg.Wait() for n := n1; q != ""; n = invert[n] {
if fatal != nil { i++
log.Fatal(fatal)
fmt.Println("processing reconcile", n)
q, err = n.Reconcile(q)
if err != nil {
t.Fatalf("reconciliation failed: %s", err)
}
if q == "" {
wg.Wait()
return
}
} }
} }