go-nostr/nip96/nip96.go

138 lines
3.4 KiB
Go

package nip96
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"hash"
"io"
"mime/multipart"
"net/http"
"strconv"
"github.com/cloudwego/base64x"
jsoniter "github.com/json-iterator/go"
"github.com/nbd-wtf/go-nostr"
)
// Upload uploads a file to the provided req.Host.
func Upload(ctx context.Context, req UploadRequest) (*UploadResponse, error) {
if err := req.Validate(); err != nil {
return nil, err
}
client := http.DefaultClient
if req.HTTPClient != nil {
client = req.HTTPClient
}
var requestBody bytes.Buffer
fileHash := sha256.New()
writer := multipart.NewWriter(&requestBody)
{
// Add the file
fileWriter, err := writer.CreateFormFile("file", req.Filename)
if err != nil {
return nil, fmt.Errorf("multipartWriter.CreateFormFile: %w", err)
}
if _, err := io.Copy(fileWriter, io.TeeReader(req.File, fileHash)); err != nil {
return nil, fmt.Errorf("io.Copy: %w", err)
}
// Add the other fields
writer.WriteField("caption", req.Caption)
writer.WriteField("alt", req.Alt)
writer.WriteField("media_type", req.MediaType)
writer.WriteField("content_type", req.ContentType)
writer.WriteField("no_transform", fmt.Sprintf("%t", req.NoTransform))
if req.Expiration == 0 {
writer.WriteField("expiration", "")
} else {
writer.WriteField("expiration", strconv.FormatInt(int64(req.Expiration), 10))
}
if err := writer.Close(); err != nil {
return nil, fmt.Errorf("multipartWriter.Close: %w", err)
}
}
uploadReq, err := http.NewRequest("POST", req.Host, &requestBody)
if err != nil {
return nil, fmt.Errorf("http.NewRequest: %w", err)
}
uploadReq.Header.Set("Content-Type", writer.FormDataContentType())
if req.SK != "" {
if !req.SignPayload {
fileHash = nil
}
auth, err := generateAuthHeader(req.SK, req.Host, fileHash)
if err != nil {
return nil, fmt.Errorf("generateAuthHeader: %w", err)
}
uploadReq.Header.Set("Authorization", auth)
}
resp, err := client.Do(uploadReq)
if err != nil {
return nil, fmt.Errorf("httpclient.Do: %w", err)
}
defer resp.Body.Close()
switch resp.StatusCode {
case http.StatusRequestEntityTooLarge:
return nil, fmt.Errorf("File is too large")
case http.StatusBadRequest:
return nil, fmt.Errorf("Bad request")
case http.StatusForbidden:
return nil, fmt.Errorf("Unauthorized")
case http.StatusPaymentRequired:
return nil, fmt.Errorf("Payment required")
case http.StatusOK, http.StatusCreated, http.StatusAccepted:
var uploadResp UploadResponse
if err := jsoniter.NewDecoder(resp.Body).Decode(&uploadResp); err != nil {
return nil, fmt.Errorf("Error decoding JSON: %w", err)
}
return &uploadResp, nil
default:
return nil, fmt.Errorf("Unexpected error %v", resp.Status)
}
}
func generateAuthHeader(sk, host string, fileHash hash.Hash) (string, error) {
pk, err := nostr.GetPublicKey(sk)
if err != nil {
return "", fmt.Errorf("nostr.GetPublicKey: %w", err)
}
event := nostr.Event{
Kind: 27235,
PubKey: pk,
CreatedAt: nostr.Now(),
Tags: nostr.Tags{
nostr.Tag{"u", host},
nostr.Tag{"method", "POST"},
},
}
if fileHash != nil {
event.Tags = append(event.Tags, nostr.Tag{"payload", hex.EncodeToString(fileHash.Sum(nil))})
}
event.Sign(sk)
b, err := jsoniter.ConfigFastest.Marshal(event)
if err != nil {
return "", fmt.Errorf("json.Marshal: %w", err)
}
payload := base64x.StdEncoding.EncodeToString(b)
return fmt.Sprintf("Nostr %s", payload), nil
}