make tests pass on base package.

This commit is contained in:
fiatjaf
2023-05-01 20:38:11 -03:00
parent a4512da371
commit e84f5df1f0
3 changed files with 42 additions and 74 deletions

View File

@@ -34,7 +34,6 @@ type Server struct {
// outputting to stderr. // outputting to stderr.
Log Logger Log Logger
addr string
relay Relay relay Relay
// keep a connection reference to all connected clients for Server.Shutdown // keep a connection reference to all connected clients for Server.Shutdown
@@ -42,6 +41,7 @@ type Server struct {
clients map[*websocket.Conn]struct{} clients map[*websocket.Conn]struct{}
// in case you call Server.Start // in case you call Server.Start
Addr string
httpServer *http.Server httpServer *http.Server
} }
@@ -83,13 +83,14 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
func (s *Server) Start(host string, port int) error { func (s *Server) Start(host string, port int, started ...chan bool) error {
addr := net.JoinHostPort(host, strconv.Itoa(port)) addr := net.JoinHostPort(host, strconv.Itoa(port))
ln, err := net.Listen("tcp", addr) ln, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
return err return err
} }
s.Addr = ln.Addr().String()
s.httpServer = &http.Server{ s.httpServer = &http.Server{
Handler: cors.Default().Handler(s), Handler: cors.Default().Handler(s),
Addr: addr, Addr: addr,
@@ -98,6 +99,11 @@ func (s *Server) Start(host string, port int) error {
IdleTimeout: 30 * time.Second, IdleTimeout: 30 * time.Second,
} }
// notify caller that we're starting
for _, started := range started {
close(started)
}
if err := s.httpServer.Serve(ln); err == http.ErrServerClosed { if err := s.httpServer.Serve(ln); err == http.ErrServerClosed {
return nil return nil
} else if err != nil { } else if err != nil {

View File

@@ -12,38 +12,28 @@ import (
func TestServerStartShutdown(t *testing.T) { func TestServerStartShutdown(t *testing.T) {
var ( var (
serverHost string
inited bool inited bool
storeInited bool storeInited bool
shutdown bool shutdown bool
) )
ready := make(chan struct{})
rl := &testRelay{ rl := &testRelay{
name: "test server start", name: "test server start",
init: func() error { init: func() error {
inited = true inited = true
return nil return nil
}, },
onInitialized: func(s *Server) {
serverHost = s.Addr()
close(ready)
},
onShutdown: func(context.Context) { shutdown = true }, onShutdown: func(context.Context) { shutdown = true },
storage: &testStorage{ storage: &testStorage{
init: func() error { storeInited = true; return nil }, init: func() error { storeInited = true; return nil },
}, },
} }
srv := NewServer("127.0.0.1:0", rl) srv, _ := NewServer(rl)
ready := make(chan bool)
done := make(chan error) done := make(chan error)
go func() { done <- srv.Start(); close(done) }() go func() { done <- srv.Start("127.0.0.1", 0, ready); close(done) }()
<-ready
// verify everything's initialized // verify everything's initialized
select {
case <-ready:
// continue
case <-time.After(time.Second):
t.Fatal("srv.Start too long to initialize")
}
if !inited { if !inited {
t.Error("didn't call testRelay.init") t.Error("didn't call testRelay.init")
} }
@@ -52,16 +42,14 @@ func TestServerStartShutdown(t *testing.T) {
} }
// check that http requests are served // check that http requests are served
if _, err := http.Get("http://" + serverHost); err != nil { if _, err := http.Get("http://" + srv.Addr); err != nil {
t.Errorf("GET %s: %v", serverHost, err) t.Errorf("GET %s: %v", srv.Addr, err)
} }
// verify server shuts down // verify server shuts down
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel() defer cancel()
if err := srv.Shutdown(ctx); err != nil { srv.Shutdown(ctx)
t.Errorf("srv.Shutdown: %v", err)
}
if !shutdown { if !shutdown {
t.Error("didn't call testRelay.onShutdown") t.Error("didn't call testRelay.onShutdown")
} }
@@ -82,7 +70,7 @@ func TestServerShutdownWebsocket(t *testing.T) {
// connect a client to it // connect a client to it
ctx1, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx1, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
client, err := nostr.RelayConnect(ctx1, "ws://"+srv.Addr()) client, err := nostr.RelayConnect(ctx1, "ws://"+srv.Addr)
if err != nil { if err != nil {
t.Fatalf("nostr.RelayConnectContext: %v", err) t.Fatalf("nostr.RelayConnectContext: %v", err)
} }
@@ -90,17 +78,12 @@ func TestServerShutdownWebsocket(t *testing.T) {
// now, shut down the server // now, shut down the server
ctx2, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx2, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel() defer cancel()
if err := srv.Shutdown(ctx2); err != nil { srv.Shutdown(ctx2)
t.Errorf("srv.Shutdown: %v", err)
}
// wait for the client to receive a "connection close" // wait for the client to receive a "connection close"
select { time.Sleep(1 * time.Second)
case err := <-client.ConnectionError: err = client.ConnectionError
if _, ok := err.(*websocket.CloseError); !ok { if _, ok := err.(*websocket.CloseError); !ok {
t.Errorf("client.ConnextionError: %v (%T); want websocket.CloseError", err, err) t.Errorf("client.ConnextionError: %v (%T); want websocket.CloseError", err, err)
} }
case <-time.After(2 * time.Second):
t.Error("client took too long to disconnect")
}
} }

View File

@@ -3,30 +3,16 @@ package relayer
import ( import (
"context" "context"
"testing" "testing"
"time"
"github.com/nbd-wtf/go-nostr" "github.com/nbd-wtf/go-nostr"
) )
func startTestRelay(t *testing.T, tr *testRelay) *Server { func startTestRelay(t *testing.T, tr *testRelay) *Server {
t.Helper() t.Helper()
ready := make(chan struct{}) srv, _ := NewServer(tr)
started := make(chan bool)
onInitializedFn := tr.onInitialized go srv.Start("127.0.0.1", 0, started)
tr.onInitialized = func(s *Server) { <-started
close(ready)
if onInitializedFn != nil {
onInitializedFn(s)
}
}
srv := NewServer("127.0.0.1:0", tr)
go srv.Start()
select {
case <-ready:
case <-time.After(time.Second):
t.Fatal("server took too long to start up")
}
return srv return srv
} }
@@ -34,13 +20,12 @@ type testRelay struct {
name string name string
storage Storage storage Storage
init func() error init func() error
onInitialized func(*Server)
onShutdown func(context.Context) onShutdown func(context.Context)
acceptEvent func(*nostr.Event) bool acceptEvent func(*nostr.Event) bool
} }
func (tr *testRelay) Name() string { return tr.name } func (tr *testRelay) Name() string { return tr.name }
func (tr *testRelay) Storage() Storage { return tr.storage } func (tr *testRelay) Storage(context.Context) Storage { return tr.storage }
func (tr *testRelay) Init() error { func (tr *testRelay) Init() error {
if fn := tr.init; fn != nil { if fn := tr.init; fn != nil {
@@ -49,19 +34,13 @@ func (tr *testRelay) Init() error {
return nil return nil
} }
func (tr *testRelay) OnInitialized(s *Server) {
if fn := tr.onInitialized; fn != nil {
fn(s)
}
}
func (tr *testRelay) OnShutdown(ctx context.Context) { func (tr *testRelay) OnShutdown(ctx context.Context) {
if fn := tr.onShutdown; fn != nil { if fn := tr.onShutdown; fn != nil {
fn(ctx) fn(ctx)
} }
} }
func (tr *testRelay) AcceptEvent(e *nostr.Event) bool { func (tr *testRelay) AcceptEvent(ctx context.Context, e *nostr.Event) bool {
if fn := tr.acceptEvent; fn != nil { if fn := tr.acceptEvent; fn != nil {
return fn(e) return fn(e)
} }
@@ -70,9 +49,9 @@ func (tr *testRelay) AcceptEvent(e *nostr.Event) bool {
type testStorage struct { type testStorage struct {
init func() error init func() error
queryEvents func(*nostr.Filter) ([]nostr.Event, error) queryEvents func(context.Context, *nostr.Filter) (chan *nostr.Event, error)
deleteEvent func(id string, pubkey string) error deleteEvent func(ctx context.Context, id string, pubkey string) error
saveEvent func(*nostr.Event) error saveEvent func(context.Context, *nostr.Event) error
} }
func (st *testStorage) Init() error { func (st *testStorage) Init() error {
@@ -82,23 +61,23 @@ func (st *testStorage) Init() error {
return nil return nil
} }
func (st *testStorage) QueryEvents(f *nostr.Filter) ([]nostr.Event, error) { func (st *testStorage) QueryEvents(ctx context.Context, f *nostr.Filter) (chan *nostr.Event, error) {
if fn := st.queryEvents; fn != nil { if fn := st.queryEvents; fn != nil {
return fn(f) return fn(ctx, f)
} }
return nil, nil return nil, nil
} }
func (st *testStorage) DeleteEvent(id string, pubkey string) error { func (st *testStorage) DeleteEvent(ctx context.Context, id string, pubkey string) error {
if fn := st.deleteEvent; fn != nil { if fn := st.deleteEvent; fn != nil {
return fn(id, pubkey) return fn(ctx, id, pubkey)
} }
return nil return nil
} }
func (st *testStorage) SaveEvent(e *nostr.Event) error { func (st *testStorage) SaveEvent(ctx context.Context, e *nostr.Event) error {
if fn := st.saveEvent; fn != nil { if fn := st.saveEvent; fn != nil {
return fn(e) return fn(ctx, e)
} }
return nil return nil
} }