diff --git a/TSInterpret/InterpretabilityModels/Saliency/SaliencyMethods_PTY.py b/TSInterpret/InterpretabilityModels/Saliency/SaliencyMethods_PTY.py index b4705db..5f7b433 100644 --- a/TSInterpret/InterpretabilityModels/Saliency/SaliencyMethods_PTY.py +++ b/TSInterpret/InterpretabilityModels/Saliency/SaliencyMethods_PTY.py @@ -49,6 +49,7 @@ def __init__( method: str = "GRAD", mode: str = "time", tsr: bool = True, + normalize:bool=True, device: str = "cpu", ) -> None: """Initialization @@ -59,9 +60,10 @@ def __init__( method str: Saliency Methode to be used mode str: Second dimension 'time'->`(1,time,feat)` or 'feat'->`(1,feat,time)` """ - super().__init__(model, NumTimeSteps, NumFeatures, method, mode) + super().__init__(model, NumTimeSteps, NumFeatures, method, mode,normalize) self.method = method self.tsr = tsr + #self.normalize=normalize if method == "GRAD": self.Grad = Saliency(model) elif method == "IG": @@ -105,8 +107,11 @@ def explain(self, item: np.ndarray, labels: int, TSR=None, **kwargs): idx = 0 item = np.array(item.tolist()) # , dtype=np.float64) input = torch.from_numpy(item) - - input = input.reshape(-1, self.NumTimeSteps, self.NumFeatures).to(self.device) + if self.mode=='feat': + input = np.swapaxes(input, -1, -2) + #if self.mode =='time': + # input = input.reshape(-1, self.NumTimeSteps, self.NumFeatures).to(self.device) + input = Variable(input, volatile=False, requires_grad=True) batch_size = input.shape[0] @@ -123,7 +128,7 @@ def explain(self, item: np.ndarray, labels: int, TSR=None, **kwargs): ) # input = samples.reshape(-1, args.NumTimeSteps, args.NumFeatures).to(device) if self.mode == "feat": - input = input.reshape(-1, self.NumFeatures, self.NumTimeSteps) + input = np.swapaxes(input, -1, -2) if "baseline_single" in kwargs.keys(): baseline_single = kwargs["baseline_single"] else: @@ -217,13 +222,14 @@ def explain(self, item: np.ndarray, labels: int, TSR=None, **kwargs): # print('TSR Saliency', TSR_saliency.shape) return TSR_saliency else: - # print('TSR', TSR) - # TODO attributions does not exist for SVS and Fo - rescaledGrad[ - idx : idx + batch_size, :, : - ] = self._givenAttGetRescaledSaliency(attributions) - # print('Rescaled', rescaledGrad.shape) - return rescaledGrad[0] + if self.normalize: + + rescaledGrad[ + idx : idx + batch_size, :, : + ] = self._givenAttGetRescaledSaliency(attributions) + return rescaledGrad[0] + else: + return attributions.detach().numpy()[0] def _getTwoStepRescaling( self, @@ -241,14 +247,10 @@ def _getTwoStepRescaling( timeGrad = np.zeros((1, sequence_length)) inputGrad = np.zeros((input_size, 1)) newGrad = np.zeros((input_size, sequence_length)) - # print("has Sliding Window", hasSliding_window_shapes) - if self.mode == "time": - newGrad = np.swapaxes(newGrad, -1, -2) - # print(input.shape) - # print('mode timw') - # print('inüut1',input) - # input = np.swapaxes(input,-1,-2)#.reshape(-1, sequence_length, input_size) - # print('inüut1',input) + + #if self.mode == "time": + # newGrad = np.swapaxes(newGrad, -1, -2) + if hasBaseline is None: ActualGrad = ( @@ -286,24 +288,21 @@ def _getTwoStepRescaling( .data.cpu() .numpy() ) - # if self.mode == "time": - # ActualGrad = ActualGrad.reshape(-1, input_size, sequence_length) + if self.mode == "time": input = np.swapaxes( input, -1, -2 - ) # input.reshape(-1, input_size, sequence_length) + ) for t in range(sequence_length): newInput = input.clone() - # if newInput.shape[-1] == self.NumTimeSteps: - # print('A') + newInput[:, :, t] = assignment # else: - # print('B') - # newInput[:, t,:] = assignment + if self.mode == "time": newInput = np.swapaxes( newInput, -1, -2 - ) # .reshape(-1, sequence_length, input_size) + ) if hasBaseline is None: timeGrad_perTime = ( self.Grad.attribute(newInput, target=TestingLabel) @@ -323,7 +322,7 @@ def _getTwoStepRescaling( .numpy() ) elif hasSliding_window_shapes is not None: - # print("HAS SLIDING WINDOW") + timeGrad_perTime = ( self.Grad.attribute( newInput, @@ -342,89 +341,87 @@ def _getTwoStepRescaling( .data.cpu() .numpy() ) - # import sys - # sys.exit(1) timeGrad_perTime = np.absolute(ActualGrad - timeGrad_perTime) if self.mode == "time": - timeGrad_perTime = np.swapaxes(timeGrad_perTime, -1, -2) # .reshape( - # -1, input_size, sequence_length - # ) + timeGrad_perTime = np.swapaxes(timeGrad_perTime, -1, -2) timeGrad[:, t] = np.sum(timeGrad_perTime) timeContribution = preprocessing.minmax_scale(timeGrad, axis=1) - # print(timeContribution.shape) + meanTime = np.quantile(timeContribution, 0.55) - for t in range(sequence_length): - if timeContribution[0, t] > meanTime: - for c in range(input_size): - newInput = input.clone() - newInput[:, c, t] = assignment - if self.mode == "time": - newInput = np.swapaxes( - newInput, -1, -2 - ) # .reshape(-1, sequence_length, input_size) + if input_size>1: + for t in range(sequence_length): + print('TIME CONR',timeContribution[0, t]) + if timeContribution[0, t] > meanTime: + for c in range(input_size): + newInput = input.clone() + newInput[:, c, t] = assignment + if self.mode == "time": + newInput = np.swapaxes( + newInput, -1, -2 + ) # .reshape(-1, sequence_length, input_size) - if hasBaseline is None: - inputGrad_perInput = ( - self.Grad.attribute(newInput, target=TestingLabel) - .data.cpu() - .numpy() - ) - else: - if hasFeatureMask is not None: - inputGrad_perInput = ( - self.Grad.attribute( - newInput, - baselines=hasBaseline, - target=TestingLabel, - feature_mask=hasFeatureMask, - ) - .data.cpu() - .numpy() - ) - elif hasSliding_window_shapes is not None: + if hasBaseline is None: inputGrad_perInput = ( - self.Grad.attribute( - newInput, - sliding_window_shapes=hasSliding_window_shapes, - baselines=hasBaseline, - target=TestingLabel, - ) + self.Grad.attribute(newInput, target=TestingLabel) .data.cpu() .numpy() ) else: - inputGrad_perInput = ( - self.Grad.attribute( - newInput, baselines=hasBaseline, target=TestingLabel + if hasFeatureMask is not None: + inputGrad_perInput = ( + self.Grad.attribute( + newInput, + baselines=hasBaseline, + target=TestingLabel, + feature_mask=hasFeatureMask, + ) + .data.cpu() + .numpy() + ) + elif hasSliding_window_shapes is not None: + inputGrad_perInput = ( + self.Grad.attribute( + newInput, + sliding_window_shapes=hasSliding_window_shapes, + baselines=hasBaseline, + target=TestingLabel, + ) + .data.cpu() + .numpy() + ) + else: + inputGrad_perInput = ( + self.Grad.attribute( + newInput, baselines=hasBaseline, target=TestingLabel + ) + .data.cpu() + .numpy() ) - .data.cpu() - .numpy() - ) - inputGrad_perInput = np.absolute(ActualGrad - inputGrad_perInput) - inputGrad_perInput = np.swapaxes( - inputGrad_perInput, -1, -2 - ) # .reshape( - # -1, input_size, sequence_length - # ) - inputGrad[c, :] = np.sum(inputGrad_perInput) - featureContribution = preprocessing.minmax_scale(inputGrad, axis=0) + inputGrad_perInput = np.absolute(ActualGrad - inputGrad_perInput) + inputGrad_perInput = np.swapaxes( + inputGrad_perInput, -1, -2 + ) # .reshape( + # -1, input_size, sequence_length + # ) + inputGrad[c, :] = np.sum(inputGrad_perInput) + featureContribution = preprocessing.minmax_scale(inputGrad, axis=0) - else: - featureContribution = np.ones((input_size, 1)) * 0.1 - # print('FC',featureContribution) - # newGrad = newGrad#.reshape(input_size, sequence_length) - if self.mode == "time": - # newGrad = newGrad.reshape(sequence_length, input_size) - newGrad = np.swapaxes(newGrad, -1, -2) - for c in range(input_size): - newGrad[c, t] = timeContribution[0, t] * featureContribution[c, 0] - if self.mode == "time": - # newGrad = newGrad.reshape(sequence_length, input_size) - newGrad = np.swapaxes(newGrad, -1, -2) + else: + featureContribution = np.ones((input_size, 1)) * 0.1 + + + for c in range(input_size): + newGrad[c, t] = timeContribution[0, t] * featureContribution[c, 0] + + else: + newGrad=timeContribution + + if self.mode == "time": + newGrad = np.swapaxes(newGrad, -1, -2) return newGrad def _givenAttGetRescaledSaliency(self, attributions, isTensor=True): diff --git a/TSInterpret/InterpretabilityModels/Saliency/Saliency_Base.py b/TSInterpret/InterpretabilityModels/Saliency/Saliency_Base.py index 57a151c..06b4c99 100644 --- a/TSInterpret/InterpretabilityModels/Saliency/Saliency_Base.py +++ b/TSInterpret/InterpretabilityModels/Saliency/Saliency_Base.py @@ -27,6 +27,7 @@ def __init__( NumFeatures: int, method: str = "GRAD", mode: str = "time", + normalize:bool=True, ) -> None: """ Arguments: @@ -35,11 +36,13 @@ def __init__( NumFeatures int: number of features. method str: Saliency Method to be used. mode str: Second dimension 'time'->`(1,time,feat)` or 'feat'->`(1,feat,time)`. + normalize bool: Wheather or not to normalize the results """ super().__init__(model, mode) self.NumTimeSteps = NumTimeSteps self.NumFeatures = NumFeatures self.method = method + self.normalize=normalize def explain(self): raise NotImplementedError("Don't use the base CF class directly") @@ -56,6 +59,7 @@ def plot(self, item, exp, figsize=(6.4, 4.8), heatmap=False, save=None): save str: Path to save figure. """ plt.style.use("classic") + print(self.normalize) i = 0 if self.mode == "time": print("time mode") @@ -75,8 +79,8 @@ def plot(self, item, exp, figsize=(6.4, 4.8), heatmap=False, save=None): cbar=True, ax=ax011, yticklabels=False, - vmin=0, - vmax=1, + #vmin=0, + #vmax=1, ) elif len(item[0]) == 1: # if only onedimensional input @@ -85,15 +89,24 @@ def plot(self, item, exp, figsize=(6.4, 4.8), heatmap=False, save=None): ) # cbar_ax = fig.add_axes([.91, .3, .03, .4]) axn012 = axn.twinx() - sns.heatmap( - exp.reshape(1, -1), - fmt="g", - cmap="viridis", - ax=axn, - yticklabels=False, - vmin=0, - vmax=1, - ) + if self.normalize: + sns.heatmap( + exp.reshape(1, -1), + fmt="g", + cmap="viridis", + ax=axn, + yticklabels=False, + vmin=0, + vmax=1, + ) + else: + sns.heatmap( + exp.reshape(1, -1), + fmt="g", + cmap="viridis", + ax=axn, + yticklabels=False, + ) sns.lineplot( x=range(0, len(item[0][0].reshape(-1))), y=item[0][0].flatten(), @@ -114,18 +127,30 @@ def plot(self, item, exp, figsize=(6.4, 4.8), heatmap=False, save=None): # ax012.append(ax011[i].twinx()) # ax011[i].set_facecolor("#440154FF") axn012 = axn[i].twinx() + if self.normalize: + + sns.heatmap( + exp[i].reshape(1, -1), + fmt="g", + cmap="viridis", + cbar=i == 0, + cbar_ax=None if i else cbar_ax, + ax=axn[i], + yticklabels=False, + vmin=0, + vmax=1, + ) + else: + sns.heatmap( + exp[i].reshape(1, -1), + fmt="g", + cmap="viridis", + cbar=i == 0, + cbar_ax=None if i else cbar_ax, + ax=axn[i], + yticklabels=False, + ) - sns.heatmap( - exp[i].reshape(1, -1), - fmt="g", - cmap="viridis", - cbar=i == 0, - cbar_ax=None if i else cbar_ax, - ax=axn[i], - yticklabels=False, - vmin=0, - vmax=1, - ) sns.lineplot( x=range(0, len(channel.reshape(-1))), y=channel.flatten(), diff --git a/TSInterpret/InterpretabilityModels/Saliency/TSR.py b/TSInterpret/InterpretabilityModels/Saliency/TSR.py index 5c03304..3bb0180 100644 --- a/TSInterpret/InterpretabilityModels/Saliency/TSR.py +++ b/TSInterpret/InterpretabilityModels/Saliency/TSR.py @@ -12,7 +12,7 @@ class TSR: """ def __new__( - self, model, NumTimeSteps, NumFeatures, method="GRAD", mode="time", device="cpu" + self, model, NumTimeSteps, NumFeatures, method="GRAD", mode="time", device="cpu", normalize=True, tsr=True ): """Initialization Arguments: @@ -30,11 +30,13 @@ def __new__( method=method, mode=mode, device=device, + normalize=normalize, + tsr=tsr ) elif isinstance(model, tensorflow.keras.Model): return Saliency_TF( - model, NumTimeSteps, NumFeatures, method=method, mode=mode + model, NumTimeSteps, NumFeatures, method=method, mode=mode, tsr=tsr ) else: raise NotImplementedError(