Skip to content

Commit

Permalink
Reduce RAM usage of spatial_error_propagation by optimizing subsamp…
Browse files Browse the repository at this point in the history
…ling of pairwise distances (#672)
  • Loading branch information
rhugonnet authored Dec 14, 2024
1 parent 29a8cf4 commit 55bb902
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 40 deletions.
34 changes: 19 additions & 15 deletions tests/test_spatialstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import xdem
from xdem import examples
from xdem._typing import NDArrayf
from xdem.spatialstats import EmpiricalVariogramKArgs, nmad
from xdem.spatialstats import EmpiricalVariogramKArgs, neff_hugonnet_approx, nmad

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
Expand Down Expand Up @@ -1066,31 +1066,23 @@ def test_neff_exact_and_approx_hugonnet(self) -> None:
)

# Check that the function runs with default parameters
# t0 = time.time()
neff_exact = xdem.spatialstats.neff_exact(
coords=coords, errors=errors, params_variogram_model=params_variogram_model
)
# t1 = time.time()

# Check that the non-vectorized version gives the same result
neff_exact_nv = xdem.spatialstats.neff_exact(
coords=coords, errors=errors, params_variogram_model=params_variogram_model, vectorized=False
)
# t2 = time.time()
assert neff_exact == pytest.approx(neff_exact_nv, rel=0.001)

# Check that the vectorized version is faster (vectorized for about 250 points here)
# assert (t1 - t0) < (t2 - t1)

# Check that the approximation function runs with default parameters, sampling 100 out of 250 samples
# t3 = time.time()
neff_approx = xdem.spatialstats.neff_hugonnet_approx(
neff_approx = neff_hugonnet_approx(
coords=coords, errors=errors, params_variogram_model=params_variogram_model, subsample=100, random_state=42
)
# t4 = time.time()

# Check that the non-vectorized version gives the same result, sampling 100 out of 250 samples
neff_approx_nv = xdem.spatialstats.neff_hugonnet_approx(
neff_approx_nv = neff_hugonnet_approx(
coords=coords,
errors=errors,
params_variogram_model=params_variogram_model,
Expand All @@ -1101,13 +1093,25 @@ def test_neff_exact_and_approx_hugonnet(self) -> None:

assert neff_approx == pytest.approx(neff_approx_nv, rel=0.001)

# Check that the approximation version is faster within 30% error
# TODO: find a more robust way to test time for CI
# assert (t4 - t3) < (t1 - t0)

# Check that the approximation is about the same as the original estimate within 10%
assert neff_approx == pytest.approx(neff_exact, rel=0.1)

# Check that the approximation works even on large dataset without creating memory errors
# 100,000 points squared (pairwise) should use more than 64GB of RAM without subsample
rng = np.random.default_rng(42)
coords = rng.normal(size=(100000, 2))
errors = rng.normal(size=(100000))
# This uses a subsample of 100, so should run just fine despite the large size
neff_approx_nv = neff_hugonnet_approx(
coords=coords,
errors=errors,
params_variogram_model=params_variogram_model,
subsample=100,
vectorized=True,
random_state=42,
)
assert neff_approx_nv is not None

def test_number_effective_samples(self) -> None:
"""Test that the wrapper function for neff functions behaves correctly and that output values are robust"""

Expand Down
36 changes: 11 additions & 25 deletions xdem/spatialstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from scipy.interpolate import RegularGridInterpolator, griddata
from scipy.optimize import curve_fit
from scipy.signal import fftconvolve
from scipy.spatial.distance import pdist, squareform
from scipy.spatial.distance import cdist, pdist, squareform
from scipy.stats import binned_statistic, binned_statistic_2d, binned_statistic_dd
from skimage.draw import disk

Expand Down Expand Up @@ -2248,48 +2248,34 @@ def neff_hugonnet_approx(
# Get spatial correlation function from variogram parameters
rho = correlation_from_variogram(params_variogram_model)

# Get number of points and pairwise distance compacted matrix from scipy.pdist
# Get number of points and pairwise distance matrix from scipy.cdist
n = len(coords)
pds = pdist(coords)

# At maximum, the number of subsamples has to be equal to number of points
subsample = min(subsample, n)

# Get random subset of points for one of the sums
rand_points = rng.choice(n, size=subsample, replace=False)

# Subsample coordinates in 1D before computing pairwise distances
sub_coords = coords[rand_points, :]
sub_errors = errors[rand_points]
pds_matrix = cdist(coords, sub_coords, "euclidean")

# Now we compute the double covariance sum
# Either using for-loop-version
if not vectorized:
var = 0.0
for ind_sub in range(subsample):
for j in range(n):

i = rand_points[ind_sub]
# For index calculation of the pairwise distance,
# see https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.pdist.html
if i == j:
d = 0
elif i < j:
ind = n * i + j - ((i + 2) * (i + 1)) // 2
d = pds[ind]
else:
ind = n * j + i - ((j + 2) * (j + 1)) // 2
d = pds[ind]

for i in range(pds_matrix.shape[0]):
for j in range(pds_matrix.shape[1]):
d = pds_matrix[i, j]
var += rho(d) * errors[i] * errors[j] # type: ignore

# Or vectorized version
else:
# We subset the points used in one dimension, for errors and pairwise distances computed
errors_sub = errors[rand_points]
pds_matrix = squareform(pds)
pds_matrix_sub = pds_matrix[:, rand_points]
# Vectorized calculation
var = np.sum(
errors.reshape((-1, 1))
@ errors_sub.reshape((1, -1))
* rho(pds_matrix_sub.flatten()).reshape(pds_matrix_sub.shape)
errors.reshape((-1, 1)) @ sub_errors.reshape((1, -1)) * rho(pds_matrix.flatten()).reshape(pds_matrix.shape)
)

# The number of effective sample is the fraction of total sill by squared standard error
Expand Down

0 comments on commit 55bb902

Please sign in to comment.