package routing

import (
	"fmt"
	"sync"

	"github.com/btcsuite/btcd/btcec/v2"
	"github.com/btcsuite/btcd/btcutil"
	"github.com/go-errors/errors"
	"github.com/lightningnetwork/lnd/channeldb"
	"github.com/lightningnetwork/lnd/channeldb/models"
	"github.com/lightningnetwork/lnd/htlcswitch"
	"github.com/lightningnetwork/lnd/lntypes"
	"github.com/lightningnetwork/lnd/lnwire"
	"github.com/lightningnetwork/lnd/record"
	"github.com/lightningnetwork/lnd/routing/route"
	"github.com/lightningnetwork/lnd/routing/shards"
	"github.com/stretchr/testify/mock"
)

type mockPaymentAttemptDispatcherOld struct {
	onPayment func(firstHop lnwire.ShortChannelID) ([32]byte, error)
	results   map[uint64]*htlcswitch.PaymentResult

	sync.Mutex
}

var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcherOld)(nil)

func (m *mockPaymentAttemptDispatcherOld) SendHTLC(
	firstHop lnwire.ShortChannelID, pid uint64,
	_ *lnwire.UpdateAddHTLC) error {

	if m.onPayment == nil {
		return nil
	}

	var result *htlcswitch.PaymentResult
	preimage, err := m.onPayment(firstHop)
	if err != nil {
		rtErr, ok := err.(htlcswitch.ClearTextError)
		if !ok {
			return err
		}
		result = &htlcswitch.PaymentResult{
			Error: rtErr,
		}
	} else {
		result = &htlcswitch.PaymentResult{Preimage: preimage}
	}

	m.Lock()
	if m.results == nil {
		m.results = make(map[uint64]*htlcswitch.PaymentResult)
	}

	m.results[pid] = result
	m.Unlock()

	return nil
}

func (m *mockPaymentAttemptDispatcherOld) GetAttemptResult(paymentID uint64,
	_ lntypes.Hash, _ htlcswitch.ErrorDecrypter) (
	<-chan *htlcswitch.PaymentResult, error) {

	c := make(chan *htlcswitch.PaymentResult, 1)

	m.Lock()
	res, ok := m.results[paymentID]
	m.Unlock()

	if !ok {
		return nil, htlcswitch.ErrPaymentIDNotFound
	}
	c <- res

	return c, nil

}
func (m *mockPaymentAttemptDispatcherOld) CleanStore(
	map[uint64]struct{}) error {

	return nil
}

func (m *mockPaymentAttemptDispatcherOld) setPaymentResult(
	f func(firstHop lnwire.ShortChannelID) ([32]byte, error)) {

	m.onPayment = f
}

type mockPaymentSessionSourceOld struct {
	routes       []*route.Route
	routeRelease chan struct{}
}

var _ PaymentSessionSource = (*mockPaymentSessionSourceOld)(nil)

func (m *mockPaymentSessionSourceOld) NewPaymentSession(
	_ *LightningPayment) (PaymentSession, error) {

	return &mockPaymentSessionOld{
		routes:  m.routes,
		release: m.routeRelease,
	}, nil
}

func (m *mockPaymentSessionSourceOld) NewPaymentSessionForRoute(
	preBuiltRoute *route.Route) PaymentSession {
	return nil
}

func (m *mockPaymentSessionSourceOld) NewPaymentSessionEmpty() PaymentSession {
	return &mockPaymentSessionOld{}
}

type mockMissionControlOld struct {
	MissionControl
}

var _ MissionController = (*mockMissionControlOld)(nil)

func (m *mockMissionControlOld) ReportPaymentFail(
	paymentID uint64, rt *route.Route,
	failureSourceIdx *int, failure lnwire.FailureMessage) (
	*channeldb.FailureReason, error) {

	// Report a permanent failure if this is an error caused
	// by incorrect details.
	if failure.Code() == lnwire.CodeIncorrectOrUnknownPaymentDetails {
		reason := channeldb.FailureReasonPaymentDetails
		return &reason, nil
	}

	return nil, nil
}

func (m *mockMissionControlOld) ReportPaymentSuccess(paymentID uint64,
	rt *route.Route) error {

	return nil
}

func (m *mockMissionControlOld) GetProbability(fromNode, toNode route.Vertex,
	amt lnwire.MilliSatoshi, capacity btcutil.Amount) float64 {

	return 0
}

