diff --git a/libavutil/tx_template.c b/libavutil/tx_template.c index 5274133ec4..747731a06d 100644 --- a/libavutil/tx_template.c +++ b/libavutil/tx_template.c @@ -754,20 +754,34 @@ static av_cold int TX_NAME(ff_tx_fft_init)(AVTXContext *s, return 0; } +static av_cold int TX_NAME(ff_tx_fft_inplace_small_init)(AVTXContext *s, + const FFTXCodelet *cd, + uint64_t flags, + FFTXCodeletOptions *opts, + int len, int inv, + const void *scale) +{ + if (!(s->tmp = av_malloc(len*sizeof(*s->tmp)))) + return AVERROR(ENOMEM); + flags &= ~AV_TX_INPLACE; + return TX_NAME(ff_tx_fft_init)(s, cd, flags, opts, len, inv, scale); +} + static void TX_NAME(ff_tx_fft)(AVTXContext *s, void *_dst, void *_src, ptrdiff_t stride) { TXComplex *src = _src; - TXComplex *dst = _dst; + TXComplex *dst1 = s->flags & AV_TX_INPLACE ? s->tmp : _dst; + TXComplex *dst2 = _dst; int *map = s->sub[0].map; int len = s->len; /* Compilers can't vectorize this anyway without assuming AVX2, which they * generally don't, at least without -march=native -mtune=native */ for (int i = 0; i < len; i++) - dst[i] = src[map[i]]; + dst1[i] = src[map[i]]; - s->fn[0](&s->sub[0], dst, dst, stride); + s->fn[0](&s->sub[0], dst2, dst1, stride); } static void TX_NAME(ff_tx_fft_inplace)(AVTXContext *s, void *_dst, @@ -807,6 +821,19 @@ static const FFTXCodelet TX_NAME(ff_tx_fft_def) = { .prio = FF_TX_PRIO_BASE, }; +static const FFTXCodelet TX_NAME(ff_tx_fft_inplace_small_def) = { + .name = TX_NAME_STR("fft_inplace_small"), + .function = TX_NAME(ff_tx_fft), + .type = TX_TYPE(FFT), + .flags = AV_TX_UNALIGNED | FF_TX_OUT_OF_PLACE | AV_TX_INPLACE, + .factors[0] = TX_FACTOR_ANY, + .min_len = 2, + .max_len = 65536, + .init = TX_NAME(ff_tx_fft_inplace_small_init), + .cpu_flags = FF_TX_CPU_FLAGS_ALL, + .prio = FF_TX_PRIO_BASE - 256, +}; + static const FFTXCodelet TX_NAME(ff_tx_fft_inplace_def) = { .name = TX_NAME_STR("fft_inplace"), .function = TX_NAME(ff_tx_fft_inplace), @@ -1638,6 +1665,7 @@ const FFTXCodelet * const TX_NAME(ff_tx_codelet_list)[] = { /* Standalone transforms */ &TX_NAME(ff_tx_fft_def), &TX_NAME(ff_tx_fft_inplace_def), + &TX_NAME(ff_tx_fft_inplace_small_def), &TX_NAME(ff_tx_fft_pfa_3xM_def), &TX_NAME(ff_tx_fft_pfa_5xM_def), &TX_NAME(ff_tx_fft_pfa_7xM_def),