Compare commits

...

7 Commits

Author SHA1 Message Date
fiatjaf
3ec0020baa add OnDisconnect() handlers. 2023-12-09 09:00:11 -03:00
fiatjaf
d3a0c545d2 GetIP() and GetOpenSubscriptions() utils. 2023-12-09 08:19:37 -03:00
fiatjaf
c09d21b621 clarity: break->return 2023-12-09 00:14:08 -03:00
fiatjaf
5823515d27 streamlined connection closes on failure.
account for the fact that the time.Ticker channel is
not closed when the ticker is stopped.
2023-12-09 00:00:22 -03:00
fiatjaf
9273a4b809 use a special context for each REQ stored-events handler that can be canceled. 2023-12-08 23:48:30 -03:00
fiatjaf
ddfc9ab64a fun with connection contexts and context cancelations. 2023-12-08 22:51:00 -03:00
fiatjaf
375236cfe2 fix sign on error checking. 2023-12-06 21:32:48 -03:00
6 changed files with 76 additions and 41 deletions

1
go.mod
View File

@@ -44,6 +44,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.17 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect
github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a // indirect
github.com/tidwall/gjson v1.14.4 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect

2
go.sum
View File

@@ -102,6 +102,8 @@ github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik=
github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU=
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk=
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g=
github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a h1:iLcLb5Fwwz7g/DLK89F+uQBDeAhHhwdzB5fSlVdhGcM=
github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a/go.mod h1:wozgYq9WEBQBaIJe4YZ0qTSFAMxmcwBhQH0fO0R34Z0=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=

View File

