Compare commits

...

14 Commits

Author SHA1 Message Date
fiatjaf
5823515d27 streamlined connection closes on failure.
account for the fact that the time.Ticker channel is
not closed when the ticker is stopped.
2023-12-09 00:00:22 -03:00
fiatjaf
9273a4b809 use a special context for each REQ stored-events handler that can be canceled. 2023-12-08 23:48:30 -03:00
fiatjaf
ddfc9ab64a fun with connection contexts and context cancelations. 2023-12-08 22:51:00 -03:00
fiatjaf
375236cfe2 fix sign on error checking. 2023-12-06 21:32:48 -03:00
fiatjaf
35e801379a make NIP-42 actually work, with inferred ServiceURL if that's not manually set. 2023-12-06 15:03:53 -03:00
fiatjaf
22da06b629 new flow for auth based on "auth-required: " rejection messages. 2023-12-06 12:14:58 -03:00
fiatjaf
7bfde76ab1 example fix. 2023-12-06 12:14:27 -03:00
fiatjaf
ad92d0b051 return CLOSED if any of the filters get rejected. 2023-12-06 11:56:56 -03:00
fiatjaf
728417852e fix nip04 policy. 2023-11-29 12:30:18 -03:00
fiatjaf
3c1b062eb8 include original http.Request in WebSocket struct. 2023-11-29 12:26:04 -03:00
fiatjaf
84d01dc1d3 rename auth-related fields on WebSocket struct. 2023-11-29 12:23:21 -03:00
fiatjaf
888ac8c1c0 use updated released go-nostr. 2023-11-29 12:23:02 -03:00
fiatjaf
e1fd6aaa56 update examples plugins->policies 2023-11-29 12:22:37 -03:00
fiatjaf
386a89676a use go-nostr envelopes and support CLOSED when filters are rejected. 2023-11-28 22:43:06 -03:00
12 changed files with 186 additions and 166 deletions

View File

@@ -8,7 +8,7 @@ import (
"github.com/fiatjaf/eventstore/lmdb" "github.com/fiatjaf/eventstore/lmdb"
"github.com/fiatjaf/khatru" "github.com/fiatjaf/khatru"
"github.com/fiatjaf/khatru/plugins" "github.com/fiatjaf/khatru/policies"
"github.com/nbd-wtf/go-nostr" "github.com/nbd-wtf/go-nostr"
) )
@@ -26,8 +26,8 @@ func main() {
relay.CountEvents = append(relay.CountEvents, db.CountEvents) relay.CountEvents = append(relay.CountEvents, db.CountEvents)
relay.DeleteEvent = append(relay.DeleteEvent, db.DeleteEvent) relay.DeleteEvent = append(relay.DeleteEvent, db.DeleteEvent)
relay.RejectEvent = append(relay.RejectEvent, plugins.PreventTooManyIndexableTags(10)) relay.RejectEvent = append(relay.RejectEvent, policies.PreventTooManyIndexableTags(10, nil, nil))
relay.RejectFilter = append(relay.RejectFilter, plugins.NoComplexFilters) relay.RejectFilter = append(relay.RejectFilter, policies.NoComplexFilters)
relay.OnEventSaved = append(relay.OnEventSaved, func(ctx context.Context, event *nostr.Event) { relay.OnEventSaved = append(relay.OnEventSaved, func(ctx context.Context, event *nostr.Event) {
}) })

View File

@@ -60,15 +60,20 @@ func main() {
return false, "" // anyone else can return false, "" // anyone else can
}, },
) )
relay.OnConnect = append(relay.OnConnect,
func(ctx context.Context) { // you can request auth by rejecting an event or a request with the prefix "auth-required: "
// request NIP-42 AUTH from everybody relay.RejectFilter = append(relay.RejectFilter,
relay.RequestAuth(ctx) func(ctx context.Context, filter nostr.Filter) (reject bool, msg string) {
if pubkey := khatru.GetAuthed(ctx); pubkey != "" {
log.Printf("request from %s\n", pubkey)
return false, ""
}
return true, "auth-required: only authenticated users can read from this relay"
}, },
) )
relay.OnAuth = append(relay.OnAuth, relay.OnAuth = append(relay.OnAuth,
func(ctx context.Context, pubkey string) { func(ctx context.Context, pubkey string) {
// and when they auth we just log that for nothing // and when they auth we can just log that for nothing
log.Println(pubkey + " is authed!") log.Println(pubkey + " is authed!")
}, },
) )

