Skip to content

Commit

Permalink
[losses] Modify mi_loss() to enable num_samples to limit max when usi…
Browse files Browse the repository at this point in the history
…ng sample_ratio
  • Loading branch information
aschuh-hf committed Oct 23, 2023
1 parent d3957ce commit b2a9322
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/deepali/losses/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,11 +1038,11 @@ def mi_loss(

# Random image samples, optionally weighted by mask
if sample_ratio is not None:
if num_samples is not None:
raise ValueError("mi_loss() 'num_samples' and 'sample_ratio' are mutually exclusive")
if sample_ratio <= 0 or sample_ratio > 1:
raise ValueError("mi_loss() 'sample_ratio' must be in open-closed interval (0, 1]")
num_samples = max(1, int(sample_ratio * target.shape[2:].numel()))
if num_samples is None:
num_samples = target.shape[2:].numel()
num_samples = min(max(1, int(sample_ratio * target.shape[2:].numel())), num_samples)
if num_samples is not None:
input, target = rand_sample([input, target], num_samples, mask=mask, replacement=True)
elif mask is not None:
Expand All @@ -1066,12 +1066,12 @@ def parzen_window_fn(x: Tensor) -> Tensor:

# calculate joint histogram
hist_joint = pw_input.bmm(pw_target.transpose(1, 2)) # (N, #bins, #bins)
hist_norm = hist_joint.flatten(start_dim=1, end_dim=-1).sum(dim=1) + 1e-5
hist_norm = hist_joint.flatten(start_dim=1, end_dim=-1).sum(dim=1).add_(1e-5)

# joint and marginal distributions
p_joint = hist_joint / hist_norm.view(-1, 1, 1) # (N, #bins, #bins) / (N, 1, 1)
p_input = torch.sum(p_joint, dim=2)
p_target = torch.sum(p_joint, dim=1)
p_input = p_joint.sum(dim=2)
p_target = p_joint.sum(dim=1)

# calculate entropy
ent_input = -torch.sum(p_input * torch.log(p_input + 1e-5), dim=1) # (N,1)
Expand Down

0 comments on commit b2a9322

Please sign in to comment.