Files
tenacity/libraries/lib-music-information-retrieval/StftFrameProvider.cpp
Dmitry Makarenko 830acf7838 7470 Fix crash when importing short audio (#9096)
(cherry picked from commit efc959a77f2fbdd25966d70b32da91a384db16ca)
Signed-off-by: Avery King <gperson@disroot.org>
2025-08-11 08:04:11 -07:00

107 lines
3.1 KiB
C++

/* SPDX-License-Identifier: GPL-2.0-or-later */
/*!********************************************************************
Audacity: A Digital Audio Editor
StftFrameProvider.cpp
Matthieu Hodgkinson
**********************************************************************/
#include "StftFrameProvider.h"
#include "MirTypes.h"
#include "MirUtils.h"
#include <algorithm>
#include <cassert>
#include <cmath>
#include <numeric>
namespace MIR
{
namespace
{
constexpr auto twoPi = 2 * M_PI;
int GetFrameSize(int sampleRate)
{
// 2048 frame size for sample rate 44.1kHz
return 1 << (11 + (int)std::round(std::log2(sampleRate / 44100.)));
}
double GetHopSize(int sampleRate, long long numSamples)
{
// Aim for a hop size closest to 10ms, yet dividing `numSamples` to a power
// of two. This will spare us the need for resampling when we need to get the
// autocorrelation of the ODF using an FFT.
const auto idealHopSize = 0.01 * sampleRate;
const int exponent = std::round(std::log2(numSamples / idealHopSize));
if (exponent < 0)
return 0;
const auto numFrames = 1 << exponent;
return 1. * numSamples / numFrames;
}
} // namespace
StftFrameProvider::StftFrameProvider(const MirAudioReader& audio)
: mAudio { audio }
, mFftSize { GetFrameSize(audio.GetSampleRate()) }
, mHopSize { GetHopSize(audio.GetSampleRate(), audio.GetNumSamples()) }
, mWindow { GetNormalizedHann(mFftSize) }
, mNumFrames { mHopSize > 0 ? static_cast<int>(std::round(
audio.GetNumSamples() / mHopSize)) :
0 }
, mNumSamples { audio.GetNumSamples() }
{
assert(mNumFrames == 0 || IsPowOfTwo(mNumFrames));
}
bool StftFrameProvider::GetNextFrame(PffftFloatVector& frame)
{
if (mNumFramesProvided >= mNumFrames)
return false;
frame.resize(mFftSize, 0.f);
const int firstReadPosition = mHopSize - mFftSize;
int start = std::round(firstReadPosition + mNumFramesProvided * mHopSize);
while (start < 0)
start += mNumSamples;
const auto end = std::min<long long>(start + mFftSize, mNumSamples);
const auto numToRead = end - start;
mAudio.ReadFloats(frame.data(), start, numToRead);
// It's not impossible that some user drops a file so short that `mFftSize >
// mNumSamples`. In that case we won't be returning a meaningful
// STFT, but that's a use case we're not interested in. We just need to make
// sure we don't crash.
const auto numRemaining = std::min(mFftSize - numToRead, mNumSamples);
if (numRemaining > 0)
mAudio.ReadFloats(frame.data() + numToRead, 0, numRemaining);
std::transform(
frame.begin(), frame.end(), mWindow.begin(), frame.begin(),
std::multiplies<float>());
++mNumFramesProvided;
return true;
}
int StftFrameProvider::GetNumFrames() const
{
return mNumFrames;
}
int StftFrameProvider::GetSampleRate() const
{
return mAudio.GetSampleRate();
}
double StftFrameProvider::GetFrameRate() const
{
if (mHopSize <= 0)
return 0;
return 1. * mAudio.GetSampleRate() / mHopSize;
}
int StftFrameProvider::GetFftSize() const
{
return mFftSize;
}
} // namespace MIR