diff --git a/src/deepali/losses/functional.py b/src/deepali/losses/functional.py index 0bff564..aa00444 100644 --- a/src/deepali/losses/functional.py +++ b/src/deepali/losses/functional.py @@ -1038,11 +1038,11 @@ def mi_loss( # Random image samples, optionally weighted by mask if sample_ratio is not None: - if num_samples is not None: - raise ValueError("mi_loss() 'num_samples' and 'sample_ratio' are mutually exclusive") if sample_ratio <= 0 or sample_ratio > 1: raise ValueError("mi_loss() 'sample_ratio' must be in open-closed interval (0, 1]") - num_samples = max(1, int(sample_ratio * target.shape[2:].numel())) + if num_samples is None: + num_samples = target.shape[2:].numel() + num_samples = min(max(1, int(sample_ratio * target.shape[2:].numel())), num_samples) if num_samples is not None: input, target = rand_sample([input, target], num_samples, mask=mask, replacement=True) elif mask is not None: @@ -1066,12 +1066,12 @@ def parzen_window_fn(x: Tensor) -> Tensor: # calculate joint histogram hist_joint = pw_input.bmm(pw_target.transpose(1, 2)) # (N, #bins, #bins) - hist_norm = hist_joint.flatten(start_dim=1, end_dim=-1).sum(dim=1) + 1e-5 + hist_norm = hist_joint.flatten(start_dim=1, end_dim=-1).sum(dim=1).add_(1e-5) # joint and marginal distributions p_joint = hist_joint / hist_norm.view(-1, 1, 1) # (N, #bins, #bins) / (N, 1, 1) - p_input = torch.sum(p_joint, dim=2) - p_target = torch.sum(p_joint, dim=1) + p_input = p_joint.sum(dim=2) + p_target = p_joint.sum(dim=1) # calculate entropy ent_input = -torch.sum(p_input * torch.log(p_input + 1e-5), dim=1) # (N,1)