This commit is contained in:
fiatjaf 2024-08-01 15:36:14 -03:00
parent a103353254
commit aa9e53b61d
5 changed files with 347 additions and 32 deletions

4
go.mod
View File

@ -8,6 +8,7 @@ require (
github.com/nbd-wtf/go-nostr v0.34.5 github.com/nbd-wtf/go-nostr v0.34.5
github.com/puzpuzpuz/xsync/v3 v3.0.2 github.com/puzpuzpuz/xsync/v3 v3.0.2
github.com/rs/cors v1.7.0 github.com/rs/cors v1.7.0
github.com/stretchr/testify v1.9.0
) )
require ( require (
@ -17,6 +18,7 @@ require (
github.com/btcsuite/btcd/btcec/v2 v2.3.2 // indirect github.com/btcsuite/btcd/btcec/v2 v2.3.2 // indirect
github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2 // indirect github.com/btcsuite/btcd/chaincfg/chainhash v1.0.2 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/decred/dcrd/crypto/blake256 v1.0.1 // indirect github.com/decred/dcrd/crypto/blake256 v1.0.1 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
github.com/dgraph-io/badger/v4 v4.2.0 // indirect github.com/dgraph-io/badger/v4 v4.2.0 // indirect
@ -42,6 +44,7 @@ require (
github.com/mailru/easyjson v0.7.7 // indirect github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-sqlite3 v1.14.18 // indirect github.com/mattn/go-sqlite3 v1.14.18 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect
github.com/tidwall/gjson v1.17.0 // indirect github.com/tidwall/gjson v1.17.0 // indirect
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
@ -53,4 +56,5 @@ require (
golang.org/x/net v0.18.0 // indirect golang.org/x/net v0.18.0 // indirect
golang.org/x/sys v0.20.0 // indirect golang.org/x/sys v0.20.0 // indirect
google.golang.org/protobuf v1.31.0 // indirect google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )

6
go.sum
View File

@ -105,6 +105,10 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU=
github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
@ -224,6 +228,8 @@ google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQ
google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@ -80,20 +80,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
cancel() cancel()
conn.Close() conn.Close()
rl.clientsMutex.Lock() rl.removeClientAndListeners(ws)
defer rl.clientsMutex.Unlock()
if specs, ok := rl.clients[ws]; ok {
// swap delete listeners and delete client
for s, spec := range specs {
// no need to cancel contexts since they inherit from the main connection context
// just delete the listeners
srl := spec.subrelay
srl.listeners[spec.index] = srl.listeners[len(srl.listeners)-1]
specs[s] = specs[len(specs)-1]
srl.listeners = srl.listeners[0:len(srl.listeners)]
}
}
delete(rl.clients, ws)
} }
go func() { go func() {

View File

@ -3,6 +3,7 @@ package khatru
import ( import (
"context" "context"
"errors" "errors"
"slices"
"github.com/nbd-wtf/go-nostr" "github.com/nbd-wtf/go-nostr"
) )
@ -10,16 +11,16 @@ import (
var ErrSubscriptionClosedByClient = errors.New("subscription closed by client") var ErrSubscriptionClosedByClient = errors.New("subscription closed by client")
type listenerSpec struct { type listenerSpec struct {
subscriptionId string // kept here so we can easily match against it removeListenerId id string // kept here so we can easily match against it removeListenerId
cancel context.CancelCauseFunc cancel context.CancelCauseFunc
index int index int
subrelay *Relay // this is important when we're dealing with routing, otherwise it will be always the same subrelay *Relay // this is important when we're dealing with routing, otherwise it will be always the same
} }
type listener struct { type listener struct {
subscriptionId string // duplicated here so we can easily send it on notifyListeners id string // duplicated here so we can easily send it on notifyListeners
filter nostr.Filter filter nostr.Filter
ws *WebSocket ws *WebSocket
} }
func (rl *Relay) GetListeningFilters() []nostr.Filter { func (rl *Relay) GetListeningFilters() []nostr.Filter {
@ -45,15 +46,15 @@ func (rl *Relay) addListener(
if specs, ok := rl.clients[ws]; ok /* this will always be true unless client has disconnected very rapidly */ { if specs, ok := rl.clients[ws]; ok /* this will always be true unless client has disconnected very rapidly */ {
idx := len(subrelay.listeners) idx := len(subrelay.listeners)
rl.clients[ws] = append(specs, listenerSpec{ rl.clients[ws] = append(specs, listenerSpec{
subscriptionId: id, id: id,
cancel: cancel, cancel: cancel,
subrelay: subrelay, subrelay: subrelay,
index: idx, index: idx,
}) })
subrelay.listeners = append(subrelay.listeners, listener{ subrelay.listeners = append(subrelay.listeners, listener{
ws: ws, ws: ws,
subscriptionId: id, id: id,
filter: filter, filter: filter,
}) })
} }
} }
@ -68,21 +69,63 @@ func (rl *Relay) removeListenerId(ws *WebSocket, id string) {
// swap delete specs that match this id // swap delete specs that match this id
nswaps := 0 nswaps := 0
for s, spec := range specs { for s, spec := range specs {
if spec.subscriptionId == id { if spec.id == id {
spec.cancel(ErrSubscriptionClosedByClient) spec.cancel(ErrSubscriptionClosedByClient)
specs[s] = specs[len(specs)-1-nswaps] specs[s] = specs[len(specs)-1-nswaps]
nswaps++ nswaps++
// swap delete listeners one at a time, as they may be each in a different subrelay // swap delete listeners one at a time, as they may be each in a different subrelay
srl := spec.subrelay // == rl in normal cases, but different when this came from a route srl := spec.subrelay // == rl in normal cases, but different when this came from a route
srl.listeners[spec.index] = srl.listeners[len(srl.listeners)-1]
srl.listeners = srl.listeners[0 : len(srl.listeners)-1] if spec.index != len(srl.listeners)-1 {
movedFromIndex := len(srl.listeners) - 1
moved := srl.listeners[movedFromIndex] // this wasn't removed, but will be moved
srl.listeners[spec.index] = moved
// update the index of the listener we just moved
movedSpecs := rl.clients[moved.ws]
idx := slices.IndexFunc(movedSpecs, func(ls listenerSpec) bool {
return ls.index == movedFromIndex
})
movedSpecs[idx].index = spec.index
rl.clients[moved.ws] = movedSpecs
}
srl.listeners = srl.listeners[0 : len(srl.listeners)-1] // finally reduce the slice length
} }
} }
rl.clients[ws] = specs[0 : len(specs)-nswaps] rl.clients[ws] = specs[0 : len(specs)-nswaps]
} }
} }
func (rl *Relay) removeClientAndListeners(ws *WebSocket) {
rl.clientsMutex.Lock()
defer rl.clientsMutex.Unlock()
if specs, ok := rl.clients[ws]; ok {
// swap delete listeners and delete client (all specs will be deleted)
for _, spec := range specs {
// no need to cancel contexts since they inherit from the main connection context
// just delete the listeners (swap-delete)
srl := spec.subrelay
if spec.index != len(srl.listeners)-1 {
movedFromIndex := len(srl.listeners) - 1
moved := srl.listeners[movedFromIndex] // this wasn't removed, but will be moved
srl.listeners[spec.index] = moved
// update the index of the listener we just moved
movedSpecs := rl.clients[moved.ws]
idx := slices.IndexFunc(movedSpecs, func(ls listenerSpec) bool {
return ls.index == movedFromIndex
})
movedSpecs[idx].index = spec.index
rl.clients[moved.ws] = movedSpecs
}
srl.listeners = srl.listeners[0 : len(srl.listeners)-1] // finally reduce the slice length
}
}
delete(rl.clients, ws)
}
func (rl *Relay) notifyListeners(event *nostr.Event) { func (rl *Relay) notifyListeners(event *nostr.Event) {
for _, listener := range rl.listeners { for _, listener := range rl.listeners {
if listener.filter.Matches(event) { if listener.filter.Matches(event) {
@ -91,7 +134,7 @@ func (rl *Relay) notifyListeners(event *nostr.Event) {
return return
} }
} }
listener.ws.WriteJSON(nostr.EventEnvelope{SubscriptionID: &listener.subscriptionId, Event: *event}) listener.ws.WriteJSON(nostr.EventEnvelope{SubscriptionID: &listener.id, Event: *event})
} }
} }
} }

275
listener_test.go Normal file
View File

@ -0,0 +1,275 @@
package khatru
import (
"testing"
"github.com/nbd-wtf/go-nostr"
"github.com/stretchr/testify/require"
)
func TestListenerSetupAndRemoveOnce(t *testing.T) {
rl := NewRelay()
ws1 := &WebSocket{}
ws2 := &WebSocket{}
f1 := nostr.Filter{Kinds: []int{1}}
f2 := nostr.Filter{Kinds: []int{2}}
f3 := nostr.Filter{Kinds: []int{3}}
rl.clients[ws1] = nil
rl.clients[ws2] = nil
var cancel func(cause error) = nil
t.Run("adding listeners", func(t *testing.T) {
rl.addListener(ws1, "1a", rl, f1, cancel)
rl.addListener(ws1, "1b", rl, f2, cancel)
rl.addListener(ws2, "2a", rl, f3, cancel)
rl.addListener(ws1, "1c", rl, f3, cancel)
require.Equal(t, map[*WebSocket][]listenerSpec{
ws1: {
{"1a", cancel, 0, rl},
{"1b", cancel, 1, rl},
{"1c", cancel, 3, rl},
},
ws2: {
{"2a", cancel, 2, rl},
},
}, rl.clients)
require.Equal(t, []listener{
{"1a", f1, ws1},
{"1b", f2, ws1},
{"2a", f3, ws2},
{"1c", f3, ws1},
}, rl.listeners)
})
t.Run("removing a client", func(t *testing.T) {
rl.removeClientAndListeners(ws1)
require.Equal(t, map[*WebSocket][]listenerSpec{
ws2: {
{"2a", cancel, 0, rl},
},
}, rl.clients)
require.Equal(t, []listener{
{"2a", f3, ws2},
}, rl.listeners)
})
}
func TestListenerMoreConvolutedCase(t *testing.T) {
rl := NewRelay()
ws1 := &WebSocket{}
ws2 := &WebSocket{}
ws3 := &WebSocket{}
ws4 := &WebSocket{}
f1 := nostr.Filter{Kinds: []int{1}}
f2 := nostr.Filter{Kinds: []int{2}}
f3 := nostr.Filter{Kinds: []int{3}}
rl.clients[ws1] = nil
rl.clients[ws2] = nil
rl.clients[ws3] = nil
rl.clients[ws4] = nil
var cancel func(cause error) = nil
t.Run("adding listeners", func(t *testing.T) {
rl.addListener(ws1, "c", rl, f1, cancel)
rl.addListener(ws2, "b", rl, f2, cancel)
rl.addListener(ws3, "a", rl, f3, cancel)
rl.addListener(ws4, "d", rl, f3, cancel)
rl.addListener(ws2, "b", rl, f1, cancel)
require.Equal(t, map[*WebSocket][]listenerSpec{
ws1: {
{"c", cancel, 0, rl},
},
ws2: {
{"b", cancel, 1, rl},
{"b", cancel, 4, rl},
},
ws3: {
{"a", cancel, 2, rl},
},
ws4: {
{"d", cancel, 3, rl},
},
}, rl.clients)
require.Equal(t, []listener{
{"c", f1, ws1},
{"b", f2, ws2},
{"a", f3, ws3},
{"d", f3, ws4},
{"b", f1, ws2},
}, rl.listeners)
})
t.Run("removing a client", func(t *testing.T) {
rl.removeClientAndListeners(ws2)
require.Equal(t, map[*WebSocket][]listenerSpec{
ws1: {
{"c", cancel, 0, rl},
},
ws3: {
{"a", cancel, 2, rl},
},
ws4: {
{"d", cancel, 1, rl},
},
}, rl.clients)
require.Equal(t, []listener{
{"c", f1, ws1},
{"d", f3, ws4},
{"a", f3, ws3},
}, rl.listeners)
})
t.Run("reorganize the first case differently and then remove again", func(t *testing.T) {
rl.clients = map[*WebSocket][]listenerSpec{
ws1: {
{"c", cancel, 1, rl},
},
ws2: {
{"b", cancel, 2, rl},
{"b", cancel, 4, rl},
},
ws3: {
{"a", cancel, 0, rl},
},
ws4: {
{"d", cancel, 3, rl},
},
}
rl.listeners = []listener{
{"a", f3, ws3},
{"c", f1, ws1},
{"b", f2, ws2},
{"d", f3, ws4},
{"b", f1, ws2},
}
rl.removeClientAndListeners(ws2)
require.Equal(t, map[*WebSocket][]listenerSpec{
ws1: {
{"c", cancel, 1, rl},
},
ws3: {
{"a", cancel, 0, rl},
},
ws4: {
{"d", cancel, 2, rl},
},
}, rl.clients)
require.Equal(t, []listener{
{"a", f3, ws3},
{"c", f1, ws1},
{"d", f3, ws4},
}, rl.listeners)
})
}
func TestListenerMoreStuffWithMultipleRelays(t *testing.T) {
rl := NewRelay()
ws1 := &WebSocket{}
ws2 := &WebSocket{}
ws3 := &WebSocket{}
ws4 := &WebSocket{}
f1 := nostr.Filter{Kinds: []int{1}}
f2 := nostr.Filter{Kinds: []int{2}}
f3 := nostr.Filter{Kinds: []int{3}}
rlx := NewRelay()
rly := NewRelay()
rlz := NewRelay()
rl.clients[ws1] = nil
rl.clients[ws2] = nil
rl.clients[ws3] = nil
rl.clients[ws4] = nil
var cancel func(cause error) = nil
t.Run("adding listeners", func(t *testing.T) {
rl.addListener(ws1, "c", rlx, f1, cancel)
rl.addListener(ws2, "b", rly, f2, cancel)
rl.addListener(ws3, "a", rlz, f3, cancel)
rl.addListener(ws4, "d", rlx, f3, cancel)
rl.addListener(ws4, "e", rlx, f3, cancel)
rl.addListener(ws3, "a", rlx, f3, cancel)
rl.addListener(ws4, "e", rly, f3, cancel)
rl.addListener(ws3, "f", rly, f3, cancel)
rl.addListener(ws1, "g", rlz, f1, cancel)
rl.addListener(ws2, "g", rlz, f2, cancel)
require.Equal(t, map[*WebSocket][]listenerSpec{
ws1: {
{"c", cancel, 0, rlx},
{"g", cancel, 1, rlz},
},
ws2: {
{"b", cancel, 0, rly},
{"g", cancel, 2, rlz},
},
ws3: {
{"a", cancel, 0, rlz},
{"a", cancel, 3, rlx},
{"f", cancel, 3, rly},
},
ws4: {
{"d", cancel, 1, rlx},
{"e", cancel, 2, rlx},
{"e", cancel, 2, rly},
},
}, rl.clients)
require.Equal(t, []listener{
{"c", f1, ws1},
{"b", f2, ws2},
{"a", f3, ws3},
{"d", f3, ws4},
{"e", f3, ws4},
{"a", f3, ws3},
{"e", f3, ws4},
{"f", f3, ws3},
{"g", f1, ws1},
{"g", f2, ws2},
}, rl.listeners)
})
// t.Run("removing a client", func(t *testing.T) {
// rl.removeClientAndListeners(ws2)
// require.Equal(t, map[*WebSocket][]listenerSpec{
// ws1: {
// {"c", cancel, 0, rl},
// },
// ws3: {
// {"a", cancel, 2, rl},
// },
// ws4: {
// {"d", cancel, 1, rl},
// },
// }, rl.clients)
// require.Equal(t, []listener{
// {"c", f1, ws1},
// {"d", f3, ws4},
// {"a", f3, ws3},
// }, rl.listeners)
// })
}