package fn

import (
	"context"
	"sync"
	"sync/atomic"
	"time"
)

var (
	// DefaultTimeout is the default timeout used for context operations.
	DefaultTimeout = 30 * time.Second
)

// ContextGuard is a struct that provides a wait group and main quit channel
// that can be used to create guarded contexts.
type ContextGuard struct {
	mu sync.Mutex
	wg sync.WaitGroup

	quit    chan struct{}
	stopped sync.Once

	// id is used to generate unique ids for each context that should be
	// cancelled when the main quit signal is triggered.
	id atomic.Uint32

	// cancelFns is a map of cancel functions that can be used to cancel
	// any context that should be cancelled when the main quit signal is
	// triggered. The key is the id of the context. The mutex must be held
	// when accessing this map.
	cancelFns map[uint32]context.CancelFunc
}

// NewContextGuard constructs and returns a new instance of ContextGuard.
func NewContextGuard() *ContextGuard {
	return &ContextGuard{
		quit:      make(chan struct{}),
		cancelFns: make(map[uint32]context.CancelFunc),
	}
}

// Quit is used to signal the main quit channel, which will cancel all
// non-blocking contexts derived from the ContextGuard.
func (g *ContextGuard) Quit() {
	g.stopped.Do(func() {
		g.mu.Lock()
		defer g.mu.Unlock()

		for _, cancel := range g.cancelFns {
			cancel()
		}

		// Clear cancelFns. It is safe to use nil, because no write
		// operations to it can happen after g.quit is closed.
		g.cancelFns = nil

		close(g.quit)
	})
}

// Done returns a channel that will be closed when the main quit signal is
// triggered.
func (g *ContextGuard) Done() <-chan struct{} {
	return g.quit
}

// WgAdd is used to add delta to the internal wait group of the ContextGuard.
func (g *ContextGuard) WgAdd(delta int) {
	g.wg.Add(delta)
}

// WgDone is used to decrement the internal wait group of the ContextGuard.
func (g *ContextGuard) WgDone() {
	g.wg.Done()
}

// WgWait is used to block until the internal wait group of the ContextGuard is
// empty.
func (g *ContextGuard) WgWait() {
	g.wg.Wait()
}

// ctxGuardOptions is used to configure the behaviour of the context derived
// via the WithCtx method of the ContextGuard.
type ctxGuardOptions struct {
	blocking    bool
	withTimeout bool
	timeout     time.Duration
}

// ContextGuardOption defines the signature of a functional option that can be
// used to configure the behaviour of the context derived via the WithCtx method
// of the ContextGuard.
type ContextGuardOption func(*ctxGuardOptions)

// WithBlockingCG is used to create a cancellable context that will NOT be
// cancelled if the main quit signal is triggered, to block shutdown of
// important tasks.
func WithBlockingCG() ContextGuardOption {
	return func(o *ctxGuardOptions) {
		o.blocking = true
	}
}

// WithCustomTimeoutCG is used to create a cancellable context with a custom
// timeout. Such a context will be cancelled if either the parent context is
// cancelled, the timeout is reached or, if the Blocking option is not provided,
// the main quit signal is triggered.
func WithCustomTimeoutCG(timeout time.Duration) ContextGuardOption {
	return func(o *ctxGuardOptions) {
		o.withTimeout = true
		o.timeout = timeout
	}
}

// WithTimeoutCG is used to create a cancellable context with a default timeout.
// Such a context will be cancelled if either the parent context is cancelled,
// the timeout is reached or, if the Blocking option is not provided, the main
// quit signal is triggered.
func WithTimeoutCG() ContextGuardOption {
	return func(o *ctxGuardOptions) {
		o.withTimeout = true
		o.timeout = DefaultTimeout
	}
}

// Create is used to derive a cancellable context from the parent. Various
// options can be provided to configure the behaviour of the derived context.
func (g *ContextGuard) Create(ctx context.Context,
	options ...ContextGuardOption) (context.Context, context.CancelFunc) {

	// Exit early if the parent context has already been cancelled.
	select {
	case <-ctx.Done():
		return ctx, func() {}
	default:
	}

	var opts ctxGuardOptions
	for _, o := range options {
		o(&opts)
	}

	g.mu.Lock()
	defer g.mu.Unlock()

	var cancel context.CancelFunc
	if opts.withTimeout {
		ctx, cancel = context.WithTimeout(ctx, opts.timeout)
	} else {
		ctx, cancel = context.WithCancel(ctx)
	}

	if opts.blocking {
		g.ctxBlocking(ctx)

		return ctx, cancel
	}

	// If the call is non-blocking, then we can exit early if the main quit
	// signal has been triggered.
	select {
	case <-g.quit:
		cancel()

		return ctx, cancel
	default:
	}

	cancel = g.ctxQuitUnsafe(ctx, cancel)

	return ctx, cancel
}

// ctxQuitUnsafe increases the wait group counter, waits until the context is
// cancelled and decreases the wait group counter. It stores the passed cancel
// function and returns a wrapped version, which removed the stored one and
// calls it. The Quit method calls all the stored cancel functions.
//
// NOTE: the caller must hold the ContextGuard's mutex before calling this
// function.
func (g *ContextGuard) ctxQuitUnsafe(ctx context.Context,
	cancel context.CancelFunc) context.CancelFunc {

	cancel = g.addCancelFnUnsafe(cancel)

	g.wg.Add(1)

	// We don't have to wait on g.quit here: g.quit can be closed only in
	// the Quit method, which also closes the context we are waiting for.
	context.AfterFunc(ctx, func() {
		g.wg.Done()
	})

	return cancel
}

// ctxBlocking increases the wait group counter, waits until the context is
// cancelled and decreases the wait group counter.
//
// NOTE: the caller must hold the ContextGuard's mutex before calling this
// function.
func (g *ContextGuard) ctxBlocking(ctx context.Context) {
	g.wg.Add(1)

	context.AfterFunc(ctx, func() {
		g.wg.Done()
	})
}

// addCancelFnUnsafe adds a context cancel function to the manager and returns a
// call-back which can safely be used to cancel the context.
//
// NOTE: the caller must hold the ContextGuard's mutex before calling this
// function.
func (g *ContextGuard) addCancelFnUnsafe(
	cancel context.CancelFunc) context.CancelFunc {

	id := g.id.Add(1)
	g.cancelFns[id] = cancel

	return g.cancelCtxFn(id)
}

// cancelCtxFn returns a call-back that can be used to cancel the context
// associated with the passed id.
func (g *ContextGuard) cancelCtxFn(id uint32) context.CancelFunc {
	return func() {
		g.mu.Lock()

		fn, ok := g.cancelFns[id]
		if !ok {
			g.mu.Unlock()
			return
		}
		delete(g.cancelFns, id)
		g.mu.Unlock()

		fn()
	}
}