mirror of
https://github.com/fiatjaf/khatru.git
synced 2025-03-18 05:42:19 +01:00
support NIP-45 (#58)
This commit is contained in:
parent
c4a678da1e
commit
639c210661
64
handlers.go
64
handlers.go
@ -179,6 +179,65 @@ func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
ok, message := AddEvent(ctx, s.relay, evt)
|
||||
ws.WriteJSON([]interface{}{"OK", evt.ID, ok, message})
|
||||
|
||||
case "COUNT":
|
||||
counter, ok := store.(EventCounter)
|
||||
if !ok {
|
||||
notice = "restricted: this relay does not support NIP-45"
|
||||
return
|
||||
}
|
||||
|
||||
var id string
|
||||
json.Unmarshal(request[1], &id)
|
||||
if id == "" {
|
||||
notice = "COUNT has no <id>"
|
||||
return
|
||||
}
|
||||
|
||||
total := int64(0)
|
||||
filters := make(nostr.Filters, len(request)-2)
|
||||
for i, filterReq := range request[2:] {
|
||||
if err := json.Unmarshal(filterReq, &filters[i]); err != nil {
|
||||
notice = "failed to decode filter"
|
||||
return
|
||||
}
|
||||
|
||||
filter := &filters[i]
|
||||
|
||||
// prevent kind-4 events from being returned to unauthed users,
|
||||
// only when authentication is a thing
|
||||
if _, ok := s.relay.(Auther); ok {
|
||||
if slices.Contains(filter.Kinds, 4) {
|
||||
senders := filter.Authors
|
||||
receivers, _ := filter.Tags["p"]
|
||||
switch {
|
||||
case ws.authed == "":
|
||||
// not authenticated
|
||||
notice = "restricted: this relay does not serve kind-4 to unauthenticated users, does your client implement NIP-42?"
|
||||
return
|
||||
case len(senders) == 1 && len(receivers) < 2 && (senders[0] == ws.authed):
|
||||
// allowed filter: ws.authed is sole sender (filter specifies one or all receivers)
|
||||
case len(receivers) == 1 && len(senders) < 2 && (receivers[0] == ws.authed):
|
||||
// allowed filter: ws.authed is sole receiver (filter specifies one or all senders)
|
||||
default:
|
||||
// restricted filter: do not return any events,
|
||||
// even if other elements in filters array were not restricted).
|
||||
// client should know better.
|
||||
notice = "restricted: authenticated user does not have authorization for requested filters."
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
count, err := counter.CountEvents(ctx, filter)
|
||||
if err != nil {
|
||||
s.Log.Errorf("store: %v", err)
|
||||
continue
|
||||
}
|
||||
total += count
|
||||
}
|
||||
|
||||
ws.WriteJSON([]interface{}{"COUNT", id, map[string]int64{"count": total}})
|
||||
setListener(id, ws, filters)
|
||||
case "REQ":
|
||||
var id string
|
||||
json.Unmarshal(request[1], &id)
|
||||
@ -312,6 +371,11 @@ func (s *Server) HandleNIP11(w http.ResponseWriter, r *http.Request) {
|
||||
if _, ok := s.relay.(Auther); ok {
|
||||
supportedNIPs = append(supportedNIPs, 42)
|
||||
}
|
||||
if storage, ok := s.relay.(Storage); ok && storage != nil {
|
||||
if _, ok = storage.(EventCounter); ok {
|
||||
supportedNIPs = append(supportedNIPs, 45)
|
||||
}
|
||||
}
|
||||
|
||||
info := nip11.RelayInformationDocument{
|
||||
Name: s.relay.Name(),
|
||||
|
@ -90,3 +90,7 @@ type AdvancedSaver interface {
|
||||
BeforeSave(context.Context, *nostr.Event)
|
||||
AfterSave(*nostr.Event)
|
||||
}
|
||||
|
||||
type EventCounter interface {
|
||||
CountEvents(ctx context.Context, filter *nostr.Filter) (int64, error)
|
||||
}
|
||||
|
@ -2,11 +2,12 @@ package relayer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/gobwas/ws/wsutil"
|
||||
"github.com/nbd-wtf/go-nostr"
|
||||
)
|
||||
|
||||
@ -83,7 +84,10 @@ func TestServerShutdownWebsocket(t *testing.T) {
|
||||
// wait for the client to receive a "connection close"
|
||||
time.Sleep(1 * time.Second)
|
||||
err = client.ConnectionError
|
||||
if _, ok := err.(*websocket.CloseError); !ok {
|
||||
t.Errorf("client.ConnextionError: %v (%T); want websocket.CloseError", err, err)
|
||||
if e := errors.Unwrap(err); e != nil {
|
||||
err = e
|
||||
}
|
||||
if _, ok := err.(wsutil.ClosedError); !ok {
|
||||
t.Errorf("client.ConnextionError: %v (%T); want wsutil.ClosedError", err, err)
|
||||
}
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ import (
|
||||
func (b PostgresBackend) QueryEvents(ctx context.Context, filter *nostr.Filter) (ch chan *nostr.Event, err error) {
|
||||
ch = make(chan *nostr.Event)
|
||||
|
||||
query, params, err := queryEventsSql(filter)
|
||||
query, params, err := queryEventsSql(filter, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -44,7 +44,20 @@ func (b PostgresBackend) QueryEvents(ctx context.Context, filter *nostr.Filter)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func queryEventsSql(filter *nostr.Filter) (string, []any, error) {
|
||||
func (b PostgresBackend) CountEvents(ctx context.Context, filter *nostr.Filter) (int64, error) {
|
||||
query, params, err := queryEventsSql(filter, true)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var count int64
|
||||
if err = b.DB.QueryRow(query, params...).Scan(&count); err != nil && err != sql.ErrNoRows {
|
||||
return 0, fmt.Errorf("failed to fetch events using query %q: %w", query, err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func queryEventsSql(filter *nostr.Filter, doCount bool) (string, []any, error) {
|
||||
var conditions []string
|
||||
var params []any
|
||||
|
||||
@ -165,11 +178,20 @@ func queryEventsSql(filter *nostr.Filter) (string, []any, error) {
|
||||
params = append(params, filter.Limit)
|
||||
}
|
||||
|
||||
query := sqlx.Rebind(sqlx.BindType("postgres"), `SELECT
|
||||
id, pubkey, created_at, kind, tags, content, sig
|
||||
FROM event WHERE `+
|
||||
strings.Join(conditions, " AND ")+
|
||||
" ORDER BY created_at DESC LIMIT ?")
|
||||
var query string
|
||||
if doCount {
|
||||
query = sqlx.Rebind(sqlx.BindType("postgres"), `SELECT
|
||||
COUNT(*)
|
||||
FROM event WHERE `+
|
||||
strings.Join(conditions, " AND ")+
|
||||
" ORDER BY created_at DESC LIMIT ?")
|
||||
} else {
|
||||
query = sqlx.Rebind(sqlx.BindType("postgres"), `SELECT
|
||||
id, pubkey, created_at, kind, tags, content, sig
|
||||
FROM event WHERE `+
|
||||
strings.Join(conditions, " AND ")+
|
||||
" ORDER BY created_at DESC LIMIT ?")
|
||||
}
|
||||
|
||||
return query, params, nil
|
||||
}
|
||||
|
@ -157,7 +157,7 @@ func TestQueryEventsSql(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
query, params, err := queryEventsSql(tt.filter)
|
||||
query, params, err := queryEventsSql(tt.filter, false)
|
||||
assert.Equal(t, tt.err, err)
|
||||
if err != nil {
|
||||
return
|
||||
@ -188,3 +188,162 @@ func strSlice(n int) []string {
|
||||
}
|
||||
return slice
|
||||
}
|
||||
|
||||
func TestCountEventsSql(t *testing.T) {
|
||||
var tests = []struct {
|
||||
name string
|
||||
filter *nostr.Filter
|
||||
query string
|
||||
params []any
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "empty filter",
|
||||
filter: &nostr.Filter{},
|
||||
query: "SELECT COUNT(*) FROM event WHERE true ORDER BY created_at DESC LIMIT $1",
|
||||
params: []any{100},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "ids filter",
|
||||
filter: &nostr.Filter{
|
||||
IDs: []string{"083ec57f36a7b39ab98a57bedab4f85355b2ee89e4b205bed58d7c3ef9edd294"},
|
||||
},
|
||||
query: `SELECT COUNT(*)
|
||||
FROM event
|
||||
WHERE (id LIKE '083ec57f36a7b39ab98a57bedab4f85355b2ee89e4b205bed58d7c3ef9edd294%')
|
||||
ORDER BY created_at DESC LIMIT $1`,
|
||||
params: []any{100},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "kind filter",
|
||||
filter: &nostr.Filter{
|
||||
Kinds: []int{1, 2, 3},
|
||||
},
|
||||
query: `SELECT COUNT(*)
|
||||
FROM event
|
||||
WHERE kind IN(1,2,3)
|
||||
ORDER BY created_at DESC LIMIT $1`,
|
||||
params: []any{100},
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "authors filter",
|
||||
filter: &nostr.Filter{
|
||||
Authors: []string{"7bdef7bdebb8721f77927d0e77c66059360fa62371fdf15f3add93923a613229"},
|
||||
},
|
||||
query: `SELECT COUNT(*)
|
||||
FROM event
|
||||
WHERE (pubkey LIKE '7bdef7bdebb8721f77927d0e77c66059360fa62371fdf15f3add93923a613229%')
|
||||
ORDER BY created_at DESC LIMIT $1`,
|
||||
params: []any{100},
|
||||
err: nil,
|
||||
},
|
||||
// errors
|
||||
{
|
||||
name: "nil filter",
|
||||
filter: nil,
|
||||
query: "",
|
||||
params: nil,
|
||||
err: fmt.Errorf("filter cannot be null"),
|
||||
},
|
||||
{
|
||||
name: "too many ids",
|
||||
filter: &nostr.Filter{
|
||||
IDs: strSlice(501),
|
||||
},
|
||||
query: "",
|
||||
params: nil,
|
||||
// REVIEW: should return error
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid ids",
|
||||
filter: &nostr.Filter{
|
||||
IDs: []string{"stuff"},
|
||||
},
|
||||
query: "",
|
||||
params: nil,
|
||||
// REVIEW: should return error
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "too many authors",
|
||||
filter: &nostr.Filter{
|
||||
Authors: strSlice(501),
|
||||
},
|
||||
query: "",
|
||||
params: nil,
|
||||
// REVIEW: should return error
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid authors",
|
||||
filter: &nostr.Filter{
|
||||
Authors: []string{"stuff"},
|
||||
},
|
||||
query: "",
|
||||
params: nil,
|
||||
// REVIEW: should return error
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "too many kinds",
|
||||
filter: &nostr.Filter{
|
||||
Kinds: intSlice(11),
|
||||
},
|
||||
query: "",
|
||||
params: nil,
|
||||
// REVIEW: should return error
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "no kinds",
|
||||
filter: &nostr.Filter{
|
||||
Kinds: []int{},
|
||||
},
|
||||
query: "",
|
||||
params: nil,
|
||||
// REVIEW: should return error
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "tags of empty array",
|
||||
filter: &nostr.Filter{
|
||||
Tags: nostr.TagMap{
|
||||
"#e": []string{},
|
||||
},
|
||||
},
|
||||
query: "",
|
||||
params: nil,
|
||||
// REVIEW: should return error
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "too many tag values",
|
||||
filter: &nostr.Filter{
|
||||
Tags: nostr.TagMap{
|
||||
"#e": strSlice(11),
|
||||
},
|
||||
},
|
||||
query: "",
|
||||
params: nil,
|
||||
// REVIEW: should return error
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
query, params, err := queryEventsSql(tt.filter, true)
|
||||
assert.Equal(t, tt.err, err)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, clean(tt.query), clean(query))
|
||||
assert.Equal(t, tt.params, params)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -4,29 +4,72 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/nbd-wtf/go-nostr"
|
||||
)
|
||||
|
||||
func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) (ch chan *nostr.Event, err error) {
|
||||
ch = make(chan *nostr.Event)
|
||||
|
||||
query, params, err := queryEventsSql(filter, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := b.DB.Query(query, params...)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("failed to fetch events using query %q: %w", query, err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer rows.Close()
|
||||
defer close(ch)
|
||||
for rows.Next() {
|
||||
var evt nostr.Event
|
||||
var timestamp int64
|
||||
err := rows.Scan(&evt.ID, &evt.PubKey, ×tamp,
|
||||
&evt.Kind, &evt.Tags, &evt.Content, &evt.Sig)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
evt.CreatedAt = nostr.Timestamp(timestamp)
|
||||
ch <- &evt
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (b SQLite3Backend) CountEvents(ctx context.Context, filter *nostr.Filter) (int64, error) {
|
||||
query, params, err := queryEventsSql(filter, true)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var count int64
|
||||
err = b.DB.QueryRow(query, params...).Scan(&count)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return 0, fmt.Errorf("failed to fetch events using query %q: %w", query, err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func queryEventsSql(filter *nostr.Filter, doCount bool) (string, []any, error) {
|
||||
var conditions []string
|
||||
var params []any
|
||||
|
||||
if filter == nil {
|
||||
err = errors.New("filter cannot be null")
|
||||
return
|
||||
return "", nil, fmt.Errorf("filter cannot be null")
|
||||
}
|
||||
|
||||
if filter.IDs != nil {
|
||||
if len(filter.IDs) > 500 {
|
||||
// too many ids, fail everything
|
||||
return
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
likeids := make([]string, 0, len(filter.IDs))
|
||||
@ -41,7 +84,7 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) (
|
||||
}
|
||||
if len(likeids) == 0 {
|
||||
// ids being [] mean you won't get anything
|
||||
return
|
||||
return "", nil, nil
|
||||
}
|
||||
conditions = append(conditions, "("+strings.Join(likeids, " OR ")+")")
|
||||
}
|
||||
@ -49,7 +92,7 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) (
|
||||
if filter.Authors != nil {
|
||||
if len(filter.Authors) > 500 {
|
||||
// too many authors, fail everything
|
||||
return
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
likekeys := make([]string, 0, len(filter.Authors))
|
||||
@ -64,7 +107,7 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) (
|
||||
}
|
||||
if len(likekeys) == 0 {
|
||||
// authors being [] mean you won't get anything
|
||||
return
|
||||
return "", nil, nil
|
||||
}
|
||||
conditions = append(conditions, "("+strings.Join(likekeys, " OR ")+")")
|
||||
}
|
||||
@ -72,12 +115,12 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) (
|
||||
if filter.Kinds != nil {
|
||||
if len(filter.Kinds) > 10 {
|
||||
// too many kinds, fail everything
|
||||
return
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
if len(filter.Kinds) == 0 {
|
||||
// kinds being [] mean you won't get anything
|
||||
return
|
||||
return "", nil, nil
|
||||
}
|
||||
// no sql injection issues since these are ints
|
||||
inkinds := make([]string, len(filter.Kinds))
|
||||
@ -91,7 +134,7 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) (
|
||||
for _, values := range filter.Tags {
|
||||
if len(values) == 0 {
|
||||
// any tag set to [] is wrong
|
||||
return
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
// add these tags to the query
|
||||
@ -99,7 +142,7 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) (
|
||||
|
||||
if len(tagQuery) > 10 {
|
||||
// too many tags, fail everything
|
||||
return
|
||||
return "", nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -134,32 +177,20 @@ func (b SQLite3Backend) QueryEvents(ctx context.Context, filter *nostr.Filter) (
|
||||
params = append(params, filter.Limit)
|
||||
}
|
||||
|
||||
query := b.DB.Rebind(`SELECT
|
||||
id, pubkey, created_at, kind, tags, content, sig
|
||||
FROM event WHERE ` +
|
||||
strings.Join(conditions, " AND ") +
|
||||
" ORDER BY created_at DESC LIMIT ?")
|
||||
|
||||
rows, err := b.DB.Query(query, params...)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("failed to fetch events using query %q: %w", query, err)
|
||||
var query string
|
||||
if doCount {
|
||||
query = sqlx.Rebind(sqlx.BindType("sqlite3"), `SELECT
|
||||
COUNT(*)
|
||||
FROM event WHERE `+
|
||||
strings.Join(conditions, " AND ")+
|
||||
" ORDER BY created_at DESC LIMIT ?")
|
||||
} else {
|
||||
query = sqlx.Rebind(sqlx.BindType("sqlite3"), `SELECT
|
||||
id, pubkey, created_at, kind, tags, content, sig
|
||||
FROM event WHERE `+
|
||||
strings.Join(conditions, " AND ")+
|
||||
" ORDER BY created_at DESC LIMIT ?")
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer rows.Close()
|
||||
defer close(ch)
|
||||
for rows.Next() {
|
||||
var evt nostr.Event
|
||||
var timestamp int64
|
||||
err := rows.Scan(&evt.ID, &evt.PubKey, ×tamp,
|
||||
&evt.Kind, &evt.Tags, &evt.Content, &evt.Sig)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
evt.CreatedAt = nostr.Timestamp(timestamp)
|
||||
ch <- &evt
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
return query, params, nil
|
||||
}
|
||||
|
@ -52,6 +52,7 @@ type testStorage struct {
|
||||
queryEvents func(context.Context, *nostr.Filter) (chan *nostr.Event, error)
|
||||
deleteEvent func(ctx context.Context, id string, pubkey string) error
|
||||
saveEvent func(context.Context, *nostr.Event) error
|
||||
countEvents func(context.Context, *nostr.Filter) (int64, error)
|
||||
}
|
||||
|
||||
func (st *testStorage) Init() error {
|
||||
@ -81,3 +82,10 @@ func (st *testStorage) SaveEvent(ctx context.Context, e *nostr.Event) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *testStorage) CountEvents(ctx context.Context, f *nostr.Filter) (int64, error) {
|
||||
if fn := st.countEvents; fn != nil {
|
||||
return fn(ctx, f)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user