type mockPaymentSessionOld struct {
	routes []*route.Route

	// release is a channel that optionally blocks requesting a route
	// from our mock payment channel. If this value is nil, we will just
	// release the route automatically.
	release chan struct{}
}

var _ PaymentSession = (*mockPaymentSessionOld)(nil)

func (m *mockPaymentSessionOld) RequestRoute(_, _ lnwire.MilliSatoshi,
	_, height uint32) (*route.Route, error) {

	if m.release != nil {
		m.release <- struct{}{}
	}

	if len(m.routes) == 0 {
		return nil, errNoPathFound
	}

	r := m.routes[0]
	m.routes = m.routes[1:]

	return r, nil
}

func (m *mockPaymentSessionOld) UpdateAdditionalEdge(_ *lnwire.ChannelUpdate,
	_ *btcec.PublicKey, _ *models.CachedEdgePolicy) bool {

	return false
}

func (m *mockPaymentSessionOld) GetAdditionalEdgePolicy(_ *btcec.PublicKey,
	_ uint64) *models.CachedEdgePolicy {

	return nil
}

type mockPayerOld struct {
	sendResult    chan error
	paymentResult chan *htlcswitch.PaymentResult
	quit          chan struct{}
}

var _ PaymentAttemptDispatcher = (*mockPayerOld)(nil)

func (m *mockPayerOld) SendHTLC(_ lnwire.ShortChannelID,
	paymentID uint64,
	_ *lnwire.UpdateAddHTLC) error {

	select {
	case res := <-m.sendResult:
		return res
	case <-m.quit:
		return fmt.Errorf("test quitting")
	}

}

func (m *mockPayerOld) GetAttemptResult(paymentID uint64, _ lntypes.Hash,
	_ htlcswitch.ErrorDecrypter) (<-chan *htlcswitch.PaymentResult, error) {

	select {
	case res, ok := <-m.paymentResult:
		resChan := make(chan *htlcswitch.PaymentResult, 1)
		if !ok {
			close(resChan)
		} else {
			resChan <- res
		}

		return resChan, nil

	case <-m.quit:
		return nil, fmt.Errorf("test quitting")
	}
}

func (m *mockPayerOld) CleanStore(pids map[uint64]struct{}) error {
	return nil
}

type initArgs struct {
	c *channeldb.PaymentCreationInfo
}

type registerAttemptArgs struct {
	a *channeldb.HTLCAttemptInfo
}

type settleAttemptArgs struct {
	preimg lntypes.Preimage
}

type failAttemptArgs struct {
	reason *channeldb.HTLCFailInfo
}

type failPaymentArgs struct {
	reason channeldb.FailureReason
}

type testPayment struct {
	info     channeldb.PaymentCreationInfo
	attempts []channeldb.HTLCAttempt
}

type mockControlTowerOld struct {
	payments   map[lntypes.Hash]*testPayment
	successful map[lntypes.Hash]struct{}
	failed     map[lntypes.Hash]channeldb.FailureReason

	init            chan initArgs
	registerAttempt chan registerAttemptArgs
	settleAttempt   chan settleAttemptArgs
	failAttempt     chan failAttemptArgs
	failPayment     chan failPaymentArgs
	fetchInFlight   chan struct{}

	sync.Mutex
}

var _ ControlTower = (*mockControlTowerOld)(nil)

func makeMockControlTower() *mockControlTowerOld {
	return &mockControlTowerOld{
		payments:   make(map[lntypes.Hash]*testPayment),
		successful: make(map[lntypes.Hash]struct{}),
		failed:     make(map[lntypes.Hash]channeldb.FailureReason),
	}
}

func (m *mockControlTowerOld) InitPayment(phash lntypes.Hash,
	c *channeldb.PaymentCreationInfo) error {

	if m.init != nil {
		m.init <- initArgs{c}
	}

	m.Lock()
	defer m.Unlock()

	// Don't allow re-init a successful payment.
	if _, ok := m.successful[phash]; ok {
		return channeldb.ErrAlreadyPaid
	}

	_, failed := m.failed[phash]
	_, ok := m.payments[phash]

	// If the payment is known, only allow re-init if failed.
	if ok && !failed {
		return channeldb.ErrPaymentInFlight
	}

	delete(m.failed, phash)
	m.payments[phash] = &testPayment{
		info: *c,
	}

	return nil
}

