Compare commits

..

7 Commits

Author SHA1 Message Date
fiatjaf
5823515d27 streamlined connection closes on failure.
account for the fact that the time.Ticker channel is
not closed when the ticker is stopped.
2023-12-09 00:00:22 -03:00
fiatjaf
9273a4b809 use a special context for each REQ stored-events handler that can be canceled. 2023-12-08 23:48:30 -03:00
fiatjaf
ddfc9ab64a fun with connection contexts and context cancelations. 2023-12-08 22:51:00 -03:00
fiatjaf
375236cfe2 fix sign on error checking. 2023-12-06 21:32:48 -03:00
fiatjaf
35e801379a make NIP-42 actually work, with inferred ServiceURL if that's not manually set. 2023-12-06 15:03:53 -03:00
fiatjaf
22da06b629 new flow for auth based on "auth-required: " rejection messages. 2023-12-06 12:14:58 -03:00
fiatjaf
7bfde76ab1 example fix. 2023-12-06 12:14:27 -03:00
7 changed files with 131 additions and 69 deletions

View File

@@ -26,7 +26,7 @@ func main() {
relay.CountEvents = append(relay.CountEvents, db.CountEvents)
relay.DeleteEvent = append(relay.DeleteEvent, db.DeleteEvent)
relay.RejectEvent = append(relay.RejectEvent, policies.PreventTooManyIndexableTags(10))
relay.RejectEvent = append(relay.RejectEvent, policies.PreventTooManyIndexableTags(10, nil, nil))
relay.RejectFilter = append(relay.RejectFilter, policies.NoComplexFilters)
relay.OnEventSaved = append(relay.OnEventSaved, func(ctx context.Context, event *nostr.Event) {

View File

@@ -60,15 +60,20 @@ func main() {
return false, "" // anyone else can
},
)
relay.OnConnect = append(relay.OnConnect,
func(ctx context.Context) {
// request NIP-42 AUTH from everybody
relay.RequestAuth(ctx)
// you can request auth by rejecting an event or a request with the prefix "auth-required: "
relay.RejectFilter = append(relay.RejectFilter,
func(ctx context.Context, filter nostr.Filter) (reject bool, msg string) {
if pubkey := khatru.GetAuthed(ctx); pubkey != "" {
log.Printf("request from %s\n", pubkey)
return false, ""
}
return true, "auth-required: only authenticated users can read from this relay"
},
)
relay.OnAuth = append(relay.OnAuth,
func(ctx context.Context, pubkey string) {
// and when they auth we just log that for nothing
// and when they auth we can just log that for nothing
log.Println(pubkey + " is authed!")
},
)

View File

@@ -6,6 +6,7 @@ import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
@@ -19,6 +20,10 @@ import (
// ServeHTTP implements http.Handler interface.
func (rl *Relay) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if rl.ServiceURL == "" {
rl.ServiceURL = getServiceBaseURL(r)
}
if r.Header.Get("Upgrade") == "websocket" {
rl.HandleWebsocket(w, r)
} else if r.Header.Get("Accept") == "application/nostr+json" {
@@ -29,8 +34,6 @@ func (rl *Relay) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
conn, err := rl.upgrader.Upgrade(w, r, nil)
if err != nil {
rl.Log.Printf("failed to upgrade websocket: %v\n", err)
@@ -50,18 +53,25 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
Authed: make(chan struct{}),
}
ctx = context.WithValue(ctx, WS_KEY, ws)
ctx, cancel := context.WithCancel(
context.WithValue(
context.Background(),
WS_KEY, ws,
),
)
kill := func() {
ticker.Stop()
cancel()
if _, ok := rl.clients.Load(conn); ok {
conn.Close()
rl.clients.Delete(conn)
removeListener(ws)
}
}
// reader
go func() {
defer func() {
ticker.Stop()
if _, ok := rl.clients.Load(conn); ok {
conn.Close()
rl.clients.Delete(conn)
removeListener(ws)
}
}()
defer kill()
conn.SetReadLimit(rl.MaxMessageSize)
conn.SetReadDeadline(time.Now().Add(rl.PongWait))
@@ -95,8 +105,6 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
}
go func(message []byte) {
ctx := context.WithValue(context.Background(), WS_KEY, ws)
envelope := nostr.ParseMessage(message)
if envelope == nil {
// stop silently
@@ -134,6 +142,9 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
ok = true
} else {
reason = nostr.NormalizeOKMessage(err.Error(), "blocked")
if isAuthRequired(reason) {
ws.WriteJSON(nostr.AuthEnvelope{Challenge: &ws.Challenge})
}
}
ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: ok, Reason: reason})
case *nostr.CountEnvelope:
@@ -150,48 +161,57 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
eose := sync.WaitGroup{}
eose.Add(len(env.Filters))
// a context just for the "stored events" request handler
reqCtx, cancelReqCtx := context.WithCancelCause(ctx)
// handle each filter separately -- dispatching events as they're loaded from databases
for _, filter := range env.Filters {
err := rl.handleRequest(ctx, env.SubscriptionID, &eose, ws, filter)
if err == nil {
err := rl.handleRequest(reqCtx, env.SubscriptionID, &eose, ws, filter)
if err != nil {
// fail everything if any filter is rejected
reason := nostr.NormalizeOKMessage(err.Error(), "blocked")
if isAuthRequired(reason) {
ws.WriteJSON(nostr.AuthEnvelope{Challenge: &ws.Challenge})
}
ws.WriteJSON(nostr.ClosedEnvelope{SubscriptionID: env.SubscriptionID, Reason: reason})
cancelReqCtx(fmt.Errorf("filter rejected"))
return
}
}
go func() {
// when all events have been loaded from databases and dispatched
// we can cancel the context and fire the EOSE message
eose.Wait()
cancelReqCtx(nil)
ws.WriteJSON(nostr.EOSEEnvelope(env.SubscriptionID))
}()
setListener(env.SubscriptionID, ws, env.Filters)
setListener(env.SubscriptionID, ws, env.Filters, cancelReqCtx)
case *nostr.CloseEnvelope:
removeListenerId(ws, string(*env))
case *nostr.AuthEnvelope:
if rl.ServiceURL != "" {
if pubkey, ok := nip42.ValidateAuthEvent(&env.Event, ws.Challenge, rl.ServiceURL); ok {
ws.AuthedPublicKey = pubkey
close(ws.Authed)
ctx = context.WithValue(ctx, AUTH_CONTEXT_KEY, pubkey)
ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: true})
} else {
ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "error: failed to authenticate"})
}
wsBaseUrl := strings.Replace(rl.ServiceURL, "http", "ws", 1)
if pubkey, ok := nip42.ValidateAuthEvent(&env.Event, ws.Challenge, wsBaseUrl); ok {
ws.AuthedPublicKey = pubkey
close(ws.Authed)
ctx = context.WithValue(ctx, AUTH_CONTEXT_KEY, pubkey)
ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: true})
} else {
ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "error: failed to authenticate"})
}
}
}(message)
}
}()
// writer
go func() {
defer func() {
ticker.Stop()
conn.Close()
}()
defer kill()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
err := ws.WriteMessage(websocket.PingMessage, nil)
if err != nil {

55
helpers.go Normal file
View File

@@ -0,0 +1,55 @@
package khatru
import (
"hash/maphash"
"net/http"
"regexp"
"strconv"
"strings"
"unsafe"
"github.com/nbd-wtf/go-nostr"
)
const (
AUTH_CONTEXT_KEY = iota
WS_KEY
)
var nip20prefixmatcher = regexp.MustCompile(`^\w+: `)
func pointerHasher[V any](_ maphash.Seed, k *V) uint64 {
return uint64(uintptr(unsafe.Pointer(k)))
}
func isOlder(previous, next *nostr.Event) bool {
return previous.CreatedAt < next.CreatedAt ||
(previous.CreatedAt == next.CreatedAt && previous.ID > next.ID)
}
func isAuthRequired(msg string) bool {
idx := strings.IndexByte(msg, ':')
return msg[0:idx] == "auth-required"
}
func getServiceBaseURL(r *http.Request) string {
host := r.Header.Get("X-Forwarded-Host")
if host == "" {
host = r.Host
}
proto := r.Header.Get("X-Forwarded-Proto")
if proto == "" {
if host == "localhost" {
proto = "http"
} else if strings.Index(host, ":") != -1 {
// has a port number
proto = "http"
} else if _, err := strconv.Atoi(strings.ReplaceAll(host, ".", "")); err == nil {
// it's a naked IP
proto = "http"
} else {
proto = "https"
}
}
return proto + "://" + host
}

View File

@@ -1,12 +1,16 @@
package khatru
import (
"context"
"fmt"
"github.com/nbd-wtf/go-nostr"
"github.com/puzpuzpuz/xsync/v2"
)
type Listener struct {
filters nostr.Filters
cancel context.CancelCauseFunc
}
var listeners = xsync.NewTypedMapOf[*WebSocket, *xsync.MapOf[string, *Listener]](pointerHasher[WebSocket])
@@ -43,24 +47,28 @@ func GetListeningFilters() nostr.Filters {
return respfilters
}
func setListener(id string, ws *WebSocket, filters nostr.Filters) {
func setListener(id string, ws *WebSocket, filters nostr.Filters, cancel context.CancelCauseFunc) {
subs, _ := listeners.LoadOrCompute(ws, func() *xsync.MapOf[string, *Listener] {
return xsync.NewMapOf[*Listener]()
})
subs.Store(id, &Listener{filters: filters})
subs.Store(id, &Listener{filters: filters, cancel: cancel})
}
// Remove a specific subscription id from listeners for a given ws client
// remove a specific subscription id from listeners for a given ws client
// and cancel its specific context
func removeListenerId(ws *WebSocket, id string) {
if subs, ok := listeners.Load(ws); ok {
subs.Delete(id)
if listener, ok := subs.LoadAndDelete(id); ok {
listener.cancel(fmt.Errorf("subscription closed by client"))
}
if subs.Size() == 0 {
listeners.Delete(ws)
}
}
}
// Remove WebSocket conn from listeners
// remove WebSocket conn from listeners
// (no need to cancel contexts as they are all inherited from the main connection context)
func removeListener(ws *WebSocket) {
listeners.Delete(ws)
}

View File

@@ -40,7 +40,7 @@ func NewRelay() *Relay {
}
type Relay struct {
ServiceURL string // required for nip-42
ServiceURL string
RejectEvent []func(ctx context.Context, event *nostr.Event) (reject bool, msg string)
RejectFilter []func(ctx context.Context, filter nostr.Filter) (reject bool, msg string)
@@ -82,8 +82,3 @@ type Relay struct {
PingPeriod time.Duration // Send pings to peer with this period. Must be less than pongWait.
MaxMessageSize int64 // Maximum message size allowed from peer.
}
func (rl *Relay) RequestAuth(ctx context.Context) {
ws := GetConnection(ctx)
ws.WriteJSON(nostr.AuthEnvelope{Challenge: &ws.Challenge})
}

View File

@@ -2,20 +2,8 @@ package khatru
import (
"context"
"hash/maphash"
"regexp"
"unsafe"
"github.com/nbd-wtf/go-nostr"
)
const (
AUTH_CONTEXT_KEY = iota
WS_KEY = iota
)
var nip20prefixmatcher = regexp.MustCompile(`^\w+: `)
func GetConnection(ctx context.Context) *WebSocket {
return ctx.Value(WS_KEY).(*WebSocket)
}
@@ -27,12 +15,3 @@ func GetAuthed(ctx context.Context) string {
}
return authedPubkey.(string)
}
func pointerHasher[V any](_ maphash.Seed, k *V) uint64 {
return uint64(uintptr(unsafe.Pointer(k)))
}
func isOlder(previous, next *nostr.Event) bool {
return previous.CreatedAt < next.CreatedAt ||
(previous.CreatedAt == next.CreatedAt && previous.ID > next.ID)
}