Skip to content

Commit

Permalink
Make WavWriter inherit from SampleProcessorBase.
Browse files Browse the repository at this point in the history
  - Currently it is implemented to consume all samples; any samples into it cannot be retrieved.
  - Move `PushFrame` to `PushFrameDerived`, which was recently migrated to be compatible with the base-related behavior.
  - `SampleProcessorBase::Flush` must only be called once.
    - Take this opportunity to close the header, which allows the convenience of being able to read back the file, before the writer is destroyed.
    - For convenience and usability, if `SampleProcessorBase::Flush` was not called, then the destructor still finalizes the header (old behavior).
  - Some member variables are now represented at the base level.
  - b/389111191: Now that this is a `SampleProcessorBase` we can start migrating to hold a generic post-processor to be used after rendering.

PiperOrigin-RevId: 715022459
  • Loading branch information
jwcullen committed Jan 13, 2025
1 parent 0c88d6e commit fc78e50
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 34 deletions.
1 change: 1 addition & 0 deletions iamf/cli/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ cc_library(
srcs = ["wav_writer.cc"],
hdrs = ["wav_writer.h"],
deps = [
":sample_processor_base",
"//iamf/common:macros",
"//iamf/common:obu_util",
"@com_google_absl//absl/base:nullability",
Expand Down
50 changes: 47 additions & 3 deletions iamf/cli/tests/wav_writer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,23 @@ TEST(PushFrame, WriteChannelWithTooFewSamplesFails) {
EXPECT_FALSE(wav_writer->PushFrame(absl::MakeConstSpan(samples)).ok());
}

TEST(PushFrame, ConsumesInputSamples) {
auto wav_writer =
WavWriter::Create(GetAndCleanupOutputFileName(".wav"), kNumChannels,
kSampleRateHz, kBitDepth16, kMaxInputSamplesPerFrame);
ASSERT_NE(wav_writer, nullptr);
constexpr int kNumSamples = 3;
const std::vector<std::vector<int32_t>> samples(
kNumSamples, std::vector<int32_t>(kNumChannels, kSampleValue));

EXPECT_THAT(wav_writer->PushFrame(absl::MakeConstSpan(samples)), IsOk());

// The writer consumes all input samples, so
// `SampleProcessorBase::GetOutputSamplesAsSpan` will always return an empty
// span.
EXPECT_TRUE(wav_writer->GetOutputSamplesAsSpan().empty());
}

TEST(DeprecatedWritePcmSamples,
DeprecatedWriteIntegerSamplesSucceedsWithoutHeader) {
auto wav_writer =
Expand Down Expand Up @@ -311,15 +328,17 @@ TEST(WavWriterTest,
EXPECT_EQ(wav_reader.buffers_, kExpectedSamples);
}

TEST(WavWriterTest, Output16BitWavFileHasCorrectDataWithPushFrame) {
TEST(WavWriterTest,
Output16BitWavFileHasCorrectDataWithPushFrameAfterDestruction) {
const std::string output_file_path(GetAndCleanupOutputFileName(".wav"));
const std::vector<std::vector<int32_t>> kExpectedSamples = {
{0x01000000}, {0x03020000}, {0x05040000},
{0x07060000}, {0x09080000}, {0x0b0a0000}};
constexpr int kNumSamplesPerFrame = 6;
{
// Create the writer in a small scope. It should be destroyed before
// checking the results.
// Create the writer in a small scope. The user can safely omit the call the
// `Flush()` method, but then they must wait until the writer is destroyed,
// to read the finalized header.
auto wav_writer =
WavWriter::Create(output_file_path, kNumChannels, kSampleRateHz,
kBitDepth16, kMaxInputSamplesPerFrame);
Expand All @@ -335,6 +354,31 @@ TEST(WavWriterTest, Output16BitWavFileHasCorrectDataWithPushFrame) {
EXPECT_EQ(wav_reader.buffers_, kExpectedSamples);
}

TEST(WavWriterTest, Output16BitWavFileHasCorrectDataWithPushFrameAfterFlush) {
const std::string output_file_path(GetAndCleanupOutputFileName(".wav"));
const std::vector<std::vector<int32_t>> kExpectedSamples = {
{0x01000000}, {0x03020000}, {0x05040000},
{0x07060000}, {0x09080000}, {0x0b0a0000}};
constexpr int kNumSamplesPerFrame = 6;

auto wav_writer =
WavWriter::Create(output_file_path, kNumChannels, kSampleRateHz,
kBitDepth16, kMaxInputSamplesPerFrame);
ASSERT_NE(wav_writer, nullptr);
EXPECT_THAT(wav_writer->PushFrame(absl::MakeConstSpan(kExpectedSamples)),
IsOk());
// Instead of waiting for the destructor to call `Flush()`, the user can call
// `Flush()` explicitly, to signal the wav header (including the total number
// of samples) to be finalized.
EXPECT_THAT(wav_writer->Flush(), IsOk());

auto wav_reader =
CreateWavReaderExpectOk(output_file_path, kNumSamplesPerFrame);
EXPECT_EQ(wav_reader.remaining_samples(), kNumSamplesPerFrame);
EXPECT_TRUE(wav_reader.ReadFrame());
EXPECT_EQ(wav_reader.buffers_, kExpectedSamples);
}

TEST(WavWriterTest,
Output24BitWavFileHasCorrectDataWithDeprecatedWritePcmSamples) {
const std::string output_file_path(GetAndCleanupOutputFileName(".wav"));
Expand Down
55 changes: 38 additions & 17 deletions iamf/cli/wav_writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "iamf/cli/sample_processor_base.h"
#include "iamf/common/macros.h"
#include "iamf/common/obu_util.h"
#include "src/dsp/write_wav_file.h"
Expand All @@ -40,6 +41,10 @@ namespace {
constexpr int kAudioToTactileResultFailure = 0;
constexpr int kAudioToTactileResultSuccess = 1;

// This class is implemented to consume all samples without producing output
// samples.
constexpr size_t kMaxOutputSamplesPerFrame = 0;

// Write samples for all channels.
absl::Status WriteSamplesInternal(absl::Nullable<FILE*> file,
size_t num_channels, int bit_depth,
Expand Down Expand Up @@ -130,6 +135,24 @@ absl::Status WriteSamplesInternal(absl::Nullable<FILE*> file,
write_sample_result));
}

void MaybeFinalizeFile(size_t sample_rate_hz, size_t num_channels,
auto& wav_header_writer, FILE*& file,
size_t& total_samples_written) {
if (file == nullptr) {
return;
}

// Finalize the temporary header based on the total number of samples written
// and close the file.
if (wav_header_writer) {
std::fseek(file, 0, SEEK_SET);
wav_header_writer(file, total_samples_written, sample_rate_hz,
num_channels);
}
std::fclose(file);
file = nullptr;
}

} // namespace

std::unique_ptr<WavWriter> WavWriter::Create(const std::string& wav_filename,
Expand Down Expand Up @@ -181,21 +204,12 @@ std::unique_ptr<WavWriter> WavWriter::Create(const std::string& wav_filename,
}

WavWriter::~WavWriter() {
if (file_ == nullptr) {
return;
}

// Finalize the temporary header based on the total number of samples written
// and close the file.
if (wav_header_writer_) {
std::fseek(file_, 0, SEEK_SET);
wav_header_writer_(file_, total_samples_written_, sample_rate_hz_,
num_channels_);
}
std::fclose(file_);
// Finalize the header, in case the user did not call `Flush()`.
MaybeFinalizeFile(sample_rate_hz_, num_channels_, wav_header_writer_, file_,
total_samples_written_);
}

absl::Status WavWriter::PushFrame(
absl::Status WavWriter::PushFrameDerived(
absl::Span<const std::vector<int32_t>> time_channel_samples) {
// Flatten down the serialized PCM for compatibility with the internal
// `WriteSamplesInternal` function.
Expand All @@ -221,13 +235,20 @@ absl::Status WavWriter::PushFrame(
}

return WriteSamplesInternal(file_, num_channels_, bit_depth_,
num_samples_per_frame_, samples_as_pcm,
max_input_samples_per_frame_, samples_as_pcm,
total_samples_written_);
}

absl::Status WavWriter::FlushDerived() {
// No more samples are coming, finalize the header and close the file.
MaybeFinalizeFile(sample_rate_hz_, num_channels_, wav_header_writer_, file_,
total_samples_written_);
return absl::OkStatus();
}

absl::Status WavWriter::WritePcmSamples(const std::vector<uint8_t>& buffer) {
return WriteSamplesInternal(file_, num_channels_, bit_depth_,
num_samples_per_frame_, buffer,
max_input_samples_per_frame_, buffer,
total_samples_written_);
}

Expand All @@ -241,10 +262,10 @@ WavWriter::WavWriter(const std::string& filename_to_remove, int num_channels,
int sample_rate_hz, int bit_depth,
size_t num_samples_per_frame, FILE* file,
WavHeaderWriter wav_header_writer)
: num_channels_(num_channels),
: SampleProcessorBase(num_samples_per_frame, num_channels,
kMaxOutputSamplesPerFrame),
sample_rate_hz_(sample_rate_hz),
bit_depth_(bit_depth),
num_samples_per_frame_(num_samples_per_frame),
total_samples_written_(0),
file_(file),
filename_to_remove_(filename_to_remove),
Expand Down
41 changes: 27 additions & 14 deletions iamf/cli/wav_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
#include "absl/functional/any_invocable.h"
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "iamf/cli/sample_processor_base.h"

namespace iamf_tools {

class WavWriter {
/*!\brief Write samples to a wav (or pcm) file, then consumes the samples. */
class WavWriter : public SampleProcessorBase {
public:
/*!\brief Factory function to create a `WavWriter`.
*
Expand Down Expand Up @@ -54,16 +56,6 @@ class WavWriter {
/*!\brief Returns the bit-depth.*/
int bit_depth() const { return bit_depth_; }

/*!\brief Writes samples to the wav file.
*
* There must be the same number of samples for each channel.
*
* \param time_channel_samples Samples to push arranged in (time, channel).
* \return `absl::OkStatus()` on success. A specific status on failure.
*/
absl::Status PushFrame(
absl::Span<const std::vector<int32_t>> time_channel_samples);

/*!\brief Writes samples to the wav file.
*
* There must be an integer number of samples and the number of samples %
Expand All @@ -74,7 +66,7 @@ class WavWriter {
* padding.
* \return `absl::OkStatus()` on success. A specific status on failure.
*/
[[deprecated(("Use `PushFrame` instead."))]]
[[deprecated(("Use `SampleProcessorBase::PushFrame` instead."))]]
absl::Status WritePcmSamples(const std::vector<uint8_t>& buffer);

/*!\brief Aborts the write process and deletes the wav file.*/
Expand All @@ -99,10 +91,31 @@ class WavWriter {
int sample_rate_hz, int bit_depth, size_t num_samples_per_frame,
FILE* file, WavHeaderWriter wav_header_writer);

const size_t num_channels_;
/*!\brief Writes samples to the wav file and consumes them.
*
* Since the samples are consumed, the
* `SampleProcessorBase::GetOutputSamplesAsSpan` method will always return an
* empty span.
*
* There must be the same number of samples for each channel.
*
* \param time_channel_samples Samples to push arranged in (time, channel).
* \return `absl::OkStatus()` on success. A specific status on failure.
*/
absl::Status PushFrameDerived(
absl::Span<const std::vector<int32_t>> time_channel_samples) override;

/*!\brief Signals that no more samples will be pushed.
*
* After calling `Flush()`, it is invalid to call `PushFrame()`
* or `Flush()` again.
*
* \return `absl::OkStatus()` on success. A specific status on failure.
*/
absl::Status FlushDerived() override;

const size_t sample_rate_hz_;
const size_t bit_depth_;
const size_t num_samples_per_frame_;
size_t total_samples_written_;
FILE* file_;
const std::string filename_to_remove_;
Expand Down

0 comments on commit fc78e50

Please sign in to comment.