Skip to content

Commit

Permalink
identity_streaming_fft: print PSNR for reconstruction
Browse files Browse the repository at this point in the history
  • Loading branch information
szabadka committed Feb 26, 2024
1 parent 6b78881 commit e647cd8
Showing 1 changed file with 34 additions and 9 deletions.
43 changes: 34 additions & 9 deletions speaker_experiments/identity_sliding_fft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ double BarkFreq(double v) {
}
}

static constexpr int64_t kBlockSize = 32768;
static constexpr int64_t kBlockSize = 1 << 15;
static const int kHistorySize = (1 << 18);
static const int kHistoryMask = kHistorySize - 1;

Expand Down Expand Up @@ -168,6 +168,22 @@ class TaskExecutor {
std::atomic<size_t> next_task_{0};
};

double SquareError(const double* input_history, const double* output,
size_t num_channels, size_t total, size_t output_len) {
double res = 0.0;
for (size_t i = 0; i < output_len; ++i) {
int input_ix = i + total;
size_t histo_ix = num_channels * (input_ix & kHistoryMask);
for (size_t c = 0; c < num_channels; ++ c) {
double in = input_history[histo_ix + c];
double out = output[num_channels * i + c];
double diff = in - out;
res += diff * diff;
}
}
return res;
}

template <typename In, typename Out>
void Process(
In& input_stream, Out& output_stream,
Expand Down Expand Up @@ -197,13 +213,16 @@ void Process(
TaskExecutor pool(40, num_channels);

start_progress();
int64_t total = 0;
int64_t total_in = 0;
int64_t total_out = 0;
bool done = false;
double err = 0.0;
while (!done) {
int64_t read = input_stream.readf(input.data(), kBlockSize);
memset(input.data() + read, 0, input.size() - read);
size_t bytes_read = num_channels * read * sizeof(float);
memset(input.data() + bytes_read, 0, input.size() - bytes_read);
for (int i = 0; i < read; ++i) {
int input_ix = i + total;
int input_ix = i + total_in;
size_t histo_ix = num_channels * (input_ix & kHistoryMask);
for (size_t c = 0; c < num_channels; ++c) {
history[histo_ix + c] = input[num_channels * i + c];
Expand All @@ -213,11 +232,11 @@ void Process(
done = true;
read = max_delay;
}
int64_t output_len = total < max_delay ?
std::max<int64_t>(0, read - (max_delay - total)) :
int64_t output_len = total_in < max_delay ?
std::max<int64_t>(0, read - (max_delay - total_in)) :
read;

pool.Execute(kNumRotators, read, total, max_delay, history.data(),
pool.Execute(kNumRotators, read, total_in, max_delay, history.data(),
rot.data());

std::fill(output.begin(), output.end(), 0);
Expand All @@ -228,9 +247,15 @@ void Process(
}
}
output_stream.writef(output.data(), output_len);
total += read;
set_progress(total);
err += SquareError(history.data(), output.data(), num_channels, total_out,
output_len);
total_in += read;
total_out += output_len;
set_progress(total_in);
}
err /= total_out;
double psnr = -10.0 * std::log(err) / std::log(10.0);
fprintf(stderr, "MSE: %f PSNR: %f\n", err, psnr);
}

} // namespace
Expand Down

0 comments on commit e647cd8

Please sign in to comment.