func (m *mockControlTowerOld) DeleteFailedAttempts(phash lntypes.Hash) error {
	p, ok := m.payments[phash]
	if !ok {
		return channeldb.ErrPaymentNotInitiated
	}

	var inFlight bool
	for _, a := range p.attempts {
		if a.Settle != nil {
			continue
		}

		if a.Failure != nil {
			continue
		}

		inFlight = true
	}

	if inFlight {
		return channeldb.ErrPaymentInFlight
	}

	return nil
}

func (m *mockControlTowerOld) RegisterAttempt(phash lntypes.Hash,
	a *channeldb.HTLCAttemptInfo) error {

	if m.registerAttempt != nil {
		m.registerAttempt <- registerAttemptArgs{a}
	}

	m.Lock()
	defer m.Unlock()

	// Lookup payment.
	p, ok := m.payments[phash]
	if !ok {
		return channeldb.ErrPaymentNotInitiated
	}

	var inFlight bool
	for _, a := range p.attempts {
		if a.Settle != nil {
			continue
		}

		if a.Failure != nil {
			continue
		}

		inFlight = true
	}

	// Cannot register attempts for successful or failed payments.
	_, settled := m.successful[phash]
	_, failed := m.failed[phash]

	if settled || failed {
		return channeldb.ErrPaymentTerminal
	}

	if settled && !inFlight {
		return channeldb.ErrPaymentAlreadySucceeded
	}

	if failed && !inFlight {
		return channeldb.ErrPaymentAlreadyFailed
	}

	// Add attempt to payment.
	p.attempts = append(p.attempts, channeldb.HTLCAttempt{
		HTLCAttemptInfo: *a,
	})
	m.payments[phash] = p

	return nil
}

func (m *mockControlTowerOld) SettleAttempt(phash lntypes.Hash,
	pid uint64, settleInfo *channeldb.HTLCSettleInfo) (
	*channeldb.HTLCAttempt, error) {

	if m.settleAttempt != nil {
		m.settleAttempt <- settleAttemptArgs{settleInfo.Preimage}
	}

	m.Lock()
	defer m.Unlock()

	// Only allow setting attempts if the payment is known.
	p, ok := m.payments[phash]
	if !ok {
		return nil, channeldb.ErrPaymentNotInitiated
	}

	// Find the attempt with this pid, and set the settle info.
	for i, a := range p.attempts {
		if a.AttemptID != pid {
			continue
		}

		if a.Settle != nil {
			return nil, channeldb.ErrAttemptAlreadySettled
		}
		if a.Failure != nil {
			return nil, channeldb.ErrAttemptAlreadyFailed
		}

		p.attempts[i].Settle = settleInfo

		// Mark the payment successful on first settled attempt.
		m.successful[phash] = struct{}{}
		return &channeldb.HTLCAttempt{
			Settle: settleInfo,
		}, nil
	}

	return nil, fmt.Errorf("pid not found")
}

func (m *mockControlTowerOld) FailAttempt(phash lntypes.Hash, pid uint64,
	failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) {

	if m.failAttempt != nil {
		m.failAttempt <- failAttemptArgs{failInfo}
	}

	m.Lock()
	defer m.Unlock()

	// Only allow failing attempts if the payment is known.
	p, ok := m.payments[phash]
	if !ok {
		return nil, channeldb.ErrPaymentNotInitiated
	}

	// Find the attempt with this pid, and set the failure info.
	for i, a := range p.attempts {
		if a.AttemptID != pid {
			continue
		}

		if a.Settle != nil {
			return nil, channeldb.ErrAttemptAlreadySettled
		}
		if a.Failure != nil {
			return nil, channeldb.ErrAttemptAlreadyFailed
		}

		p.attempts[i].Failure = failInfo
		return &channeldb.HTLCAttempt{
			Failure: failInfo,
		}, nil
	}

	return nil, fmt.Errorf("pid not found")
}

func (m *mockControlTowerOld) FailPayment(phash lntypes.Hash,
	reason channeldb.FailureReason) error {

	m.Lock()
	defer m.Unlock()

	if m.failPayment != nil {
		m.failPayment <- failPaymentArgs{reason}
	}

	// Payment must be known.
	if _, ok := m.payments[phash]; !ok {
		return channeldb.ErrPaymentNotInitiated
	}

	m.failed[phash] = reason

	return nil
}

