You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
What do you think of also supporting `callable` methods as a parameter? In such case, we would be able to provide a default random generator based on input with something like:
class DeepLiftShap(Explainer):
SUPPORTED_MODULES = [Convolution]
def __init__(
self,
model: Module,
background_data: Optional[Tensor, Callable[[Tensor, Tensor], Tensor]] = None,
):
if background_data is None:
background_data = lambda x, y: torch.rand_like(x)
if torch.is_tensor(background_data):
background_data = lambda x, y: background_data
self.background_data = background_data
Originally posted by @enver1323 in #151 (comment)
The text was updated successfully, but these errors were encountered: