package amp

import (
	"crypto/rand"
	"encoding/binary"
	"fmt"
	"sync"

	"github.com/lightningnetwork/lnd/lntypes"
	"github.com/lightningnetwork/lnd/lnwire"
	"github.com/lightningnetwork/lnd/record"
	"github.com/lightningnetwork/lnd/routing/shards"
)

// Shard is an implementation of the shards.PaymentShards interface specific
// to AMP payments.
type Shard struct {
	child *Child
	mpp   *record.MPP
	amp   *record.AMP
}

// A compile time check to ensure Shard implements the shards.PaymentShard
// interface.
var _ shards.PaymentShard = (*Shard)(nil)

// Hash returns the hash used for the HTLC representing this AMP shard.
func (s *Shard) Hash() lntypes.Hash {
	return s.child.Hash
}

// MPP returns any extra MPP records that should be set for the final hop on
// the route used by this shard.
func (s *Shard) MPP() *record.MPP {
	return s.mpp
}

// AMP returns any extra AMP records that should be set for the final hop on
// the route used by this shard.
func (s *Shard) AMP() *record.AMP {
	return s.amp
}

// ShardTracker is an implementation of the shards.ShardTracker interface
// that is able to generate payment shards according to the AMP splitting
// algorithm. It can be used to generate new hashes to use for HTLCs, and also
// cancel shares used for failed payment shards.
type ShardTracker struct {
	setID       [32]byte
	paymentAddr [32]byte
	totalAmt    lnwire.MilliSatoshi

	sharer Sharer

	shards map[uint64]*Child
	sync.Mutex
}

// A compile time check to ensure ShardTracker implements the
// shards.ShardTracker interface.
var _ shards.ShardTracker = (*ShardTracker)(nil)

// NewShardTracker creates a new shard tracker to use for AMP payments. The
// root shard, setID, payment address and total amount must be correctly set in
// order for the TLV options to include with each shard to be created
// correctly.
func NewShardTracker(root, setID, payAddr [32]byte,
	totalAmt lnwire.MilliSatoshi) *ShardTracker {

	// Create a new seed sharer from this root.
	rootShare := Share(root)
	rootSharer := SeedSharerFromRoot(&rootShare)

	return &ShardTracker{
		setID:       setID,
		paymentAddr: payAddr,
		totalAmt:    totalAmt,
		sharer:      rootSharer,
		shards:      make(map[uint64]*Child),
	}
}

// NewShard registers a new attempt with the ShardTracker and returns a
// new shard representing this attempt. This attempt's shard should be canceled
// if it ends up not being used by the overall payment, i.e. if the attempt
// fails.
func (s *ShardTracker) NewShard(pid uint64, last bool) (shards.PaymentShard,
	error) {

	s.Lock()
	defer s.Unlock()

	// Use a random child index.
	var childIndex [4]byte
	if _, err := rand.Read(childIndex[:]); err != nil {
		return nil, err
	}
	idx := binary.BigEndian.Uint32(childIndex[:])

	// Depending on whether we are requesting the last shard or not, either
	// split the current share into two, or get a Child directly from the
	// current sharer.
	var child *Child
	if last {
		child = s.sharer.Child(idx)

		// If this was the last shard, set the current share to the
		// zero share to indicate we cannot split it further.
		s.sharer = s.sharer.Zero()
	} else {
		left, sharer, err := s.sharer.Split()
		if err != nil {
			return nil, err
		}

		s.sharer = sharer
		child = left.Child(idx)
	}

	// Track the new child and return the shard.
	s.shards[pid] = child

	mpp := record.NewMPP(s.totalAmt, s.paymentAddr)
	amp := record.NewAMP(
		child.ChildDesc.Share, s.setID, child.ChildDesc.Index,
	)

	return &Shard{
		child: child,
		mpp:   mpp,
		amp:   amp,
	}, nil
}

// CancelShard cancel's the shard corresponding to the given attempt ID.
func (s *ShardTracker) CancelShard(pid uint64) error {
	s.Lock()
	defer s.Unlock()

	c, ok := s.shards[pid]
	if !ok {
		return fmt.Errorf("pid not found")
	}
	delete(s.shards, pid)

	// Now that we are canceling this shard, we XOR the share back into our
	// current share.
	s.sharer = s.sharer.Merge(c)
	return nil
}

// GetHash retrieves the hash used by the shard of the given attempt ID. This
// will return an error if the attempt ID is unknown.
func (s *ShardTracker) GetHash(pid uint64) (lntypes.Hash, error) {
	s.Lock()
	defer s.Unlock()

	c, ok := s.shards[pid]
	if !ok {
		return lntypes.Hash{}, fmt.Errorf("AMP shard for attempt %v "+
			"not found", pid)
	}

	return c.Hash, nil
}