func (m *mockControlTowerOld) FetchPayment(phash lntypes.Hash) (
	dbMPPayment, error) {

	m.Lock()
	defer m.Unlock()

	return m.fetchPayment(phash)
}

func (m *mockControlTowerOld) fetchPayment(phash lntypes.Hash) (
	*channeldb.MPPayment, error) {

	p, ok := m.payments[phash]
	if !ok {
		return nil, channeldb.ErrPaymentNotInitiated
	}

	mp := &channeldb.MPPayment{
		Info: &p.info,
	}

	reason, ok := m.failed[phash]
	if ok {
		mp.FailureReason = &reason
	}

	// Return a copy of the current attempts.
	mp.HTLCs = append(mp.HTLCs, p.attempts...)

	if err := mp.SetState(); err != nil {
		return nil, err
	}

	return mp, nil
}

func (m *mockControlTowerOld) FetchInFlightPayments() (
	[]*channeldb.MPPayment, error) {

	if m.fetchInFlight != nil {
		m.fetchInFlight <- struct{}{}
	}

	m.Lock()
	defer m.Unlock()

	// In flight are all payments not successful or failed.
	var fl []*channeldb.MPPayment
	for hash := range m.payments {
		if _, ok := m.successful[hash]; ok {
			continue
		}
		if _, ok := m.failed[hash]; ok {
			continue
		}

		mp, err := m.fetchPayment(hash)
		if err != nil {
			return nil, err
		}

		fl = append(fl, mp)
	}

	return fl, nil
}

func (m *mockControlTowerOld) SubscribePayment(paymentHash lntypes.Hash) (
	ControlTowerSubscriber, error) {

	return nil, errors.New("not implemented")
}

func (m *mockControlTowerOld) SubscribeAllPayments() (
	ControlTowerSubscriber, error) {

	return nil, errors.New("not implemented")
}

type mockPaymentAttemptDispatcher struct {
	mock.Mock
}

var _ PaymentAttemptDispatcher = (*mockPaymentAttemptDispatcher)(nil)

func (m *mockPaymentAttemptDispatcher) SendHTLC(firstHop lnwire.ShortChannelID,
	pid uint64, htlcAdd *lnwire.UpdateAddHTLC) error {

	args := m.Called(firstHop, pid, htlcAdd)
	return args.Error(0)
}

func (m *mockPaymentAttemptDispatcher) GetAttemptResult(attemptID uint64,
	paymentHash lntypes.Hash, deobfuscator htlcswitch.ErrorDecrypter) (
	<-chan *htlcswitch.PaymentResult, error) {

	args := m.Called(attemptID, paymentHash, deobfuscator)

	resultChan := args.Get(0)
	if resultChan == nil {
		return nil, args.Error(1)
	}

	return args.Get(0).(chan *htlcswitch.PaymentResult), args.Error(1)
}

func (m *mockPaymentAttemptDispatcher) CleanStore(
	keepPids map[uint64]struct{}) error {

	args := m.Called(keepPids)
	return args.Error(0)
}

type mockPaymentSessionSource struct {
	mock.Mock
}

var _ PaymentSessionSource = (*mockPaymentSessionSource)(nil)

func (m *mockPaymentSessionSource) NewPaymentSession(
	payment *LightningPayment) (PaymentSession, error) {

	args := m.Called(payment)
	return args.Get(0).(PaymentSession), args.Error(1)
}

func (m *mockPaymentSessionSource) NewPaymentSessionForRoute(
	preBuiltRoute *route.Route) PaymentSession {

	args := m.Called(preBuiltRoute)
	return args.Get(0).(PaymentSession)
}

func (m *mockPaymentSessionSource) NewPaymentSessionEmpty() PaymentSession {
	args := m.Called()
	return args.Get(0).(PaymentSession)
}

type mockMissionControl struct {
	mock.Mock
}

var _ MissionController = (*mockMissionControl)(nil)

func (m *mockMissionControl) ReportPaymentFail(
	paymentID uint64, rt *route.Route,
	failureSourceIdx *int, failure lnwire.FailureMessage) (
	*channeldb.FailureReason, error) {

	args := m.Called(paymentID, rt, failureSourceIdx, failure)

	// Type assertion on nil will fail, so we check and return here.
	if args.Get(0) == nil {
		return nil, args.Error(1)
	}

	return args.Get(0).(*channeldb.FailureReason), args.Error(1)
}

