From 4705719c7666a46fd264da7266d33c32ed3e1ad8 Mon Sep 17 00:00:00 2001 From: fiatjaf Date: Fri, 4 Apr 2025 12:58:32 -0300 Subject: [PATCH] sdk: fix wot filter. --- sdk/wot.go | 181 +++++++++++++++++++++++++++++++++++------------- sdk/wot_test.go | 134 +++++++++++++++++++++++++++++++++++ 2 files changed, 265 insertions(+), 50 deletions(-) create mode 100644 sdk/wot_test.go diff --git a/sdk/wot.go b/sdk/wot.go index 96f128d..a6acf14 100644 --- a/sdk/wot.go +++ b/sdk/wot.go @@ -2,13 +2,12 @@ package sdk import ( "context" - "maps" - "slices" "strconv" + "sync" + "time" "github.com/FastFilter/xorfilter" "golang.org/x/sync/errgroup" - "sync" ) func PubKeyToShid(pubkey string) uint64 { @@ -16,59 +15,141 @@ func PubKeyToShid(pubkey string) uint64 { return shid } -func (sys *System) GetWoT(ctx context.Context, pubkey string) (map[uint64]struct{}, error) { - g, ctx := errgroup.WithContext(ctx) - g.SetLimit(30) - - res := make(chan uint64, 100) // Add buffer to prevent blocking - result := make(map[uint64]struct{}) - var resultMu sync.Mutex // Add mutex to protect map access - - // Start consumer goroutine - done := make(chan struct{}) - go func() { - defer close(done) - for shid := range res { - resultMu.Lock() - result[shid] = struct{}{} - resultMu.Unlock() - } - }() - - // Process follow lists - for _, f := range sys.FetchFollowList(ctx, pubkey).Items { - f := f // Capture loop variable - g.Go(func() error { - for _, f2 := range sys.FetchFollowList(ctx, f.Pubkey).Items { - select { - case res <- PubKeyToShid(f2.Pubkey): - case <-ctx.Done(): - return ctx.Err() - } - } - return nil - }) - } - - err := g.Wait() - close(res) // Close channel after all goroutines are done - <-done // Wait for consumer to finish - - return result, err +type wotCall struct { + id uint64 // basically the pubkey we're targeting here + mutex sync.Mutex + resultbacks []chan WotXorFilter // all callers waiting for results + errorbacks []chan error // all callers waiting for errors + done chan struct{} // this is closed when this call is fully resolved and deleted } -func (sys *System) GetWoTFilter(ctx context.Context, pubkey string) (WotXorFilter, error) { - m, err := sys.GetWoT(ctx, pubkey) - if err != nil { - return WotXorFilter{}, err +const wotCallsSize = 8 + +var ( + wotCallsMutex sync.Mutex + wotCallsInPlace [wotCallsSize]*wotCall +) + +func (sys *System) LoadWoTFilter(ctx context.Context, pubkey string) (WotXorFilter, error) { + id := PubKeyToShid(pubkey) + pos := int(id % wotCallsSize) + +start: + wotCallsMutex.Lock() + wc := wotCallsInPlace[pos] + if wc == nil { + // we are the first to call at this position + wc = &wotCall{ + id: id, + resultbacks: make([]chan WotXorFilter, 0), + errorbacks: make([]chan error, 0), + done: make(chan struct{}), + } + wotCallsInPlace[pos] = wc + wotCallsMutex.Unlock() + goto actualcall + } else { + wotCallsMutex.Unlock() } - xf, err := xorfilter.Populate(slices.Collect(maps.Keys(m))) - if err != nil { - return WotXorFilter{}, err + wc.mutex.Lock() + if wc.id == id { + // there is already a call for this exact pubkey ongoing, so we just wait + resch := make(chan WotXorFilter) + errch := make(chan error) + wc.resultbacks = append(wc.resultbacks, resch) + wc.errorbacks = append(wc.errorbacks, errch) + wc.mutex.Unlock() + select { + case res := <-resch: + return res, nil + case err := <-errch: + return WotXorFilter{}, err + } + } else { + wc.mutex.Unlock() + // there is already a call in this place, but it's for a different pubkey, so wait + <-wc.done + // when it's done restart + goto start } - return WotXorFilter{*xf}, nil +actualcall: + var res WotXorFilter + m, err := sys.loadWoT(ctx, pubkey) + if err != nil { + wc.mutex.Lock() + for _, ch := range wc.errorbacks { + ch <- err + } + } else { + res = makeWoTFilter(m) + wc.mutex.Lock() + for _, ch := range wc.resultbacks { + ch <- res + } + } + + wotCallsMutex.Lock() + wotCallsInPlace[pos] = nil + wc.mutex.Unlock() + close(wc.done) + wotCallsMutex.Unlock() + + return res, err +} + +func (sys *System) loadWoT(ctx context.Context, pubkey string) (chan string, error) { + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(45) + + res := make(chan string) + + // process follow lists + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + for _, f := range sys.FetchFollowList(ctx, pubkey).Items { + wg.Add(1) + + g.Go(func() error { + ctx, cancel := context.WithTimeout(ctx, time.Second*7) + defer cancel() + + ff := sys.FetchFollowList(ctx, f.Pubkey).Items + for _, f2 := range ff { + res <- f2.Pubkey + } + wg.Done() + return nil + }) + } + + wg.Done() + }() + + go func() { + wg.Wait() + close(res) + }() + + return res, nil +} + +func makeWoTFilter(m chan string) WotXorFilter { + shids := make([]uint64, 0, 60000) + shidMap := make(map[uint64]struct{}, 60000) + for pk := range m { + shid := PubKeyToShid(pk) + if _, alreadyAdded := shidMap[shid]; !alreadyAdded { + shidMap[shid] = struct{}{} + shids = append(shids, shid) + } + } + + xf, _ := xorfilter.Populate(shids) + return WotXorFilter{*xf} } type WotXorFilter struct { diff --git a/sdk/wot_test.go b/sdk/wot_test.go new file mode 100644 index 0000000..b08663c --- /dev/null +++ b/sdk/wot_test.go @@ -0,0 +1,134 @@ +package sdk + +import ( + "sync" + "testing" + "time" + + "github.com/nbd-wtf/go-nostr" + "github.com/stretchr/testify/require" +) + +func TestLoadWoT(t *testing.T) { + sys := NewSystem() + ctx := t.Context() + + // test with fiatjaf's pubkey + wotch, err := sys.loadWoT(ctx, "3bf0c63fcb93463407af97a5e5ee64fa883d107ef9e558472c4eb9aaaefa459d") + require.NoError(t, err) + + wot := make([]string, 0, 100000) + wotch2 := make(chan string) + + var filter WotXorFilter + done := make(chan struct{}) + go func() { + // test that we can get a filter from the WoT + filter = makeWoTFilter(wotch2) + close(done) + }() + + for pk := range wotch { + wot = append(wot, pk) + wotch2 <- pk + } + close(wotch2) + + // we should get a decent number of pubkeys in the WoT + require.Greater(t, len(wot), 10000, "should have more than 10000 pubkeys in WoT") + + // test that the filter contains some known pubkeys from the WoT + <-done + for _, pk := range wot { + require.True(t, filter.Contains(pk), "filter should contain all WoT pubkeys") + } +} + +func TestLoadWoTManyPeople(t *testing.T) { + sys := NewSystem() + ctx := t.Context() + + wg := sync.WaitGroup{} + wg.Add(3 + 2 + 2) + + diffs := make([]nostr.Timestamp, 5) + var rabble1 WotXorFilter + var rabble2 WotXorFilter + var rabble3 WotXorFilter + var alex1 WotXorFilter + var alex2 WotXorFilter + + // these are the same pubkey + go func() { + rabble, err := sys.LoadWoTFilter(ctx, "76c71aae3a491f1d9eec47cba17e229cda4113a0bbb6e6ae1776d7643e29cafa") + require.NoError(t, err) + diffs[0] = nostr.Now() + rabble1 = rabble + wg.Done() + }() + + time.Sleep(time.Millisecond * 20) + go func() { + rabble, err := sys.LoadWoTFilter(ctx, "76c71aae3a491f1d9eec47cba17e229cda4113a0bbb6e6ae1776d7643e29cafa") + require.NoError(t, err) + diffs[1] = nostr.Now() + rabble2 = rabble + wg.Done() + }() + + time.Sleep(time.Millisecond * 20) + go func() { + rabble, err := sys.LoadWoTFilter(ctx, "76c71aae3a491f1d9eec47cba17e229cda4113a0bbb6e6ae1776d7643e29cafa") + require.NoError(t, err) + diffs[2] = nostr.Now() + rabble3 = rabble + wg.Done() + }() + + // these should map to the same pos + time.Sleep(time.Millisecond * 20) + go func() { + alex, err := sys.LoadWoTFilter(ctx, "9ce71f1506ccf4b99f234af49bd6202be883a80f95a155c6e9a1c36fd7e780c7") + require.NoError(t, err) + diffs[3] = nostr.Now() + alex1 = alex + wg.Done() + }() + + time.Sleep(time.Millisecond * 20) + go func() { + alex, err := sys.LoadWoTFilter(ctx, "9ce71f1506ccf4b99f234af49bd6202be883a80f95a155c6e9a1c36fd7e780c7") + require.NoError(t, err) + diffs[4] = nostr.Now() + alex2 = alex + wg.Done() + }() + + // these are independent + go func() { + hodlbod, err := sys.LoadWoTFilter(ctx, "97c70a44366a6535c145b333f973ea86dfdc2d7a99da618c40c64705ad98e322") + require.NoError(t, err) + require.True(t, hodlbod.Contains("ee11a5dff40c19a555f41fe42b48f00e618c91225622ae37b6c2bb67b76c4e49")) + require.True(t, hodlbod.Contains("76c71aae3a491f1d9eec47cba17e229cda4113a0bbb6e6ae1776d7643e29cafa")) + require.True(t, hodlbod.Contains("3bf0c63fcb93463407af97a5e5ee64fa883d107ef9e558472c4eb9aaaefa459d")) + wg.Done() + }() + go func() { + mikedilger, err := sys.LoadWoTFilter(ctx, "ee11a5dff40c19a555f41fe42b48f00e618c91225622ae37b6c2bb67b76c4e49") + require.NoError(t, err) + require.True(t, mikedilger.Contains("97c70a44366a6535c145b333f973ea86dfdc2d7a99da618c40c64705ad98e322")) + require.True(t, mikedilger.Contains("3bf0c63fcb93463407af97a5e5ee64fa883d107ef9e558472c4eb9aaaefa459d")) + wg.Done() + }() + + wg.Wait() + + require.Equal(t, rabble1, rabble2) + require.Equal(t, rabble2, rabble3) + require.Equal(t, alex1, alex2) + + require.Less(t, int(diffs[1]-diffs[0]), 1, "second duplicated call should resolve immediately") + require.Less(t, int(diffs[2]-diffs[1]), 1, "third duplicated call should resolve immediately") + require.Greater(t, int(diffs[3]-diffs[2]), 10, "the next call should take a long time") + require.Less(t, int(diffs[4]-diffs[3]), 1, "and then a duplicated call should resolve immediately") +}