diff --git a/asyncbuffer/buffer.go b/asyncbuffer/buffer.go index 89c5045c..abdc023e 100644 --- a/asyncbuffer/buffer.go +++ b/asyncbuffer/buffer.go @@ -58,13 +58,14 @@ var chunkPool = sync.Pool{ // AsyncBuffer is a wrapper around io.Reader that reads data in chunks // in background and allows reading from synchronously. type AsyncBuffer struct { - r io.ReadCloser // Upstream reader + r io.ReadCloser // Upstream reader + dataLen int // Expected length of the data in r, <= 0 means unknown length chunks []*byteChunk // References to the chunks read from the upstream reader mu sync.RWMutex // Mutex on chunks slice - err atomic.Value // Error that occurred during reading - len atomic.Int64 // Total length of the data read + err atomic.Value // Error that occurred during reading + bytesRead atomic.Int64 // Total length of the data read finished atomic.Bool // Indicates that the buffer has finished reading closed atomic.Bool // Indicates that the buffer was closed @@ -78,9 +79,14 @@ type AsyncBuffer struct { // New creates a new AsyncBuffer that reads from the given io.ReadCloser in background // and closes it when finished. -func New(r io.ReadCloser, finishFn ...context.CancelFunc) *AsyncBuffer { +// +// r - io.ReadCloser to read data from +// dataLen - expected length of the data in r, <= 0 means unknown length +// finishFn - optional functions to call when the buffer is finished reading +func New(r io.ReadCloser, dataLen int, finishFn ...context.CancelFunc) *AsyncBuffer { ab := &AsyncBuffer{ r: r, + dataLen: dataLen, paused: NewLatch(), chunkCond: NewCond(), finishFn: finishFn, @@ -102,7 +108,8 @@ func (ab *AsyncBuffer) callFinishFn() { }) } -// addChunk adds a new chunk to the AsyncBuffer, increments len and signals that a chunk is ready +// addChunk adds a new chunk to the AsyncBuffer, increments bytesRead +// and signals that a chunk is ready func (ab *AsyncBuffer) addChunk(chunk *byteChunk) { ab.mu.Lock() defer ab.mu.Unlock() @@ -115,7 +122,7 @@ func (ab *AsyncBuffer) addChunk(chunk *byteChunk) { // Store the chunk, increase chunk size, increase length of the data read ab.chunks = append(ab.chunks, chunk) - ab.len.Add(int64(len(chunk.data))) + ab.bytesRead.Add(int64(len(chunk.data))) ab.chunkCond.Tick() } @@ -132,14 +139,26 @@ func (ab *AsyncBuffer) readChunks() { logrus.WithField("source", "asyncbuffer.AsyncBuffer.readChunks").Warningf("error closing upstream reader: %s", err) } + if ab.bytesRead.Load() < int64(ab.dataLen) && ab.err.Load() == nil { + // If the reader has finished reading and we have not read enough data, + // set err to io.ErrUnexpectedEOF + ab.err.Store(io.ErrUnexpectedEOF) + } + ab.callFinishFn() }() + r := ab.r.(io.Reader) + if ab.dataLen > 0 { + // If the data length is known, we read only that much data + r = io.LimitReader(r, int64(ab.dataLen)) + } + // Stop reading if the reader is closed for !ab.closed.Load() { // In case we are trying to read data beyond threshold and we are paused, // wait for pause to be released. - if ab.len.Load() >= pauseThreshold { + if ab.bytesRead.Load() >= pauseThreshold { ab.paused.Wait() // If the reader has been closed while waiting, we can stop reading @@ -157,9 +176,9 @@ func (ab *AsyncBuffer) readChunks() { } // Read data into the chunk's buffer - // There is no way to guarantee that ab.r.Read will abort on context cancellation, + // There is no way to guarantee that r.Read will abort on context cancellation, // unfortunately, this is how golang works. - n, err := ioutil.TryReadFull(ab.r, chunk.buf) + n, err := ioutil.TryReadFull(r, chunk.buf) // If it's not the EOF, we need to store the error if err != nil && err != io.EOF { @@ -214,7 +233,7 @@ func (ab *AsyncBuffer) offsetAvailable(off int64) (bool, error) { // In case the offset falls within the already read chunks, we can return immediately, // even if error has occurred in the future - if off < ab.len.Load() { + if off < ab.bytesRead.Load() { return true, nil } @@ -267,7 +286,7 @@ func (ab *AsyncBuffer) Wait() (int, error) { // In case the reader is finished reading, we can return immediately if ab.finished.Load() { - return int(ab.len.Load()), ab.Error() + return int(ab.bytesRead.Load()), ab.Error() } // Lock until the next chunk is ready @@ -296,7 +315,7 @@ func (ab *AsyncBuffer) Error() error { // (eg. offset is beyond the end of the stream). func (ab *AsyncBuffer) readChunkAt(p []byte, off int64) int { // If the chunk is not available, we return 0 - if off >= ab.len.Load() { + if off >= ab.bytesRead.Load() { return 0 } diff --git a/asyncbuffer/buffer_test.go b/asyncbuffer/buffer_test.go index 10c4859a..6fc98c8c 100644 --- a/asyncbuffer/buffer_test.go +++ b/asyncbuffer/buffer_test.go @@ -109,7 +109,7 @@ func generateSourceData(t *testing.T, size int) ([]byte, io.ReadSeekCloser) { func TestAsyncBufferReadAt(t *testing.T) { // Let's use source buffer which is 4.5 chunks long source, bytesReader := generateSourceData(t, chunkSize*4+halfChunkSize) - asyncBuffer := New(bytesReader) + asyncBuffer := New(bytesReader, -1) defer asyncBuffer.Close() asyncBuffer.Wait() // Wait for all chunks to be read since we're going to read all data @@ -169,7 +169,7 @@ func TestAsyncBufferReadAt(t *testing.T) { // TestAsyncBufferRead tests reading from AsyncBuffer using ReadAt method func TestAsyncBufferReadAtSmallBuffer(t *testing.T) { source, bytesReader := generateSourceData(t, 20) - asyncBuffer := New(bytesReader) + asyncBuffer := New(bytesReader, -1) defer asyncBuffer.Close() // First, let's read all the data @@ -199,7 +199,7 @@ func TestAsyncBufferReader(t *testing.T) { source, bytesReader := generateSourceData(t, chunkSize*4+halfChunkSize) // Create an AsyncBuffer with the byte slice - asyncBuffer := New(bytesReader) + asyncBuffer := New(bytesReader, -1) defer asyncBuffer.Close() // Let's wait for all chunks to be read @@ -267,7 +267,7 @@ func TestAsyncBufferClose(t *testing.T) { _, bytesReader := generateSourceData(t, chunkSize*4+halfChunkSize) // Create an AsyncBuffer with the byte slice - asyncBuffer := New(bytesReader) + asyncBuffer := New(bytesReader, -1) reader1 := asyncBuffer.Reader() reader2 := asyncBuffer.Reader() @@ -294,7 +294,7 @@ func TestAsyncBufferReadAtErrAtSomePoint(t *testing.T) { // Let's use source buffer which is 4.5 chunks long source, bytesReader := generateSourceData(t, chunkSize*4+halfChunkSize) slowReader := &erraticReader{reader: bytesReader, failAt: chunkSize*3 + 5} // fails at last chunk - asyncBuffer := New(slowReader) + asyncBuffer := New(slowReader, -1) defer asyncBuffer.Close() // Let's wait for all chunks to be read @@ -327,7 +327,7 @@ func TestAsyncBufferReadAsync(t *testing.T) { // Let's use source buffer which is 4.5 chunks long source, bytesReader := generateSourceData(t, chunkSize*3) blockingReader := newBlockingReader(bytesReader) - asyncBuffer := New(blockingReader) + asyncBuffer := New(blockingReader, -1) defer asyncBuffer.Close() // flush the first chunk to allow reading @@ -367,11 +367,54 @@ func TestAsyncBufferReadAsync(t *testing.T) { assert.Equal(t, 0, n) } +// TestAsyncBufferWithDataLenAndExactReaderSize tests that AsyncBuffer doesn't +// return an error when the expected data length is set and matches the reader size +func TestAsyncBufferWithDataLenAndExactReaderSize(t *testing.T) { + source, bytesReader := generateSourceData(t, chunkSize*4+halfChunkSize) + asyncBuffer := New(bytesReader, len(source)) + defer asyncBuffer.Close() + + // Let's wait for all chunks to be read + size, err := asyncBuffer.Wait() + require.NoError(t, err, "AsyncBuffer failed to wait for all chunks") + assert.Equal(t, len(source), size) +} + +// TestAsyncBufferWithDataLenAndShortReaderSize tests that AsyncBuffer returns +// io.ErrUnexpectedEOF when the expected data length is set and the reader size +// is shorter than the expected data length +func TestAsyncBufferWithDataLenAndShortReaderSize(t *testing.T) { + source, bytesReader := generateSourceData(t, chunkSize*4+halfChunkSize) + asyncBuffer := New(bytesReader, len(source)+100) // 100 bytes more than the source + defer asyncBuffer.Close() + + // Let's wait for all chunks to be read + size, err := asyncBuffer.Wait() + require.Equal(t, len(source), size) + require.ErrorIs(t, err, io.ErrUnexpectedEOF, + "AsyncBuffer should return io.ErrUnexpectedEOF when data length is longer than reader size") +} + +// TestAsyncBufferWithDataLenAndLongerReaderSize tests that AsyncBuffer doesn't +// read more data than specified by the expected data length and doesn't return an error +// when the reader size is longer than the expected data length +func TestAsyncBufferWithDataLenAndLongerReaderSize(t *testing.T) { + source, bytesReader := generateSourceData(t, chunkSize*4+halfChunkSize) + asyncBuffer := New(bytesReader, len(source)-100) // 100 bytes less than the source + defer asyncBuffer.Close() + + // Let's wait for all chunks to be read + size, err := asyncBuffer.Wait() + require.NoError(t, err, "AsyncBuffer failed to wait for all chunks") + assert.Equal(t, len(source)-100, size, + "AsyncBuffer should read only the specified amount of data when data length is set") +} + // TestAsyncBufferReadAllCompability tests that ReadAll methods works as expected func TestAsyncBufferReadAllCompability(t *testing.T) { source, err := os.ReadFile("../testdata/test1.jpg") require.NoError(t, err) - asyncBuffer := New(nopSeekCloser{bytes.NewReader(source)}) + asyncBuffer := New(nopSeekCloser{bytes.NewReader(source)}, -1) defer asyncBuffer.Close() b, err := io.ReadAll(asyncBuffer.Reader()) @@ -381,7 +424,7 @@ func TestAsyncBufferReadAllCompability(t *testing.T) { func TestAsyncBufferThreshold(t *testing.T) { _, bytesReader := generateSourceData(t, pauseThreshold*2) - asyncBuffer := New(bytesReader) + asyncBuffer := New(bytesReader, -1) defer asyncBuffer.Close() target := make([]byte, chunkSize) @@ -391,12 +434,12 @@ func TestAsyncBufferThreshold(t *testing.T) { // Ensure that buffer hits the pause threshold require.Eventually(t, func() bool { - return asyncBuffer.len.Load() >= pauseThreshold + return asyncBuffer.bytesRead.Load() >= pauseThreshold }, 300*time.Millisecond, 10*time.Millisecond) // Ensure that buffer never reaches the end of the stream require.Never(t, func() bool { - return asyncBuffer.len.Load() >= pauseThreshold*2-1 + return asyncBuffer.bytesRead.Load() >= pauseThreshold*2-1 }, 300*time.Millisecond, 10*time.Millisecond) // Let's hit the pause threshold @@ -407,7 +450,7 @@ func TestAsyncBufferThreshold(t *testing.T) { // Ensure that buffer never reaches the end of the stream require.Never(t, func() bool { - return asyncBuffer.len.Load() >= pauseThreshold*2-1 + return asyncBuffer.bytesRead.Load() >= pauseThreshold*2-1 }, 300*time.Millisecond, 10*time.Millisecond) // Let's hit the pause threshold @@ -421,13 +464,13 @@ func TestAsyncBufferThreshold(t *testing.T) { // Ensure that buffer hits the end of the stream require.Eventually(t, func() bool { - return asyncBuffer.len.Load() >= pauseThreshold*2 + return asyncBuffer.bytesRead.Load() >= pauseThreshold*2 }, 300*time.Millisecond, 10*time.Millisecond) } func TestAsyncBufferThresholdInstantBeyondAccess(t *testing.T) { _, bytesReader := generateSourceData(t, pauseThreshold*2) - asyncBuffer := New(bytesReader) + asyncBuffer := New(bytesReader, -1) defer asyncBuffer.Close() target := make([]byte, chunkSize) @@ -437,6 +480,6 @@ func TestAsyncBufferThresholdInstantBeyondAccess(t *testing.T) { // Ensure that buffer hits the end of the stream require.Eventually(t, func() bool { - return asyncBuffer.len.Load() >= pauseThreshold*2 + return asyncBuffer.bytesRead.Load() >= pauseThreshold*2 }, 300*time.Millisecond, 10*time.Millisecond) } diff --git a/asyncbuffer/reader.go b/asyncbuffer/reader.go index fcc09da4..5c2df749 100644 --- a/asyncbuffer/reader.go +++ b/asyncbuffer/reader.go @@ -32,9 +32,12 @@ func (r *Reader) Seek(offset int64, whence int) (int64, error) { r.pos += offset case io.SeekEnd: - size, err := r.ab.Wait() - if err != nil { - return 0, err + size := r.ab.dataLen + if size <= 0 { + var err error + if size, err = r.ab.Wait(); err != nil { + return 0, err + } } r.pos = int64(size) + offset diff --git a/imagedata/factory.go b/imagedata/factory.go index 7036cb04..4d75ac5c 100644 --- a/imagedata/factory.go +++ b/imagedata/factory.go @@ -150,7 +150,7 @@ func downloadAsync(ctx context.Context, imageURL string, opts DownloadOptions) ( return nil, h, err } - b := asyncbuffer.New(res.Body, opts.DownloadFinished) + b := asyncbuffer.New(res.Body, int(res.ContentLength), opts.DownloadFinished) format, err := imagetype.Detect(b.Reader()) if err != nil {