func (m *mockMissionControl) ReportPaymentSuccess(paymentID uint64,
	rt *route.Route) error {

	args := m.Called(paymentID, rt)
	return args.Error(0)
}

func (m *mockMissionControl) GetProbability(fromNode, toNode route.Vertex,
	amt lnwire.MilliSatoshi, capacity btcutil.Amount) float64 {

	args := m.Called(fromNode, toNode, amt, capacity)
	return args.Get(0).(float64)
}

type mockPaymentSession struct {
	mock.Mock
}

var _ PaymentSession = (*mockPaymentSession)(nil)

func (m *mockPaymentSession) RequestRoute(maxAmt, feeLimit lnwire.MilliSatoshi,
	activeShards, height uint32) (*route.Route, error) {

	args := m.Called(maxAmt, feeLimit, activeShards, height)

	// Type assertion on nil will fail, so we check and return here.
	if args.Get(0) == nil {
		return nil, args.Error(1)
	}

	return args.Get(0).(*route.Route), args.Error(1)
}

func (m *mockPaymentSession) UpdateAdditionalEdge(msg *lnwire.ChannelUpdate,
	pubKey *btcec.PublicKey, policy *models.CachedEdgePolicy) bool {

	args := m.Called(msg, pubKey, policy)
	return args.Bool(0)
}

func (m *mockPaymentSession) GetAdditionalEdgePolicy(pubKey *btcec.PublicKey,
	channelID uint64) *models.CachedEdgePolicy {

	args := m.Called(pubKey, channelID)
	return args.Get(0).(*models.CachedEdgePolicy)
}

type mockControlTower struct {
	mock.Mock
}

var _ ControlTower = (*mockControlTower)(nil)

func (m *mockControlTower) InitPayment(phash lntypes.Hash,
	c *channeldb.PaymentCreationInfo) error {

	args := m.Called(phash, c)
	return args.Error(0)
}

func (m *mockControlTower) DeleteFailedAttempts(phash lntypes.Hash) error {
	args := m.Called(phash)
	return args.Error(0)
}

func (m *mockControlTower) RegisterAttempt(phash lntypes.Hash,
	a *channeldb.HTLCAttemptInfo) error {

	args := m.Called(phash, a)
	return args.Error(0)
}

func (m *mockControlTower) SettleAttempt(phash lntypes.Hash,
	pid uint64, settleInfo *channeldb.HTLCSettleInfo) (
	*channeldb.HTLCAttempt, error) {

	args := m.Called(phash, pid, settleInfo)

	attempt := args.Get(0)
	if attempt == nil {
		return nil, args.Error(1)
	}

	return attempt.(*channeldb.HTLCAttempt), args.Error(1)
}

func (m *mockControlTower) FailAttempt(phash lntypes.Hash, pid uint64,
	failInfo *channeldb.HTLCFailInfo) (*channeldb.HTLCAttempt, error) {

	args := m.Called(phash, pid, failInfo)

	attempt := args.Get(0)
	if attempt == nil {
		return nil, args.Error(1)
	}

	return args.Get(0).(*channeldb.HTLCAttempt), args.Error(1)
}

func (m *mockControlTower) FailPayment(phash lntypes.Hash,
	reason channeldb.FailureReason) error {

	args := m.Called(phash, reason)
	return args.Error(0)
}

func (m *mockControlTower) FetchPayment(phash lntypes.Hash) (
	dbMPPayment, error) {

	args := m.Called(phash)

	// Type assertion on nil will fail, so we check and return here.
	if args.Get(0) == nil {
		return nil, args.Error(1)
	}

	payment := args.Get(0).(*mockMPPayment)
	return payment, args.Error(1)
}

func (m *mockControlTower) FetchInFlightPayments() (
	[]*channeldb.MPPayment, error) {

	args := m.Called()
	return args.Get(0).([]*channeldb.MPPayment), args.Error(1)
}

func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) (
	ControlTowerSubscriber, error) {

	args := m.Called(paymentHash)
	return args.Get(0).(ControlTowerSubscriber), args.Error(1)
}

