Skip to content

Commit

Permalink
Least_square_direct: cache input instead
Browse files Browse the repository at this point in the history
When computing `prox_L2(v)`, don't cache the Fourier-transformed signal
`v`, cache the Fourier-transformed raw image `b` of the L2-norm `|| x -
b ||_2^2`.
  • Loading branch information
antonysigma committed Jan 7, 2025
1 parent d114511 commit 8eeb8bd
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 32 deletions.
10 changes: 5 additions & 5 deletions proximal/halide/interface/prox_L2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ prox_L2_glue(const array_float_t input, const float theta, const array_float_t o
auto input_buf = getHalideBuffer<3>(input);
auto offset_buf = getHalideBuffer<3>(offset);
auto freq_diag_buf = getHalideComplexBuffer<4>(freq_diag);
auto output_buf = getHalideBuffer<3>(output, true);
auto output_buf = getHalideBuffer<3>(output, false);

// A hash to denote when the Fourier transformed offset signal should be
// recomputed. TODO(Antony): It is more practical to marshal the
// proximal.Problem instance hash from Python runtime to here.
const auto offset_buf_hash = reinterpret_cast<uintptr_t>(offset_buf.begin());
const auto input_buf_hash = reinterpret_cast<uintptr_t>(input_buf.begin());

const auto success = least_square_direct(input_buf, theta, offset_buf, freq_diag_buf,
offset_buf_hash, output_buf);
const auto has_error = least_square_direct(input_buf, theta, offset_buf, freq_diag_buf,
input_buf_hash, output_buf);
output_buf.copy_to_host();
return success;
return has_error;
}

} // namespace proximal
Expand Down
70 changes: 43 additions & 27 deletions proximal/halide/src/least_square_direct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ using Halide::BoundaryConditions::repeat_image;

namespace {

enum component_t : int32_t { RE = 0, IM = 1 };

class least_square_direct_gen : public Generator<least_square_direct_gen> {
public:
Input<Buffer<float, 3>> input{"input"};
Input<float> rho{"rho"};
Input<Buffer<float, 3>> offset{"offset"};
Input<Buffer<float, 4>> freq_diag{"freq_diag"};
Input<uint64_t> offset_hash{"offset_hash"};
Input<uint64_t> input_hash{"offset_hash"};

Output<Buffer<float, 3>> output{"output"};

Expand All @@ -43,16 +45,16 @@ class least_square_direct_gen : public Generator<least_square_direct_gen> {
// Forward DFT
Fft2dDesc fwd_desc{};
fwd_desc.parallel = true;
f_input = fft2d_r2c(shifted_input, W, H, target, fwd_desc);
f_input_tmp = fft2d_r2c(shifted_input, W, H, target, fwd_desc);

f_offset_tmp = fft2d_r2c(padded_offset, W, H, target, fwd_desc);
f_input_cached(x, y, k, c) =
memoize_tag(mux(k, {f_input_tmp(x, y, c).re(), f_input_tmp(x, y, c).im()}), input_hash);
f_input(x, y, c) = {f_input_cached(x, y, RE, c), f_input_cached(x, y, IM, c)};

f_offset_cached(x, y, k, c) = memoize_tag(
mux(k, {re(f_offset_tmp(x, y, c)), im(f_offset_tmp(x, y, c))}), offset_hash);
f_offset(x, y, c) = {f_offset_cached(x, y, 0, c), f_offset_cached(x, y, 1, c)};
f_offset = fft2d_r2c(padded_offset, W, H, target, fwd_desc);

// Cast freq_diag from pair<float> to std:::complex<float>
diag(x, y, c) = {freq_diag(0, x, y, c), freq_diag(1, x, y, c)};
diag(x, y, c) = {freq_diag(RE, x, y, c), freq_diag(IM, x, y, c)};

if (ignore_offset) {
weighted_average(x, y, c) =
Expand All @@ -64,12 +66,11 @@ class least_square_direct_gen : public Generator<least_square_direct_gen> {

// Inverse DFT
Fft2dDesc inv_desc{};
inv_desc.gain = 1.0f / (W * H);
inv_desc.parallel = true;

inversed = fft2d_c2r(weighted_average, W, H, target, inv_desc);

// Crop the image by the user-defined dimensions
output(x, y, c) = inversed(x, y, c) / (W * H);
output = inversed;
}

void validateDimensions() {
Expand Down Expand Up @@ -125,38 +126,53 @@ class least_square_direct_gen : public Generator<least_square_direct_gen> {
.parallel(y)
.parallel(c);

f_input //
.compute_root()
.parallel(c)
;

if(!ignore_offset) {
f_offset_cached //
.compute_root()
.bound(k, 0, 2)
.unroll(k)
f_offset
.compute_root() //
.parallel(c);

padded_offset
.compute_root() //
.vectorize(x, vfloat)
.parallel(y)
.parallel(c)
.memoize();

f_offset_tmp.compute_at(f_offset_cached, c);
.parallel(c);
}

// Cache the Fourier-transformed input
f_input_cached //
.compute_root()
.bound(k, 0, 2)
.unroll(k)
.vectorize(x, vfloat)
.parallel(y)
.parallel(c)
.memoize();

// Compute FFT only when cache is evicted
f_input_tmp.compute_at(f_input_cached, c);

shifted_input
.compute_at(f_input_cached, Var::outermost()) //
.vectorize(x, vfloat)
.parallel(y)
.parallel(c);
}

private:
// coordinates in the space domain
Var x{"x"}, y{"y"}, c{"c"}, k{"k"};

Func padded_input{"padded_input"};
Func cyclic_input{"cyclic_input"};
Func shifted_input{"shifted_input"};
Func padded_offset{"padded_offset"};
ComplexFunc diag{"diag"};
Func f_offset_cached{"f_offset_cached"};

ComplexFunc f_offset;

Func f_input_cached{"f_input_cached"};
ComplexFunc f_input_tmp;
ComplexFunc f_input{"f_input"};
ComplexFunc f_offset_tmp{"f_offset_tmp"};
ComplexFunc f_offset{"f_offset"};

ComplexFunc weighted_average{"weighted_average"};
Func inversed{"inversed"};
};
Expand Down

0 comments on commit 8eeb8bd

Please sign in to comment.