fix @staab's mess.

This commit is contained in:
fiatjaf 2024-12-31 22:15:06 -03:00
parent 5b9b89543f
commit e1de0432fe
4 changed files with 33 additions and 29 deletions

View File

@ -6,7 +6,6 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"net/http" "net/http"
"os"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -316,10 +315,7 @@ func (rl *Relay) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
id := string(*env) id := string(*env)
rl.removeListenerId(ws, id) rl.removeListenerId(ws, id)
case *nostr.AuthEnvelope: case *nostr.AuthEnvelope:
wsBaseUrl := os.Getenv("RELAY_URL") wsBaseUrl := strings.Replace(rl.getBaseURL(r), "http", "ws", 1)
if wsBaseUrl == "" {
wsBaseUrl = strings.Replace(getServiceBaseURL(r), "http", "ws", 1)
}
if pubkey, ok := nip42.ValidateAuthEvent(&env.Event, ws.Challenge, wsBaseUrl); ok { if pubkey, ok := nip42.ValidateAuthEvent(&env.Event, ws.Challenge, wsBaseUrl); ok {
ws.AuthedPublicKey = pubkey ws.AuthedPublicKey = pubkey
ws.authLock.Lock() ws.authLock.Lock()

View File

@ -3,7 +3,6 @@ package khatru
import ( import (
"net" "net"
"net/http" "net/http"
"strconv"
"strings" "strings"
"github.com/nbd-wtf/go-nostr" "github.com/nbd-wtf/go-nostr"
@ -14,28 +13,6 @@ func isOlder(previous, next *nostr.Event) bool {
(previous.CreatedAt == next.CreatedAt && previous.ID > next.ID) (previous.CreatedAt == next.CreatedAt && previous.ID > next.ID)
} }
func getServiceBaseURL(r *http.Request) string {
host := r.Header.Get("X-Forwarded-Host")
if host == "" {
host = r.Host
}
proto := r.Header.Get("X-Forwarded-Proto")
if proto == "" {
if host == "localhost" {
proto = "http"
} else if strings.Index(host, ":") != -1 {
// has a port number
proto = "http"
} else if _, err := strconv.Atoi(strings.ReplaceAll(host, ".", "")); err == nil {
// it's a naked IP
proto = "http"
} else {
proto = "https"
}
}
return proto + "://" + host
}
var privateMasks = func() []net.IPNet { var privateMasks = func() []net.IPNet {
privateCIDRs := []string{ privateCIDRs := []string{
"127.0.0.0/8", "127.0.0.0/8",

View File

@ -80,7 +80,7 @@ func (rl *Relay) HandleNIP86(w http.ResponseWriter, r *http.Request) {
goto respond goto respond
} }
if uTag := evt.Tags.GetFirst([]string{"u", ""}); uTag == nil || getServiceBaseURL(r) != (*uTag)[1] { if uTag := evt.Tags.GetFirst([]string{"u", ""}); uTag == nil || rl.getBaseURL(r) != (*uTag)[1] {
resp.Error = "invalid 'u' tag" resp.Error = "invalid 'u' tag"
goto respond goto respond
} else if pht := evt.Tags.GetFirst([]string{"payload", hex.EncodeToString(payloadHash[:])}); pht == nil { } else if pht := evt.Tags.GetFirst([]string{"payload", hex.EncodeToString(payloadHash[:])}); pht == nil {

View File

@ -5,6 +5,8 @@ import (
"log" "log"
"net/http" "net/http"
"os" "os"
"strconv"
"strings"
"sync" "sync"
"time" "time"
@ -45,6 +47,9 @@ func NewRelay() *Relay {
} }
type Relay struct { type Relay struct {
// setting this variable overwrites the hackish workaround we do to try to figure out our own base URL
ServiceURL string
// hooks that will be called at various times // hooks that will be called at various times
RejectEvent []func(ctx context.Context, event *nostr.Event) (reject bool, msg string) RejectEvent []func(ctx context.Context, event *nostr.Event) (reject bool, msg string)
OverwriteDeletionOutcome []func(ctx context.Context, target *nostr.Event, deletion *nostr.Event) (acceptDeletion bool, msg string) OverwriteDeletionOutcome []func(ctx context.Context, target *nostr.Event, deletion *nostr.Event) (acceptDeletion bool, msg string)
@ -104,3 +109,29 @@ type Relay struct {
PingPeriod time.Duration // Send pings to peer with this period. Must be less than pongWait. PingPeriod time.Duration // Send pings to peer with this period. Must be less than pongWait.
MaxMessageSize int64 // Maximum message size allowed from peer. MaxMessageSize int64 // Maximum message size allowed from peer.
} }
func (rl *Relay) getBaseURL(r *http.Request) string {
if rl.ServiceURL != "" {
return rl.ServiceURL
}
host := r.Header.Get("X-Forwarded-Host")
if host == "" {
host = r.Host
}
proto := r.Header.Get("X-Forwarded-Proto")
if proto == "" {
if host == "localhost" {
proto = "http"
} else if strings.Index(host, ":") != -1 {
// has a port number
proto = "http"
} else if _, err := strconv.Atoi(strings.ReplaceAll(host, ".", "")); err == nil {
// it's a naked IP
proto = "http"
} else {
proto = "https"
}
}
return proto + "://" + host
}