diff --git a/start.go b/start.go index dfd4224..0895bbb 100644 --- a/start.go +++ b/start.go @@ -34,7 +34,6 @@ type Server struct { // outputting to stderr. Log Logger - addr string relay Relay // keep a connection reference to all connected clients for Server.Shutdown @@ -42,6 +41,7 @@ type Server struct { clients map[*websocket.Conn]struct{} // in case you call Server.Start + Addr string httpServer *http.Server } @@ -83,13 +83,14 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (s *Server) Start(host string, port int) error { +func (s *Server) Start(host string, port int, started ...chan bool) error { addr := net.JoinHostPort(host, strconv.Itoa(port)) ln, err := net.Listen("tcp", addr) if err != nil { return err } + s.Addr = ln.Addr().String() s.httpServer = &http.Server{ Handler: cors.Default().Handler(s), Addr: addr, @@ -98,6 +99,11 @@ func (s *Server) Start(host string, port int) error { IdleTimeout: 30 * time.Second, } + // notify caller that we're starting + for _, started := range started { + close(started) + } + if err := s.httpServer.Serve(ln); err == http.ErrServerClosed { return nil } else if err != nil { diff --git a/start_test.go b/start_test.go index 137ed3c..3a23f18 100644 --- a/start_test.go +++ b/start_test.go @@ -12,38 +12,28 @@ import ( func TestServerStartShutdown(t *testing.T) { var ( - serverHost string inited bool storeInited bool shutdown bool ) - ready := make(chan struct{}) rl := &testRelay{ name: "test server start", init: func() error { inited = true return nil }, - onInitialized: func(s *Server) { - serverHost = s.Addr() - close(ready) - }, onShutdown: func(context.Context) { shutdown = true }, storage: &testStorage{ init: func() error { storeInited = true; return nil }, }, } - srv := NewServer("127.0.0.1:0", rl) + srv, _ := NewServer(rl) + ready := make(chan bool) done := make(chan error) - go func() { done <- srv.Start(); close(done) }() + go func() { done <- srv.Start("127.0.0.1", 0, ready); close(done) }() + <-ready // verify everything's initialized - select { - case <-ready: - // continue - case <-time.After(time.Second): - t.Fatal("srv.Start too long to initialize") - } if !inited { t.Error("didn't call testRelay.init") } @@ -52,16 +42,14 @@ func TestServerStartShutdown(t *testing.T) { } // check that http requests are served - if _, err := http.Get("http://" + serverHost); err != nil { - t.Errorf("GET %s: %v", serverHost, err) + if _, err := http.Get("http://" + srv.Addr); err != nil { + t.Errorf("GET %s: %v", srv.Addr, err) } // verify server shuts down ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() - if err := srv.Shutdown(ctx); err != nil { - t.Errorf("srv.Shutdown: %v", err) - } + srv.Shutdown(ctx) if !shutdown { t.Error("didn't call testRelay.onShutdown") } @@ -82,7 +70,7 @@ func TestServerShutdownWebsocket(t *testing.T) { // connect a client to it ctx1, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - client, err := nostr.RelayConnect(ctx1, "ws://"+srv.Addr()) + client, err := nostr.RelayConnect(ctx1, "ws://"+srv.Addr) if err != nil { t.Fatalf("nostr.RelayConnectContext: %v", err) } @@ -90,17 +78,12 @@ func TestServerShutdownWebsocket(t *testing.T) { // now, shut down the server ctx2, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() - if err := srv.Shutdown(ctx2); err != nil { - t.Errorf("srv.Shutdown: %v", err) - } + srv.Shutdown(ctx2) // wait for the client to receive a "connection close" - select { - case err := <-client.ConnectionError: - if _, ok := err.(*websocket.CloseError); !ok { - t.Errorf("client.ConnextionError: %v (%T); want websocket.CloseError", err, err) - } - case <-time.After(2 * time.Second): - t.Error("client took too long to disconnect") + time.Sleep(1 * time.Second) + err = client.ConnectionError + if _, ok := err.(*websocket.CloseError); !ok { + t.Errorf("client.ConnextionError: %v (%T); want websocket.CloseError", err, err) } } diff --git a/util_test.go b/util_test.go index f1fff91..ec0c385 100644 --- a/util_test.go +++ b/util_test.go @@ -3,44 +3,29 @@ package relayer import ( "context" "testing" - "time" "github.com/nbd-wtf/go-nostr" ) func startTestRelay(t *testing.T, tr *testRelay) *Server { t.Helper() - ready := make(chan struct{}) - - onInitializedFn := tr.onInitialized - tr.onInitialized = func(s *Server) { - close(ready) - if onInitializedFn != nil { - onInitializedFn(s) - } - } - srv := NewServer("127.0.0.1:0", tr) - go srv.Start() - - select { - case <-ready: - case <-time.After(time.Second): - t.Fatal("server took too long to start up") - } + srv, _ := NewServer(tr) + started := make(chan bool) + go srv.Start("127.0.0.1", 0, started) + <-started return srv } type testRelay struct { - name string - storage Storage - init func() error - onInitialized func(*Server) - onShutdown func(context.Context) - acceptEvent func(*nostr.Event) bool + name string + storage Storage + init func() error + onShutdown func(context.Context) + acceptEvent func(*nostr.Event) bool } -func (tr *testRelay) Name() string { return tr.name } -func (tr *testRelay) Storage() Storage { return tr.storage } +func (tr *testRelay) Name() string { return tr.name } +func (tr *testRelay) Storage(context.Context) Storage { return tr.storage } func (tr *testRelay) Init() error { if fn := tr.init; fn != nil { @@ -49,19 +34,13 @@ func (tr *testRelay) Init() error { return nil } -func (tr *testRelay) OnInitialized(s *Server) { - if fn := tr.onInitialized; fn != nil { - fn(s) - } -} - func (tr *testRelay) OnShutdown(ctx context.Context) { if fn := tr.onShutdown; fn != nil { fn(ctx) } } -func (tr *testRelay) AcceptEvent(e *nostr.Event) bool { +func (tr *testRelay) AcceptEvent(ctx context.Context, e *nostr.Event) bool { if fn := tr.acceptEvent; fn != nil { return fn(e) } @@ -70,9 +49,9 @@ func (tr *testRelay) AcceptEvent(e *nostr.Event) bool { type testStorage struct { init func() error - queryEvents func(*nostr.Filter) ([]nostr.Event, error) - deleteEvent func(id string, pubkey string) error - saveEvent func(*nostr.Event) error + queryEvents func(context.Context, *nostr.Filter) (chan *nostr.Event, error) + deleteEvent func(ctx context.Context, id string, pubkey string) error + saveEvent func(context.Context, *nostr.Event) error } func (st *testStorage) Init() error { @@ -82,23 +61,23 @@ func (st *testStorage) Init() error { return nil } -func (st *testStorage) QueryEvents(f *nostr.Filter) ([]nostr.Event, error) { +func (st *testStorage) QueryEvents(ctx context.Context, f *nostr.Filter) (chan *nostr.Event, error) { if fn := st.queryEvents; fn != nil { - return fn(f) + return fn(ctx, f) } return nil, nil } -func (st *testStorage) DeleteEvent(id string, pubkey string) error { +func (st *testStorage) DeleteEvent(ctx context.Context, id string, pubkey string) error { if fn := st.deleteEvent; fn != nil { - return fn(id, pubkey) + return fn(ctx, id, pubkey) } return nil } -func (st *testStorage) SaveEvent(e *nostr.Event) error { +func (st *testStorage) SaveEvent(ctx context.Context, e *nostr.Event) error { if fn := st.saveEvent; fn != nil { - return fn(e) + return fn(ctx, e) } return nil }