mirror of
https://github.com/fiatjaf/khatru.git
synced 2026-04-08 06:26:52 +02:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3ec0020baa | ||
|
|
d3a0c545d2 | ||
|
|
c09d21b621 | ||
|
|
5823515d27 | ||
|
|
9273a4b809 | ||
|
|
ddfc9ab64a | ||
|
|
375236cfe2 | ||
|
|
35e801379a | ||
|
|
22da06b629 | ||
|
|
7bfde76ab1 |
@@ -26,7 +26,7 @@ func main() {
|
||||
relay.CountEvents = append(relay.CountEvents, db.CountEvents)
|
||||
relay.DeleteEvent = append(relay.DeleteEvent, db.DeleteEvent)
|
||||
|
||||
relay.RejectEvent = append(relay.RejectEvent, policies.PreventTooManyIndexableTags(10))
|
||||
relay.RejectEvent = append(relay.RejectEvent, policies.PreventTooManyIndexableTags(10, nil, nil))
|
||||
relay.RejectFilter = append(relay.RejectFilter, policies.NoComplexFilters)
|
||||
|
||||
relay.OnEventSaved = append(relay.OnEventSaved, func(ctx context.Context, event *nostr.Event) {
|
||||
|
||||
@@ -60,15 +60,20 @@ func main() {
|
||||
return false, "" // anyone else can
|
||||
},
|
||||
)
|
||||
relay.OnConnect = append(relay.OnConnect,
|
||||
func(ctx context.Context) {
|
||||
// request NIP-42 AUTH from everybody
|
||||
relay.RequestAuth(ctx)
|
||||
|
||||
// you can request auth by rejecting an event or a request with the prefix "auth-required: "
|
||||
relay.RejectFilter = append(relay.RejectFilter,
|
||||
func(ctx context.Context, filter nostr.Filter) (reject bool, msg string) {
|
||||
if pubkey := khatru.GetAuthed(ctx); pubkey != "" {
|
||||
log.Printf("request from %s\n", pubkey)
|
||||
return false, ""
|
||||
}
|
||||
return true, "auth-required: only authenticated users can read from this relay"
|
||||
},
|
||||
)
|
||||
relay.OnAuth = append(relay.OnAuth,
|
||||
func(ctx context.Context, pubkey string) {
|
||||
// and when they auth we just log that for nothing
|
||||
// and when they auth we can just log that for nothing
|
||||
log.Println(pubkey + " is authed!")
|
||||
},
|
||||
)
|
||||
|
||||
1
go.mod
1
go.mod
@@ -44,6 +44,7 @@ require (
|
||||
github.com/mattn/go-sqlite3 v1.14.17 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect
|
||||
github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a // indirect
|
||||
github.com/tidwall/gjson v1.14.4 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -102,6 +102,8 @@ github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik=
|
||||
github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU=
|
||||
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk=
|
||||
github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g=
|
||||
github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a h1:iLcLb5Fwwz7g/DLK89F+uQBDeAhHhwdzB5fSlVdhGcM=
|
||||
github.com/sebest/xff v0.0.0-20210106013422-671bd2870b3a/go.mod h1:wozgYq9WEBQBaIJe4YZ0qTSFAMxmcwBhQH0fO0R34Z0=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
|
||||
87
handlers.go
87
handlers.go
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -19,6 +20,10 @@ import (
|
||||
|
||||
// ServeHTTP implements http.Handler interface.
|
||||
func (rl *Relay) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if rl.ServiceURL == "" {
|
||||
rl.ServiceURL = getServiceBaseURL(r)
|
||||
}
|
||||
|
||||
if r.Header.Get("Upgrade") == "websocket" {
|
||||
rl.HandleWebsocket(w, r)
|
||||
} else if r.Header.Get("Accept") == "application/nostr+json" {
|
||||
@@ -29,8 +34,6 @@ func (rl *Relay) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
conn, err := rl.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
rl.Log.Printf("failed to upgrade websocket: %v\n", err)
|
||||
@@ -50,18 +53,29 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
Authed: make(chan struct{}),
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, WS_KEY, ws)
|
||||
ctx, cancel := context.WithCancel(
|
||||
context.WithValue(
|
||||
context.Background(),
|
||||
WS_KEY, ws,
|
||||
),
|
||||
)
|
||||
|
||||
kill := func() {
|
||||
for _, ondisconnect := range rl.OnDisconnect {
|
||||
ondisconnect(ctx)
|
||||
}
|
||||
|
||||
ticker.Stop()
|
||||
cancel()
|
||||
if _, ok := rl.clients.Load(conn); ok {
|
||||
conn.Close()
|
||||
rl.clients.Delete(conn)
|
||||
removeListener(ws)
|
||||
}
|
||||
}
|
||||
|
||||
// reader
|
||||
go func() {
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
if _, ok := rl.clients.Load(conn); ok {
|
||||
conn.Close()
|
||||
rl.clients.Delete(conn)
|
||||
removeListener(ws)
|
||||
}
|
||||
}()
|
||||
defer kill()
|
||||
|
||||
conn.SetReadLimit(rl.MaxMessageSize)
|
||||
conn.SetReadDeadline(time.Now().Add(rl.PongWait))
|
||||
@@ -86,7 +100,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
) {
|
||||
rl.Log.Printf("unexpected close error from %s: %v\n", r.Header.Get("X-Forwarded-For"), err)
|
||||
}
|
||||
break
|
||||
return
|
||||
}
|
||||
|
||||
if typ == websocket.PingMessage {
|
||||
@@ -95,8 +109,6 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
go func(message []byte) {
|
||||
ctx := context.WithValue(context.Background(), WS_KEY, ws)
|
||||
|
||||
envelope := nostr.ParseMessage(message)
|
||||
if envelope == nil {
|
||||
// stop silently
|
||||
@@ -134,6 +146,9 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
ok = true
|
||||
} else {
|
||||
reason = nostr.NormalizeOKMessage(err.Error(), "blocked")
|
||||
if isAuthRequired(reason) {
|
||||
ws.WriteJSON(nostr.AuthEnvelope{Challenge: &ws.Challenge})
|
||||
}
|
||||
}
|
||||
ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: ok, Reason: reason})
|
||||
case *nostr.CountEnvelope:
|
||||
@@ -150,48 +165,56 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
eose := sync.WaitGroup{}
|
||||
eose.Add(len(env.Filters))
|
||||
|
||||
// a context just for the "stored events" request handler
|
||||
reqCtx, cancelReqCtx := context.WithCancelCause(ctx)
|
||||
|
||||
// handle each filter separately -- dispatching events as they're loaded from databases
|
||||
for _, filter := range env.Filters {
|
||||
err := rl.handleRequest(ctx, env.SubscriptionID, &eose, ws, filter)
|
||||
if err == nil {
|
||||
err := rl.handleRequest(reqCtx, env.SubscriptionID, &eose, ws, filter)
|
||||
if err != nil {
|
||||
// fail everything if any filter is rejected
|
||||
reason := nostr.NormalizeOKMessage(err.Error(), "blocked")
|
||||
if isAuthRequired(reason) {
|
||||
ws.WriteJSON(nostr.AuthEnvelope{Challenge: &ws.Challenge})
|
||||
}
|
||||
ws.WriteJSON(nostr.ClosedEnvelope{SubscriptionID: env.SubscriptionID, Reason: reason})
|
||||
cancelReqCtx(fmt.Errorf("filter rejected"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
// when all events have been loaded from databases and dispatched
|
||||
// we can cancel the context and fire the EOSE message
|
||||
eose.Wait()
|
||||
cancelReqCtx(nil)
|
||||
ws.WriteJSON(nostr.EOSEEnvelope(env.SubscriptionID))
|
||||
}()
|
||||
|
||||
setListener(env.SubscriptionID, ws, env.Filters)
|
||||
setListener(env.SubscriptionID, ws, env.Filters, cancelReqCtx)
|
||||
case *nostr.CloseEnvelope:
|
||||
removeListenerId(ws, string(*env))
|
||||
case *nostr.AuthEnvelope:
|
||||
if rl.ServiceURL != "" {
|
||||
if pubkey, ok := nip42.ValidateAuthEvent(&env.Event, ws.Challenge, rl.ServiceURL); ok {
|
||||
ws.AuthedPublicKey = pubkey
|
||||
close(ws.Authed)
|
||||
ctx = context.WithValue(ctx, AUTH_CONTEXT_KEY, pubkey)
|
||||
ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: true})
|
||||
} else {
|
||||
ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "error: failed to authenticate"})
|
||||
}
|
||||
wsBaseUrl := strings.Replace(rl.ServiceURL, "http", "ws", 1)
|
||||
if pubkey, ok := nip42.ValidateAuthEvent(&env.Event, ws.Challenge, wsBaseUrl); ok {
|
||||
ws.AuthedPublicKey = pubkey
|
||||
close(ws.Authed)
|
||||
ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: true})
|
||||
} else {
|
||||
ws.WriteJSON(nostr.OKEnvelope{EventID: env.Event.ID, OK: false, Reason: "error: failed to authenticate"})
|
||||
}
|
||||
}
|
||||
}(message)
|
||||
}
|
||||
}()
|
||||
|
||||
// writer
|
||||
go func() {
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
conn.Close()
|
||||
}()
|
||||
defer kill()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := ws.WriteMessage(websocket.PingMessage, nil)
|
||||
if err != nil {
|
||||
|
||||
55
helpers.go
Normal file
55
helpers.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package khatru
|
||||
|
||||
import (
|
||||
"hash/maphash"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/nbd-wtf/go-nostr"
|
||||
)
|
||||
|
||||
const (
|
||||
AUTH_CONTEXT_KEY = iota
|
||||
WS_KEY
|
||||
)
|
||||
|
||||
var nip20prefixmatcher = regexp.MustCompile(`^\w+: `)
|
||||
|
||||
func pointerHasher[V any](_ maphash.Seed, k *V) uint64 {
|
||||
return uint64(uintptr(unsafe.Pointer(k)))
|
||||
}
|
||||
|
||||
func isOlder(previous, next *nostr.Event) bool {
|
||||
return previous.CreatedAt < next.CreatedAt ||
|
||||
(previous.CreatedAt == next.CreatedAt && previous.ID > next.ID)
|
||||
}
|
||||
|
||||
func isAuthRequired(msg string) bool {
|
||||
idx := strings.IndexByte(msg, ':')
|
||||
return msg[0:idx] == "auth-required"
|
||||
}
|
||||
|
||||
func getServiceBaseURL(r *http.Request) string {
|
||||
host := r.Header.Get("X-Forwarded-Host")
|
||||
if host == "" {
|
||||
host = r.Host
|
||||
}
|
||||
proto := r.Header.Get("X-Forwarded-Proto")
|
||||
if proto == "" {
|
||||
if host == "localhost" {
|
||||
proto = "http"
|
||||
} else if strings.Index(host, ":") != -1 {
|
||||
// has a port number
|
||||
proto = "http"
|
||||
} else if _, err := strconv.Atoi(strings.ReplaceAll(host, ".", "")); err == nil {
|
||||
// it's a naked IP
|
||||
proto = "http"
|
||||
} else {
|
||||
proto = "https"
|
||||
}
|
||||
}
|
||||
return proto + "://" + host
|
||||
}
|
||||
18
listener.go
18
listener.go
@@ -1,12 +1,16 @@
|
||||
package khatru
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/nbd-wtf/go-nostr"
|
||||
"github.com/puzpuzpuz/xsync/v2"
|
||||
)
|
||||
|
||||
type Listener struct {
|
||||
filters nostr.Filters
|
||||
cancel context.CancelCauseFunc
|
||||
}
|
||||
|
||||
var listeners = xsync.NewTypedMapOf[*WebSocket, *xsync.MapOf[string, *Listener]](pointerHasher[WebSocket])
|
||||
@@ -43,24 +47,28 @@ func GetListeningFilters() nostr.Filters {
|
||||
return respfilters
|
||||
}
|
||||
|
||||
func setListener(id string, ws *WebSocket, filters nostr.Filters) {
|
||||
func setListener(id string, ws *WebSocket, filters nostr.Filters, cancel context.CancelCauseFunc) {
|
||||
subs, _ := listeners.LoadOrCompute(ws, func() *xsync.MapOf[string, *Listener] {
|
||||
return xsync.NewMapOf[*Listener]()
|
||||
})
|
||||
subs.Store(id, &Listener{filters: filters})
|
||||
subs.Store(id, &Listener{filters: filters, cancel: cancel})
|
||||
}
|
||||
|
||||
// Remove a specific subscription id from listeners for a given ws client
|
||||
// remove a specific subscription id from listeners for a given ws client
|
||||
// and cancel its specific context
|
||||
func removeListenerId(ws *WebSocket, id string) {
|
||||
if subs, ok := listeners.Load(ws); ok {
|
||||
subs.Delete(id)
|
||||
if listener, ok := subs.LoadAndDelete(id); ok {
|
||||
listener.cancel(fmt.Errorf("subscription closed by client"))
|
||||
}
|
||||
if subs.Size() == 0 {
|
||||
listeners.Delete(ws)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove WebSocket conn from listeners
|
||||
// remove WebSocket conn from listeners
|
||||
// (no need to cancel contexts as they are all inherited from the main connection context)
|
||||
func removeListener(ws *WebSocket) {
|
||||
listeners.Delete(ws)
|
||||
}
|
||||
|
||||
8
relay.go
8
relay.go
@@ -40,7 +40,7 @@ func NewRelay() *Relay {
|
||||
}
|
||||
|
||||
type Relay struct {
|
||||
ServiceURL string // required for nip-42
|
||||
ServiceURL string
|
||||
|
||||
RejectEvent []func(ctx context.Context, event *nostr.Event) (reject bool, msg string)
|
||||
RejectFilter []func(ctx context.Context, filter nostr.Filter) (reject bool, msg string)
|
||||
@@ -56,6 +56,7 @@ type Relay struct {
|
||||
CountEvents []func(ctx context.Context, filter nostr.Filter) (int64, error)
|
||||
OnAuth []func(ctx context.Context, pubkey string)
|
||||
OnConnect []func(ctx context.Context)
|
||||
OnDisconnect []func(ctx context.Context)
|
||||
OnEventSaved []func(ctx context.Context, event *nostr.Event)
|
||||
|
||||
// editing info will affect
|
||||
@@ -82,8 +83,3 @@ type Relay struct {
|
||||
PingPeriod time.Duration // Send pings to peer with this period. Must be less than pongWait.
|
||||
MaxMessageSize int64 // Maximum message size allowed from peer.
|
||||
}
|
||||
|
||||
func (rl *Relay) RequestAuth(ctx context.Context) {
|
||||
ws := GetConnection(ctx)
|
||||
ws.WriteJSON(nostr.AuthEnvelope{Challenge: &ws.Challenge})
|
||||
}
|
||||
|
||||
40
utils.go
40
utils.go
@@ -2,37 +2,31 @@ package khatru
|
||||
|
||||
import (
|
||||
"context"
|
||||
"hash/maphash"
|
||||
"regexp"
|
||||
"unsafe"
|
||||
|
||||
"github.com/nbd-wtf/go-nostr"
|
||||
"github.com/sebest/xff"
|
||||
)
|
||||
|
||||
const (
|
||||
AUTH_CONTEXT_KEY = iota
|
||||
WS_KEY = iota
|
||||
)
|
||||
|
||||
var nip20prefixmatcher = regexp.MustCompile(`^\w+: `)
|
||||
|
||||
func GetConnection(ctx context.Context) *WebSocket {
|
||||
return ctx.Value(WS_KEY).(*WebSocket)
|
||||
}
|
||||
|
||||
func GetAuthed(ctx context.Context) string {
|
||||
authedPubkey := ctx.Value(AUTH_CONTEXT_KEY)
|
||||
if authedPubkey == nil {
|
||||
return ""
|
||||
return GetConnection(ctx).AuthedPublicKey
|
||||
}
|
||||
|
||||
func GetIP(ctx context.Context) string {
|
||||
return xff.GetRemoteAddr(GetConnection(ctx).Request)
|
||||
}
|
||||
|
||||
func GetOpenSubscriptions(ctx context.Context) []nostr.Filter {
|
||||
if listeners, ok := listeners.Load(GetConnection(ctx)); ok {
|
||||
res := make([]nostr.Filter, 0, listeners.Size()*2)
|
||||
listeners.Range(func(_ string, listener *Listener) bool {
|
||||
res = append(res, listener.filters...)
|
||||
return true
|
||||
})
|
||||
return res
|
||||
}
|
||||
return authedPubkey.(string)
|
||||
}
|
||||
|
||||
func pointerHasher[V any](_ maphash.Seed, k *V) uint64 {
|
||||
return uint64(uintptr(unsafe.Pointer(k)))
|
||||
}
|
||||
|
||||
func isOlder(previous, next *nostr.Event) bool {
|
||||
return previous.CreatedAt < next.CreatedAt ||
|
||||
(previous.CreatedAt == next.CreatedAt && previous.ID > next.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user