func (m *mockControlTower) SubscribeAllPayments() (
	ControlTowerSubscriber, error) {

	args := m.Called()
	return args.Get(0).(ControlTowerSubscriber), args.Error(1)
}

type mockMPPayment struct {
	mock.Mock
}

var _ dbMPPayment = (*mockMPPayment)(nil)

func (m *mockMPPayment) GetState() *channeldb.MPPaymentState {
	args := m.Called()
	return args.Get(0).(*channeldb.MPPaymentState)
}

func (m *mockMPPayment) GetStatus() channeldb.PaymentStatus {
	args := m.Called()
	return args.Get(0).(channeldb.PaymentStatus)
}

func (m *mockMPPayment) Terminated() bool {
	args := m.Called()

	return args.Bool(0)
}

func (m *mockMPPayment) NeedWaitAttempts() (bool, error) {
	args := m.Called()
	return args.Bool(0), args.Error(1)
}

func (m *mockMPPayment) GetHTLCs() []channeldb.HTLCAttempt {
	args := m.Called()
	return args.Get(0).([]channeldb.HTLCAttempt)
}

func (m *mockMPPayment) InFlightHTLCs() []channeldb.HTLCAttempt {
	args := m.Called()
	return args.Get(0).([]channeldb.HTLCAttempt)
}

func (m *mockMPPayment) AllowMoreAttempts() (bool, error) {
	args := m.Called()
	return args.Bool(0), args.Error(1)
}

func (m *mockMPPayment) TerminalInfo() (*channeldb.HTLCAttempt,
	*channeldb.FailureReason) {

	args := m.Called()

	var (
		settleInfo  *channeldb.HTLCAttempt
		failureInfo *channeldb.FailureReason
	)

	settle := args.Get(0)
	if settle != nil {
		settleInfo = settle.(*channeldb.HTLCAttempt)
	}

	reason := args.Get(1)
	if reason != nil {
		failureInfo = reason.(*channeldb.FailureReason)
	}

	return settleInfo, failureInfo
}

type mockLink struct {
	htlcswitch.ChannelLink
	bandwidth         lnwire.MilliSatoshi
	mayAddOutgoingErr error
	ineligible        bool
}

// Bandwidth returns the bandwidth the mock was configured with.
func (m *mockLink) Bandwidth() lnwire.MilliSatoshi {
	return m.bandwidth
}

// EligibleToForward returns the mock's configured eligibility.
func (m *mockLink) EligibleToForward() bool {
	return !m.ineligible
}

// MayAddOutgoingHtlc returns the error configured in our mock.
func (m *mockLink) MayAddOutgoingHtlc(_ lnwire.MilliSatoshi) error {
	return m.mayAddOutgoingErr
}

type mockShardTracker struct {
	mock.Mock
}

var _ shards.ShardTracker = (*mockShardTracker)(nil)

func (m *mockShardTracker) NewShard(attemptID uint64,
	lastShard bool) (shards.PaymentShard, error) {

	args := m.Called(attemptID, lastShard)

	shard := args.Get(0)
	if shard == nil {
		return nil, args.Error(1)
	}

	return shard.(shards.PaymentShard), args.Error(1)
}

func (m *mockShardTracker) GetHash(attemptID uint64) (lntypes.Hash, error) {
	args := m.Called(attemptID)
	return args.Get(0).(lntypes.Hash), args.Error(1)
}

func (m *mockShardTracker) CancelShard(attemptID uint64) error {
	args := m.Called(attemptID)
	return args.Error(0)
}

type mockShard struct {
	mock.Mock
}

var _ shards.PaymentShard = (*mockShard)(nil)

// Hash returns the hash used for the HTLC representing this shard.
func (m *mockShard) Hash() lntypes.Hash {
	args := m.Called()
	return args.Get(0).(lntypes.Hash)
}

// MPP returns any extra MPP records that should be set for the final
// hop on the route used by this shard.
func (m *mockShard) MPP() *record.MPP {
	args := m.Called()

	r := args.Get(0)
	if r == nil {
		return nil
	}

	return r.(*record.MPP)
}

// AMP returns any extra AMP records that should be set for the final
// hop on the route used by this shard.
func (m *mockShard) AMP() *record.AMP {
	args := m.Called()

	r := args.Get(0)
	if r == nil {
		return nil
	}

	return r.(*record.AMP)
}