2
go.mod
View File

@@ -5,7 +5,7 @@ go 1.21.0
require ( require (
github.com/fasthttp/websocket v1.5.3 github.com/fasthttp/websocket v1.5.3
github.com/fiatjaf/eventstore v0.1.0 github.com/fiatjaf/eventstore v0.1.0
github.com/nbd-wtf/go-nostr v0.25.7 github.com/nbd-wtf/go-nostr v0.26.0
github.com/puzpuzpuz/xsync/v2 v2.5.1 github.com/puzpuzpuz/xsync/v2 v2.5.1
github.com/rs/cors v1.7.0 github.com/rs/cors v1.7.0
golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53 golang.org/x/exp v0.0.0-20230425010034-47ecfdc1ba53

4
go.sum
View File

@@ -90,8 +90,8 @@ github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJ
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/nbd-wtf/go-nostr v0.25.7 h1:DcGOSgKVr/L6w62tRtKeV2t46sRyFcq9pWcyIFkh0eM= github.com/nbd-wtf/go-nostr v0.26.0 h1:Tofbs9i8DD5iEKIhLlWFO7kfWpvmUG16fEyW30MzHVQ=
github.com/nbd-wtf/go-nostr v0.25.7/go.mod h1:bkffJI+x914sPQWum9ZRUn66D7NpDnAoWo1yICvj3/0= github.com/nbd-wtf/go-nostr v0.26.0/go.mod h1:bkffJI+x914sPQWum9ZRUn66D7NpDnAoWo1yICvj3/0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=

View File

@@ -6,6 +6,7 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
@@ -19,6 +20,10 @@ import (
// ServeHTTP implements http.Handler interface. // ServeHTTP implements http.Handler interface.
func (rl *Relay) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (rl *Relay) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if rl.ServiceURL == "" {
rl.ServiceURL = getServiceBaseURL(r)
}
if r.Header.Get("Upgrade") == "websocket" { if r.Header.Get("Upgrade") == "websocket" {
rl.HandleWebsocket(w, r) rl.HandleWebsocket(w, r)
} else if r.Header.Get("Accept") == "application/nostr+json" { } else if r.Header.Get("Accept") == "application/nostr+json" {
@@ -29,8 +34,6 @@ func (rl *Relay) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) { func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
conn, err := rl.upgrader.Upgrade(w, r, nil) conn, err := rl.upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
rl.Log.Printf("failed to upgrade websocket: %v\n", err) rl.Log.Printf("failed to upgrade websocket: %v\n", err)
@@ -44,23 +47,31 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
rand.Read(challenge) rand.Read(challenge)
ws := &WebSocket{ ws := &WebSocket{
conn: conn, conn: conn,
Challenge: hex.EncodeToString(challenge), Request: r,
WaitingForAuth: make(chan struct{}), Challenge: hex.EncodeToString(challenge),
Authed: make(chan struct{}),
} }
ctx = context.WithValue(ctx, WS_KEY, ws) ctx, cancel := context.WithCancel(
context.WithValue(
context.Background(),
WS_KEY, ws,
),
)
kill := func() {
ticker.Stop()
cancel()
if _, ok := rl.clients.Load(conn); ok {
conn.Close()
rl.clients.Delete(conn)
removeListener(ws)
}
}
// reader
go func() { go func() {
defer func() { defer kill()
ticker.Stop()
if _, ok := rl.clients.Load(conn); ok {
conn.Close()
rl.clients.Delete(conn)
removeListener(ws)
}
}()
conn.SetReadLimit(rl.MaxMessageSize) conn.SetReadLimit(rl.MaxMessageSize)
conn.SetReadDeadline(time.Now().Add(rl.PongWait)) conn.SetReadDeadline(time.Now().Add(rl.PongWait))
@@ -94,153 +105,113 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
} }
go func(message []byte) { go func(message []byte) {
ctx = context.Background() envelope := nostr.ParseMessage(message)
if envelope == nil {
var request []json.RawMessage
if err := json.Unmarshal(message, &request); err != nil {
// stop silently // stop silently
return return
} }
if len(request) < 2 { switch env := envelope.(type) {
ws.WriteJSON(nostr.NoticeEnvelope("request has less than 2 parameters")) case *nostr.EventEnvelope:
return
}
var typ string
json.Unmarshal(request[0], &typ)
switch typ {
case "EVENT":
// it's a new event
var evt nostr.Event
if err := json.Unmarshal(request[1], &evt); err != nil {
ws.WriteJSON(nostr.NoticeEnvelope("failed to decode event: " + err.Error()))
return
}
// check id // check id
hash := sha256.Sum256(evt.Serialize()) hash := sha256.Sum256(env.Event.Serialize())
id := hex.EncodeToString(hash[:]) id := hex.EncodeToString(hash[:])
if id != evt.ID { if id != env.Event.ID {
ws.WriteJSON(nostr.OKEnvelope{EventID: evt.ID, OK: false, Reason: "invalid: id is computed incorrectly"}) ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "invalid: id is computed incorrectly"})
return return
} }
// check signature // check signature
if ok, err := evt.CheckSignature(); err != nil { if ok, err := env.Event.CheckSignature(); err != nil {
ws.WriteJSON(nostr.OKEnvelope{EventID: evt.ID, OK: false, Reason: "error: failed to verify signature"}) ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "error: failed to verify signature"})
return return
} else if !ok { } else if !ok {
ws.WriteJSON(nostr.OKEnvelope{EventID: evt.ID, OK: false, Reason: "invalid: signature is invalid"}) ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "invalid: signature is invalid"})
return return
} }
var ok bool var ok bool
if evt.Kind == 5 { if env.Event.Kind == 5 {
err = rl.handleDeleteRequest(ctx, &evt) err = rl.handleDeleteRequest(ctx, &env.Event)
} else { } else {
err = rl.AddEvent(ctx, &evt) err = rl.AddEvent(ctx, &env.Event)
} }
var reason string var reason string
if err == nil { if err == nil {
ok = true ok = true
} else { } else {
reason = err.Error() reason = nostr.NormalizeOKMessage(err.Error(), "blocked")
if isAuthRequired(reason) {
ws.WriteJSON(nostr.AuthEnvelope{Challenge: &ws.Challenge})
}
} }
ws.WriteJSON(nostr.OKEnvelope{EventID: evt.ID, OK: ok, Reason: reason}) ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: ok, Reason: reason})
case "COUNT": case *nostr.CountEnvelope:
if rl.CountEvents == nil { if rl.CountEvents == nil {
ws.WriteJSON(nostr.NoticeEnvelope("this relay does not support NIP-45")) ws.WriteJSON(nostr.ClosedEnvelope{SubscriptionID: env.SubscriptionID, Reason: "unsupported: this relay does not support NIP-45"})
return return
} }
var id string
json.Unmarshal(request[1], &id)
if id == "" {
ws.WriteJSON(nostr.NoticeEnvelope("COUNT has no <id>"))
return
}
var total int64 var total int64
filters := make(nostr.Filters, len(request)-2) for _, filter := range env.Filters {
for i, filterReq := range request[2:] { total += rl.handleCountRequest(ctx, ws, filter)
if err := json.Unmarshal(filterReq, &filters[i]); err != nil {
ws.WriteJSON(nostr.NoticeEnvelope("failed to decode filter"))
continue
}
total += rl.handleCountRequest(ctx, ws, filters[i])
} }
ws.WriteJSON(nostr.CountEnvelope{SubscriptionID: env.SubscriptionID, Count: &total})
ws.WriteJSON([]interface{}{"COUNT", id, map[string]int64{"count": total}}) case *nostr.ReqEnvelope:
case "REQ":
var id string
json.Unmarshal(request[1], &id)
if id == "" {
ws.WriteJSON(nostr.NoticeEnvelope("REQ has no <id>"))
return
}
filters := make(nostr.Filters, len(request)-2)
eose := sync.WaitGroup{} eose := sync.WaitGroup{}
eose.Add(len(request[2:])) eose.Add(len(env.Filters))
for i, filterReq := range request[2:] { // a context just for the "stored events" request handler
if err := json.Unmarshal(filterReq, &filters[i]); err != nil { reqCtx, cancelReqCtx := context.WithCancelCause(ctx)
ws.WriteJSON(nostr.NoticeEnvelope("failed to decode filter"))
eose.Done() // handle each filter separately -- dispatching events as they're loaded from databases
continue for _, filter := range env.Filters {
err := rl.handleRequest(reqCtx, env.SubscriptionID, &eose, ws, filter)
if err != nil {
// fail everything if any filter is rejected
reason := nostr.NormalizeOKMessage(err.Error(), "blocked")
if isAuthRequired(reason) {
ws.WriteJSON(nostr.AuthEnvelope{Challenge: &ws.Challenge})
}
ws.WriteJSON(nostr.ClosedEnvelope{SubscriptionID: env.SubscriptionID, Reason: reason})
cancelReqCtx(fmt.Errorf("filter rejected"))
return
} }
go rl.handleRequest(ctx, id, &eose, ws, filters[i])
} }
go func() { go func() {
// when all events have been loaded from databases and dispatched
// we can cancel the context and fire the EOSE message
eose.Wait() eose.Wait()
ws.WriteJSON(nostr.EOSEEnvelope(id)) cancelReqCtx(nil)
ws.WriteJSON(nostr.EOSEEnvelope(env.SubscriptionID))
}() }()
setListener(id, ws, filters) setListener(env.SubscriptionID, ws, env.Filters, cancelReqCtx)
case "CLOSE": case *nostr.CloseEnvelope:
var id string removeListenerId(ws, string(*env))
json.Unmarshal(request[1], &id) case *nostr.AuthEnvelope:
if id == "" { wsBaseUrl := strings.Replace(rl.ServiceURL, "http", "ws", 1)
ws.WriteJSON(nostr.NoticeEnvelope("CLOSE has no <id>")) if pubkey, ok := nip42.ValidateAuthEvent(&env.Event, ws.Challenge, wsBaseUrl); ok {
return ws.AuthedPublicKey = pubkey
} close(ws.Authed)
ctx = context.WithValue(ctx, AUTH_CONTEXT_KEY, pubkey)
removeListenerId(ws, id) ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: true})
case "AUTH": } else {
if rl.ServiceURL != "" { ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "error: failed to authenticate"})
var evt nostr.Event
if err := json.Unmarshal(request[1], &evt); err != nil {
ws.WriteJSON(nostr.NoticeEnvelope("failed to decode auth event: " + err.Error()))
return
}
if pubkey, ok := nip42.ValidateAuthEvent(&evt, ws.Challenge, rl.ServiceURL); ok {
ws.Authed = pubkey
close(ws.WaitingForAuth)
ctx = context.WithValue(ctx, AUTH_CONTEXT_KEY, pubkey)
ws.WriteJSON(nostr.OKEnvelope{EventID: evt.ID, OK: true})
} else {
ws.WriteJSON(nostr.OKEnvelope{EventID: evt.ID, OK: false, Reason: "error: failed to authenticate"})
}
} }
} }
}(message) }(message)
} }
}() }()
// writer
go func() { go func() {
defer func() { defer kill()
ticker.Stop()
conn.Close()
}()
for { for {
select { select {
case <-ctx.Done():
return
case <-ticker.C: case <-ticker.C:
err := ws.WriteMessage(websocket.PingMessage, nil) err := ws.WriteMessage(websocket.PingMessage, nil)
if err != nil { if err != nil {

55
helpers.go Normal file
View File

@@ -0,0 +1,55 @@
package khatru
import (
"hash/maphash"
"net/http"
"regexp"
"strconv"
"strings"
"unsafe"
"github.com/nbd-wtf/go-nostr"
)
const (
AUTH_CONTEXT_KEY = iota
WS_KEY
)
var nip20prefixmatcher = regexp.MustCompile(`^\w+: `)
func pointerHasher[V any](_ maphash.Seed, k *V) uint64 {
return uint64(uintptr(unsafe.Pointer(k)))
}
func isOlder(previous, next *nostr.Event) bool {
return previous.CreatedAt < next.CreatedAt ||
(previous.CreatedAt == next.CreatedAt && previous.ID > next.ID)
}
func isAuthRequired(msg string) bool {
idx := strings.IndexByte(msg, ':')
return msg[0:idx] == "auth-required"
}
func getServiceBaseURL(r *http.Request) string {
host := r.Header.Get("X-Forwarded-Host")
if host == "" {
host = r.Host
}
proto := r.Header.Get("X-Forwarded-Proto")
if proto == "" {
if host == "localhost" {
proto = "http"
} else if strings.Index(host, ":") != -1 {
// has a port number
proto = "http"
} else if _, err := strconv.Atoi(strings.ReplaceAll(host, ".", "")); err == nil {
// it's a naked IP
proto = "http"
} else {
proto = "https"
}
}
return proto + "://" + host
}

View File

@@ -1,12 +1,16 @@
package khatru package khatru
import ( import (
"context"
"fmt"
"github.com/nbd-wtf/go-nostr" "github.com/nbd-wtf/go-nostr"
"github.com/puzpuzpuz/xsync/v2" "github.com/puzpuzpuz/xsync/v2"
) )
type Listener struct { type Listener struct {
filters nostr.Filters filters nostr.Filters
cancel context.CancelCauseFunc
} }
var listeners = xsync.NewTypedMapOf[*WebSocket, *xsync.MapOf[string, *Listener]](pointerHasher[WebSocket]) var listeners = xsync.NewTypedMapOf[*WebSocket, *xsync.MapOf[string, *Listener]](pointerHasher[WebSocket])
@@ -43,24 +47,28 @@ func GetListeningFilters() nostr.Filters {
return respfilters return respfilters
} }
func setListener(id string, ws *WebSocket, filters nostr.Filters) { func setListener(id string, ws *WebSocket, filters nostr.Filters, cancel context.CancelCauseFunc) {
subs, _ := listeners.LoadOrCompute(ws, func() *xsync.MapOf[string, *Listener] { subs, _ := listeners.LoadOrCompute(ws, func() *xsync.MapOf[string, *Listener] {
return xsync.NewMapOf[*Listener]() return xsync.NewMapOf[*Listener]()
}) })
subs.Store(id, &Listener{filters: filters}) subs.Store(id, &Listener{filters: filters, cancel: cancel})
} }
// Remove a specific subscription id from listeners for a given ws client // remove a specific subscription id from listeners for a given ws client
// and cancel its specific context
func removeListenerId(ws *WebSocket, id string) { func removeListenerId(ws *WebSocket, id string) {
if subs, ok := listeners.Load(ws); ok { if subs, ok := listeners.Load(ws); ok {
subs.Delete(id) if listener, ok := subs.LoadAndDelete(id); ok {
listener.cancel(fmt.Errorf("subscription closed by client"))
}
if subs.Size() == 0 { if subs.Size() == 0 {
listeners.Delete(ws) listeners.Delete(ws)
} }
} }
} }
// Remove WebSocket conn from listeners // remove WebSocket conn from listeners
// (no need to cancel contexts as they are all inherited from the main connection context)
func removeListener(ws *WebSocket) { func removeListener(ws *WebSocket) {
listeners.Delete(ws) listeners.Delete(ws)
} }

View File

@@ -20,13 +20,13 @@ func RejectKind04Snoopers(ctx context.Context, filter nostr.Filter) (bool, strin
senders := filter.Authors senders := filter.Authors
receivers, _ := filter.Tags["p"] receivers, _ := filter.Tags["p"]
switch { switch {
case ws.Authed == "": case ws.AuthedPublicKey == "":
// not authenticated // not authenticated
return true, "restricted: this relay does not serve kind-4 to unauthenticated users, does your client implement NIP-42?" return true, "restricted: this relay does not serve kind-4 to unauthenticated users, does your client implement NIP-42?"
case len(senders) == 1 && len(receivers) < 2 && (senders[0] == ws.Authed): case len(senders) == 1 && len(receivers) < 2 && (senders[0] == ws.AuthedPublicKey):
// allowed filter: ws.authed is sole sender (filter specifies one or all receivers) // allowed filter: ws.authed is sole sender (filter specifies one or all receivers)
return false, "" return false, ""
case len(receivers) == 1 && len(senders) < 2 && (receivers[0] == ws.Authed): case len(receivers) == 1 && len(senders) < 2 && (receivers[0] == ws.AuthedPublicKey):
// allowed filter: ws.authed is sole receiver (filter specifies one or all senders) // allowed filter: ws.authed is sole receiver (filter specifies one or all senders)
return false, "" return false, ""
default: default:

View File

@@ -40,7 +40,7 @@ func NewRelay() *Relay {
} }
type Relay struct { type Relay struct {
ServiceURL string // required for nip-42 ServiceURL string
RejectEvent []func(ctx context.Context, event *nostr.Event) (reject bool, msg string) RejectEvent []func(ctx context.Context, event *nostr.Event) (reject bool, msg string)
RejectFilter []func(ctx context.Context, filter nostr.Filter) (reject bool, msg string) RejectFilter []func(ctx context.Context, filter nostr.Filter) (reject bool, msg string)
@@ -82,8 +82,3 @@ type Relay struct {
PingPeriod time.Duration // Send pings to peer with this period. Must be less than pongWait. PingPeriod time.Duration // Send pings to peer with this period. Must be less than pongWait.
MaxMessageSize int64 // Maximum message size allowed from peer. MaxMessageSize int64 // Maximum message size allowed from peer.
} }
func (rl *Relay) RequestAuth(ctx context.Context) {
ws := GetConnection(ctx)
ws.WriteJSON(nostr.AuthEnvelope{Challenge: &ws.Challenge})
}

View File

@@ -2,12 +2,13 @@ package khatru
import ( import (
"context" "context"
"fmt"
"sync" "sync"
"github.com/nbd-wtf/go-nostr" "github.com/nbd-wtf/go-nostr"
) )
func (rl *Relay) handleRequest(ctx context.Context, id string, eose *sync.WaitGroup, ws *WebSocket, filter nostr.Filter) { func (rl *Relay) handleRequest(ctx context.Context, id string, eose *sync.WaitGroup, ws *WebSocket, filter nostr.Filter) error {
defer eose.Done() defer eose.Done()
// overwrite the filter (for example, to eliminate some kinds or // overwrite the filter (for example, to eliminate some kinds or
@@ -17,7 +18,7 @@ func (rl *Relay) handleRequest(ctx context.Context, id string, eose *sync.WaitGr
} }
if filter.Limit < 0 { if filter.Limit < 0 {
return return fmt.Errorf("filter invalidated")
} }
// then check if we'll reject this filter (we apply this after overwriting // then check if we'll reject this filter (we apply this after overwriting
@@ -27,7 +28,7 @@ func (rl *Relay) handleRequest(ctx context.Context, id string, eose *sync.WaitGr
for _, reject := range rl.RejectFilter { for _, reject := range rl.RejectFilter {
if reject, msg := reject(ctx, filter); reject { if reject, msg := reject(ctx, filter); reject {
ws.WriteJSON(nostr.NoticeEnvelope(msg)) ws.WriteJSON(nostr.NoticeEnvelope(msg))
return return fmt.Errorf(msg)
} }
} }
@@ -52,6 +53,8 @@ func (rl *Relay) handleRequest(ctx context.Context, id string, eose *sync.WaitGr
eose.Done() eose.Done()
}(ch) }(ch)
} }
return nil
} }
func (rl *Relay) handleCountRequest(ctx context.Context, ws *WebSocket, filter nostr.Filter) int64 { func (rl *Relay) handleCountRequest(ctx context.Context, ws *WebSocket, filter nostr.Filter) int64 {

View File

@@ -2,20 +2,8 @@ package khatru
import ( import (
"context" "context"
"hash/maphash"
"regexp"
"unsafe"
"github.com/nbd-wtf/go-nostr"
) )
const (
AUTH_CONTEXT_KEY = iota
WS_KEY = iota
)
var nip20prefixmatcher = regexp.MustCompile(`^\w+: `)
func GetConnection(ctx context.Context) *WebSocket { func GetConnection(ctx context.Context) *WebSocket {
return ctx.Value(WS_KEY).(*WebSocket) return ctx.Value(WS_KEY).(*WebSocket)
} }
@@ -27,12 +15,3 @@ func GetAuthed(ctx context.Context) string {
} }
return authedPubkey.(string) return authedPubkey.(string)
} }
func pointerHasher[V any](_ maphash.Seed, k *V) uint64 {
return uint64(uintptr(unsafe.Pointer(k)))
}
func isOlder(previous, next *nostr.Event) bool {
return previous.CreatedAt < next.CreatedAt ||
(previous.CreatedAt == next.CreatedAt && previous.ID > next.ID)
}

View File

@@ -1,6 +1,7 @@
package khatru package khatru
import ( import (
"net/http"
"sync" "sync"
"github.com/fasthttp/websocket" "github.com/fasthttp/websocket"
@@ -10,10 +11,13 @@ type WebSocket struct {
conn *websocket.Conn conn *websocket.Conn
mutex sync.Mutex mutex sync.Mutex
// original request
Request *http.Request
// nip42 // nip42
Challenge string Challenge string
Authed string AuthedPublicKey string
WaitingForAuth chan struct{} Authed chan struct{}
} }
func (ws *WebSocket) WriteJSON(any any) error { func (ws *WebSocket) WriteJSON(any any) error {