@@ -6,6 +6,7 @@ import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
@@ -33,8 +34,6 @@ func (rl *Relay) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
connectionContext := r.Context()
conn, err := rl.upgrader.Upgrade(w, r, nil)
if err != nil {
rl.Log.Printf("failed to upgrade websocket: %v\n", err)
@@ -54,18 +53,29 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
Authed: make(chan struct{}),
}
connectionContext = context.WithValue(connectionContext, WS_KEY, ws)
ctx, cancel := context.WithCancel(
context.WithValue(
context.Background(),
WS_KEY, ws,
),
)
kill := func() {
for _, ondisconnect := range rl.OnDisconnect {
ondisconnect(ctx)
}
ticker.Stop()
cancel()
if _, ok := rl.clients.Load(conn); ok {
conn.Close()
rl.clients.Delete(conn)
removeListener(ws)
}
}
// reader
go func() {
defer func() {
ticker.Stop()
if _, ok := rl.clients.Load(conn); ok {
conn.Close()
rl.clients.Delete(conn)
removeListener(ws)
}
}()
defer kill()
conn.SetReadLimit(rl.MaxMessageSize)
conn.SetReadDeadline(time.Now().Add(rl.PongWait))
@@ -75,7 +85,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
})
for _, onconnect := range rl.OnConnect {
onconnect(connectionContext)
onconnect(ctx)
}
for {
@@ -90,7 +100,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
) {
rl.Log.Printf("unexpected close error from %s: %v\n", r.Header.Get("X-Forwarded-For"), err)
}
break
return
}
if typ == websocket.PingMessage {
@@ -99,14 +109,6 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
}
go func(message []byte) {
ctx := context.WithValue(
context.WithValue(
context.Background(),
AUTH_CONTEXT_KEY, connectionContext.Value(AUTH_CONTEXT_KEY),
),
WS_KEY, ws,
)
envelope := nostr.ParseMessage(message)
if envelope == nil {
// stop silently
@@ -163,24 +165,33 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
eose := sync.WaitGroup{}
eose.Add(len(env.Filters))
// a context just for the "stored events" request handler
reqCtx, cancelReqCtx := context.WithCancelCause(ctx)
// handle each filter separately -- dispatching events as they're loaded from databases
for _, filter := range env.Filters {
err := rl.handleRequest(ctx, env.SubscriptionID, &eose, ws, filter)
if err == nil {
err := rl.handleRequest(reqCtx, env.SubscriptionID, &eose, ws, filter)
if err != nil {
// fail everything if any filter is rejected
reason := nostr.NormalizeOKMessage(err.Error(), "blocked")
if isAuthRequired(reason) {
ws.WriteJSON(nostr.AuthEnvelope{Challenge: &ws.Challenge})
}
ws.WriteJSON(nostr.ClosedEnvelope{SubscriptionID: env.SubscriptionID, Reason: reason})
cancelReqCtx(fmt.Errorf("filter rejected"))
return
}
}
go func() {
// when all events have been loaded from databases and dispatched
// we can cancel the context and fire the EOSE message
eose.Wait()
cancelReqCtx(nil)
ws.WriteJSON(nostr.EOSEEnvelope(env.SubscriptionID))
}()
setListener(env.SubscriptionID, ws, env.Filters)
setListener(env.SubscriptionID, ws, env.Filters, cancelReqCtx)
case *nostr.CloseEnvelope:
removeListenerId(ws, string(*env))
case *nostr.AuthEnvelope:
@@ -188,7 +199,6 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
if pubkey, ok := nip42.ValidateAuthEvent(&env.Event, ws.Challenge, wsBaseUrl); ok {
ws.AuthedPublicKey = pubkey
close(ws.Authed)
connectionContext = context.WithValue(ctx, AUTH_CONTEXT_KEY, pubkey)
ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: true})
} else {
ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "error: failed to authenticate"})
@@ -198,15 +208,13 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
}
}()
// writer
go func() {
defer func() {
ticker.Stop()
conn.Close()
}()
defer kill()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
err := ws.WriteMessage(websocket.PingMessage, nil)
if err != nil {

View File

@@ -1,12 +1,16 @@
package khatru
import (
"context"
"fmt"
"github.com/nbd-wtf/go-nostr"
"github.com/puzpuzpuz/xsync/v2"
)
type Listener struct {
filters nostr.Filters
cancel context.CancelCauseFunc
}
var listeners = xsync.NewTypedMapOf[*WebSocket, *xsync.MapOf[string, *Listener]](pointerHasher[WebSocket])
@@ -43,24 +47,28 @@ func GetListeningFilters() nostr.Filters {
return respfilters
}
func setListener(id string, ws *WebSocket, filters nostr.Filters) {
func setListener(id string, ws *WebSocket, filters nostr.Filters, cancel context.CancelCauseFunc) {
subs, _ := listeners.LoadOrCompute(ws, func() *xsync.MapOf[string, *Listener] {
return xsync.NewMapOf[*Listener]()
})
subs.Store(id, &Listener{filters: filters})
subs.Store(id, &Listener{filters: filters, cancel: cancel})
}
// Remove a specific subscription id from listeners for a given ws client
// remove a specific subscription id from listeners for a given ws client
// and cancel its specific context
func removeListenerId(ws *WebSocket, id string) {
if subs, ok := listeners.Load(ws); ok {
subs.Delete(id)
if listener, ok := subs.LoadAndDelete(id); ok {
listener.cancel(fmt.Errorf("subscription closed by client"))
}
if subs.Size() == 0 {
listeners.Delete(ws)
}
}
}
// Remove WebSocket conn from listeners
// remove WebSocket conn from listeners
// (no need to cancel contexts as they are all inherited from the main connection context)
func removeListener(ws *WebSocket) {
listeners.Delete(ws)
}

View File

@@ -56,6 +56,7 @@ type Relay struct {
CountEvents []func(ctx context.Context, filter nostr.Filter) (int64, error)
OnAuth []func(ctx context.Context, pubkey string)
OnConnect []func(ctx context.Context)
OnDisconnect []func(ctx context.Context)
OnEventSaved []func(ctx context.Context, event *nostr.Event)
// editing info will affect

View File

@@ -2,6 +2,9 @@ package khatru
import (
"context"
"github.com/nbd-wtf/go-nostr"
"github.com/sebest/xff"
)
func GetConnection(ctx context.Context) *WebSocket {
@@ -9,9 +12,21 @@ func GetConnection(ctx context.Context) *WebSocket {
}
func GetAuthed(ctx context.Context) string {
authedPubkey := ctx.Value(AUTH_CONTEXT_KEY)
if authedPubkey == nil {
return ""
}
return authedPubkey.(string)
return GetConnection(ctx).AuthedPublicKey
}
func GetIP(ctx context.Context) string {
return xff.GetRemoteAddr(GetConnection(ctx).Request)
}
func GetOpenSubscriptions(ctx context.Context) []nostr.Filter {
if listeners, ok := listeners.Load(GetConnection(ctx)); ok {
res := make([]nostr.Filter, 0, listeners.Size()*2)
listeners.Range(func(_ string, listener *Listener) bool {
res = append(res, listener.filters...)
return true
})
return res
}
return nil
}