support NIP-45 (#58)

This commit is contained in:
mattn 2023-05-17 19:54:56 +09:00 committed by GitHub
parent c4a678da1e
commit 639c210661
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 340 additions and 48 deletions

View File

@ -179,6 +179,65 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
ok, message := AddEvent(ctx, s.relay, evt)
ws.WriteJSON([]interface{}{"OK", evt.ID, ok, message})
case "COUNT":
counter, ok := store.(EventCounter)
if !ok {
notice = "restricted: this relay does not support NIP-45"
return
}
var id string
json.Unmarshal(request[1], &id)
if id == "" {
notice = "COUNT has no <id>"
return
}
total := int64(0)
filters := make(nostr.Filters, len(request)-2)
for i, filterReq := range request[2:] {
if err := json.Unmarshal(filterReq, &filters[i]); err != nil {
notice = "failed to decode filter"
return
}
filter := &filters[i]
// prevent kind-4 events from being returned to unauthed users,
// only when authentication is a thing
if _, ok := s.relay.(Auther); ok {
if slices.Contains(filter.Kinds, 4) {
senders := filter.Authors
receivers, _ := filter.Tags["p"]
switch {
case ws.authed == "":
// not authenticated
notice = "restricted: this relay does not serve kind-4 to unauthenticated users, does your client implement NIP-42?"
return
case len(senders) == 1 && len(receivers) < 2 && (senders[0] == ws.authed):
// allowed filter: ws.authed is sole sender (filter specifies one or all receivers)
case len(receivers) == 1 && len(senders) < 2 && (receivers[0] == ws.authed):
// allowed filter: ws.authed is sole receiver (filter specifies one or all senders)
default:
// restricted filter: do not return any events,
// even if other elements in filters array were not restricted).
// client should know better.
notice = "restricted: authenticated user does not have authorization for requested filters."
return
}
}
}
count, err := counter.CountEvents(ctx, filter)
if err != nil {
s.Log.Errorf("store: %v", err)
continue
}
total += count
}
ws.WriteJSON([]interface{}{"COUNT", id, map[string]int64{"count": total}})
setListener(id, ws, filters)
case "REQ":
var id string
json.Unmarshal(request[1], &id)
@ -312,6 +371,11 @@ func (s *Server) HandleNIP11(w http.ResponseWriter, r *http.Request) {
if _, ok := s.relay.(Auther); ok {
supportedNIPs = append(supportedNIPs, 42)
}
if storage, ok := s.relay.(Storage); ok && storage != nil {
if _, ok = storage.(EventCounter); ok {
supportedNIPs = append(supportedNIPs, 45)
}
}
info := nip11.RelayInformationDocument{
Name: s.relay.Name(),

View File

@ -90,3 +90,7 @@ type AdvancedSaver interface {
BeforeSave(context.Context, *nostr.Event)
AfterSave(*nostr.Event)
}
type EventCounter interface {
CountEvents(ctx context.Context, filter *nostr.Filter) (int64, error)
}

View File

@ -2,11 +2,12 @@ package relayer
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/gobwas/ws/wsutil"
"github.com/nbd-wtf/go-nostr"
)
@ -83,7 +84,10 @@ func TestServerShutdownWebsocket(t *testing.T) {
// wait for the client to receive a "connection close"
time.Sleep(1 * time.Second)
err = client.ConnectionError
if _, ok := err.(*websocket.CloseError); !ok {
t.Errorf("client.ConnextionError: %v (%T); want websocket.CloseError", err, err)
if e := errors.Unwrap(err); e != nil {
err = e
}
if _, ok := err.(wsutil.ClosedError); !ok {
t.Errorf("client.ConnextionError: %v (%T); want wsutil.ClosedError", err, err)
}
}

View File

@ -15,7 +15,7 @@ import (
func (b PostgresBackend) QueryEvents(ctx context.Context, filter *nostr.Filter) (ch chan *nostr.Event, err error) {
ch = make(chan *nostr.Event)
query, params, err := queryEventsSql(filter)
query, params, err := queryEventsSql(filter, false)
if err != nil {
return nil, err
}
@ -44,7 +44,20 @@ func (b PostgresBackend) QueryEvents(ctx context.Context, filter *nostr.Filter)
return ch, nil
}
func queryEventsSql(filter *nostr.Filter) (string, []any, error) {
func (b PostgresBackend) CountEvents(ctx context.Context, filter *nostr.Filter) (int64, error) {
query, params, err := queryEventsSql(filter, true)
if err != nil {
return 0, err
}
var count int64
if err = b.DB.QueryRow(query, params...).Scan(&count); err != nil && err != sql.ErrNoRows {
return 0, fmt.Errorf("failed to fetch events using query %q: %w", query, err)
}
return count, nil
}
func queryEventsSql(filter *nostr.Filter, doCount bool) (string, []any, error) {
var conditions []string
var params []any
@ -165,11 +178,20 @@ func queryEventsSql(filter *nostr.Filter) (string, []any, error) {
params = append(params, filter.Limit)
}
query := sqlx.Rebind(sqlx.BindType("postgres"), `SELECT
id, pubkey, created_at, kind, tags, content, sig
FROM event WHERE `+
strings.Join(conditions, " AND ")+
" ORDER BY created_at DESC LIMIT ?")
var query string
if doCount {
query = sqlx.Rebind(sqlx.BindType("postgres"), `SELECT
COUNT(*)
FROM event WHERE `+
strings.Join(conditions, " AND ")+
" ORDER BY created_at DESC LIMIT ?")
} else {
query = sqlx.Rebind(sqlx.BindType("postgres"), `SELECT
id, pubkey, created_at, kind, tags, content, sig
FROM event WHERE `+
strings.Join(conditions, " AND ")+
" ORDER BY created_at DESC LIMIT ?")
}
return query, params, nil
}

View File

@ -157,7 +157,7 @@ func TestQueryEventsSql(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query, params, err := queryEventsSql(tt.filter)
query, params, err := queryEventsSql(tt.filter, false)
assert.Equal(t, tt.err, err)
if err != nil {
return
@ -188,3 +188,162 @@ func strSlice(n int) []string {
}
return slice
}
func TestCountEventsSql(t *testing.T) {
var tests = []struct {
name string
filter *nostr.Filter
query string
params []any
err error
}{
{
name: "empty filter",
filter: &nostr.Filter{},
query: "SELECT COUNT(*) FROM event WHERE true ORDER BY created_at DESC LIMIT $1",
params: []any{100},
err: nil,
},
{
name: "ids filter",
filter: &nostr.Filter{
IDs: []string{"083ec57f36a7b39ab98a57bedab4f85355b2ee89e4b205bed58d7c3ef9edd294"},
},
query: `SELECT COUNT(*)
FROM event
WHERE (id LIKE '083ec57f36a7b39ab98a57bedab4f85355b2ee89e4b205bed58d7c3ef9edd294%')
ORDER BY created_at DESC LIMIT $1`,
params: []any{100},
err: nil,
},
{
name: "kind filter",
filter: &nostr.Filter{
Kinds: []int{1, 2, 3},
},
query: `SELECT COUNT(*)
FROM event
WHERE kind IN(1,2,3)
ORDER BY created_at DESC LIMIT $1`,
params: []any{100},
err: nil,
},
{
name: "authors filter",
filter: &nostr.Filter{
Authors: []string{"7bdef7bdebb8721f77927d0e77c66059360fa62371fdf15f3add93923a613229"},
},
query: `SELECT COUNT(*)
FROM event
WHERE (pubkey LIKE '7bdef7bdebb8721f77927d0e77c66059360fa62371fdf15f3add93923a613229%')
ORDER BY created_at DESC LIMIT $1`,
params: []any{100},
err: nil,
},
// errors
{
name: "nil filter",
filter: nil,
query: "",
params: nil,
err: fmt.Errorf("filter cannot be null"),
},
{
name: "too many ids",
filter: &nostr.Filter{
IDs: strSlice(501),
},
query: "",
params: nil,
// REVIEW: should return error
err: nil,
},
{
name: "invalid ids",
filter: &nostr.Filter{
IDs: []string{"stuff"},
},
query: "",
params: nil,
// REVIEW: should return error
err: nil,
},
{
name: "too many authors",
filter: &nostr.Filter{
Authors: strSlice(501),
},
query: "",
params: nil,
// REVIEW: should return error
err: nil,
},
{
name: "invalid authors",
filter: &nostr.Filter{
Authors: []string{"stuff"},
},
query: "",
params: nil,
// REVIEW: should return error
err: nil,
},
{
name: "too many kinds",
filter: &nostr.Filter{
Kinds: intSlice(11),
},
query: "",
params: nil,
// REVIEW: should return error
err: nil,
},
{
name: "no kinds",
filter: &nostr.Filter{
Kinds: []int{},
},
query: "",
params: nil,
// REVIEW: should return error
err: nil,
},
{
name: "tags of empty array",
filter: &nostr.Filter{
Tags: nostr.TagMap{
"#e": []string{},
},
},
query: "",
params: nil,
// REVIEW: should return error
err: nil,
},
{
name: "too many tag values",
filter: &nostr.Filter{
Tags: nostr.TagMap{
"#e": strSlice(11),
},
},
query: "",
params: nil,
// REVIEW: should return error
err: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query, params, err := queryEventsSql(tt.filter, true)
assert.Equal(t, tt.err, err)
if err != nil {
return
}
assert.Equal(t, clean(tt.query), clean(query))
assert.Equal(t, tt.params, params)
})
}
}

View File

@ -4,29 +4,72 @@ import (
"context"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"strconv"
"strings"
"github.com/jmoiron/sqlx"
"github.com/nbd-wtf/go-nostr"
)
func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) (ch chan *nostr.Event, err error) {
ch = make(chan *nostr.Event)
query, params, err := queryEventsSql(filter, false)
if err != nil {
return nil, err
}
rows, err := b.DB.Query(query, params...)
if err != nil && err != sql.ErrNoRows {
return nil, fmt.Errorf("failed to fetch events using query %q: %w", query, err)
}
go func() {
defer rows.Close()
defer close(ch)
for rows.Next() {
var evt nostr.Event
var timestamp int64
err := rows.Scan(&evt.ID, &evt.PubKey, &timestamp,
&evt.Kind, &evt.Tags, &evt.Content, &evt.Sig)
if err != nil {
return
}
evt.CreatedAt = nostr.Timestamp(timestamp)
ch <- &evt
}
}()
return ch, nil
}
func (b SQLite3Backend) CountEvents(ctx context.Context, filter *nostr.Filter) (int64, error) {
query, params, err := queryEventsSql(filter, true)
if err != nil {
return 0, err
}
var count int64
err = b.DB.QueryRow(query, params...).Scan(&count)
if err != nil && err != sql.ErrNoRows {
return 0, fmt.Errorf("failed to fetch events using query %q: %w", query, err)
}
return count, nil
}
func queryEventsSql(filter *nostr.Filter, doCount bool) (string, []any, error) {
var conditions []string
var params []any
if filter == nil {
err = errors.New("filter cannot be null")
return
return "", nil, fmt.Errorf("filter cannot be null")
}
if filter.IDs != nil {
if len(filter.IDs) > 500 {
// too many ids, fail everything
return
return "", nil, nil
}
likeids := make([]string, 0, len(filter.IDs))
@ -41,7 +84,7 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) (
}
if len(likeids) == 0 {
// ids being [] mean you won't get anything
return
return "", nil, nil
}
conditions = append(conditions, "("+strings.Join(likeids, " OR ")+")")
}
@ -49,7 +92,7 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) (
if filter.Authors != nil {
if len(filter.Authors) > 500 {
// too many authors, fail everything
return
return "", nil, nil
}
likekeys := make([]string, 0, len(filter.Authors))
@ -64,7 +107,7 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) (
}
if len(likekeys) == 0 {
// authors being [] mean you won't get anything
return
return "", nil, nil
}
conditions = append(conditions, "("+strings.Join(likekeys, " OR ")+")")
}
@ -72,12 +115,12 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) (
if filter.Kinds != nil {
if len(filter.Kinds) > 10 {
// too many kinds, fail everything
return
return "", nil, nil
}
if len(filter.Kinds) == 0 {
// kinds being [] mean you won't get anything
return
return "", nil, nil
}
// no sql injection issues since these are ints
inkinds := make([]string, len(filter.Kinds))
@ -91,7 +134,7 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) (
for _, values := range filter.Tags {
if len(values) == 0 {
// any tag set to [] is wrong
return
return "", nil, nil
}
// add these tags to the query
@ -99,7 +142,7 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) (
if len(tagQuery) > 10 {
// too many tags, fail everything
return
return "", nil, nil
}
}
@ -134,32 +177,20 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) (
params = append(params, filter.Limit)
}
query := b.DB.Rebind(`SELECT
id, pubkey, created_at, kind, tags, content, sig
FROM event WHERE ` +
strings.Join(conditions, " AND ") +
" ORDER BY created_at DESC LIMIT ?")
rows, err := b.DB.Query(query, params...)
if err != nil && err != sql.ErrNoRows {
return nil, fmt.Errorf("failed to fetch events using query %q: %w", query, err)
var query string
if doCount {
query = sqlx.Rebind(sqlx.BindType("sqlite3"), `SELECT
COUNT(*)
FROM event WHERE `+
strings.Join(conditions, " AND ")+
" ORDER BY created_at DESC LIMIT ?")
} else {
query = sqlx.Rebind(sqlx.BindType("sqlite3"), `SELECT
id, pubkey, created_at, kind, tags, content, sig
FROM event WHERE `+
strings.Join(conditions, " AND ")+
" ORDER BY created_at DESC LIMIT ?")
}
go func() {
defer rows.Close()
defer close(ch)
for rows.Next() {
var evt nostr.Event
var timestamp int64
err := rows.Scan(&evt.ID, &evt.PubKey, &timestamp,
&evt.Kind, &evt.Tags, &evt.Content, &evt.Sig)
if err != nil {
return
}
evt.CreatedAt = nostr.Timestamp(timestamp)
ch <- &evt
}
}()
return ch, nil
return query, params, nil
}

View File

@ -52,6 +52,7 @@ type testStorage struct {
queryEvents func(context.Context, *nostr.Filter) (chan *nostr.Event, error)
deleteEvent func(ctx context.Context, id string, pubkey string) error
saveEvent func(context.Context, *nostr.Event) error
countEvents func(context.Context, *nostr.Filter) (int64, error)
}
func (st *testStorage) Init() error {
@ -81,3 +82,10 @@ func (st *testStorage) SaveEvent(ctx context.Context, e *nostr.Event) error {
}
return nil
}
func (st *testStorage) CountEvents(ctx context.Context, f *nostr.Filter) (int64, error) {
if fn := st.countEvents; fn != nil {
return fn(ctx, f)
}
return 0, nil
}