Skip to content

Commit

Permalink
Merge pull request #60 from fzi-forschungszentrum-informatik/LEFTIST
Browse files Browse the repository at this point in the history
Leftist
  • Loading branch information
JHoelli authored Mar 6, 2024
2 parents b576faf + 3af5311 commit 7b675dd
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 118 deletions.
185 changes: 91 additions & 94 deletions TSInterpret/InterpretabilityModels/Saliency/SaliencyMethods_PTY.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
method: str = "GRAD",
mode: str = "time",
tsr: bool = True,
normalize:bool=True,
device: str = "cpu",
) -> None:
"""Initialization
Expand All @@ -59,9 +60,10 @@ def __init__(
method str: Saliency Methode to be used
mode str: Second dimension 'time'->`(1,time,feat)` or 'feat'->`(1,feat,time)`
"""
super().__init__(model, NumTimeSteps, NumFeatures, method, mode)
super().__init__(model, NumTimeSteps, NumFeatures, method, mode,normalize)
self.method = method
self.tsr = tsr
#self.normalize=normalize
if method == "GRAD":
self.Grad = Saliency(model)
elif method == "IG":
Expand Down Expand Up @@ -105,8 +107,11 @@ def explain(self, item: np.ndarray, labels: int, TSR=None, **kwargs):
idx = 0
item = np.array(item.tolist()) # , dtype=np.float64)
input = torch.from_numpy(item)

input = input.reshape(-1, self.NumTimeSteps, self.NumFeatures).to(self.device)
if self.mode=='feat':
input = np.swapaxes(input, -1, -2)
#if self.mode =='time':
# input = input.reshape(-1, self.NumTimeSteps, self.NumFeatures).to(self.device)

input = Variable(input, volatile=False, requires_grad=True)

batch_size = input.shape[0]
Expand All @@ -123,7 +128,7 @@ def explain(self, item: np.ndarray, labels: int, TSR=None, **kwargs):
)
# input = samples.reshape(-1, args.NumTimeSteps, args.NumFeatures).to(device)
if self.mode == "feat":
input = input.reshape(-1, self.NumFeatures, self.NumTimeSteps)
input = np.swapaxes(input, -1, -2)
if "baseline_single" in kwargs.keys():
baseline_single = kwargs["baseline_single"]
else:
Expand Down Expand Up @@ -217,13 +222,14 @@ def explain(self, item: np.ndarray, labels: int, TSR=None, **kwargs):
# print('TSR Saliency', TSR_saliency.shape)
return TSR_saliency
else:
# print('TSR', TSR)
# TODO attributions does not exist for SVS and Fo
rescaledGrad[
idx : idx + batch_size, :, :
] = self._givenAttGetRescaledSaliency(attributions)
# print('Rescaled', rescaledGrad.shape)
return rescaledGrad[0]
if self.normalize:

rescaledGrad[
idx : idx + batch_size, :, :
] = self._givenAttGetRescaledSaliency(attributions)
return rescaledGrad[0]
else:
return attributions.detach().numpy()[0]

def _getTwoStepRescaling(
self,
Expand All @@ -241,14 +247,10 @@ def _getTwoStepRescaling(
timeGrad = np.zeros((1, sequence_length))
inputGrad = np.zeros((input_size, 1))
newGrad = np.zeros((input_size, sequence_length))
# print("has Sliding Window", hasSliding_window_shapes)
if self.mode == "time":
newGrad = np.swapaxes(newGrad, -1, -2)
# print(input.shape)
# print('mode timw')
# print('inüut1',input)
# input = np.swapaxes(input,-1,-2)#.reshape(-1, sequence_length, input_size)
# print('inüut1',input)

#if self.mode == "time":
# newGrad = np.swapaxes(newGrad, -1, -2)


if hasBaseline is None:
ActualGrad = (
Expand Down Expand Up @@ -286,24 +288,21 @@ def _getTwoStepRescaling(
.data.cpu()
.numpy()
)
# if self.mode == "time":
# ActualGrad = ActualGrad.reshape(-1, input_size, sequence_length)

if self.mode == "time":
input = np.swapaxes(
input, -1, -2
) # input.reshape(-1, input_size, sequence_length)
)
for t in range(sequence_length):
newInput = input.clone()
# if newInput.shape[-1] == self.NumTimeSteps:
# print('A')

newInput[:, :, t] = assignment
# else:
# print('B')
# newInput[:, t,:] = assignment

if self.mode == "time":
newInput = np.swapaxes(
newInput, -1, -2
) # .reshape(-1, sequence_length, input_size)
)
if hasBaseline is None:
timeGrad_perTime = (
self.Grad.attribute(newInput, target=TestingLabel)
Expand All @@ -323,7 +322,7 @@ def _getTwoStepRescaling(
.numpy()
)
elif hasSliding_window_shapes is not None:
# print("HAS SLIDING WINDOW")

timeGrad_perTime = (
self.Grad.attribute(
newInput,
Expand All @@ -342,89 +341,87 @@ def _getTwoStepRescaling(
.data.cpu()
.numpy()
)
# import sys
# sys.exit(1)

timeGrad_perTime = np.absolute(ActualGrad - timeGrad_perTime)
if self.mode == "time":
timeGrad_perTime = np.swapaxes(timeGrad_perTime, -1, -2) # .reshape(
# -1, input_size, sequence_length
# )
timeGrad_perTime = np.swapaxes(timeGrad_perTime, -1, -2)
timeGrad[:, t] = np.sum(timeGrad_perTime)

timeContribution = preprocessing.minmax_scale(timeGrad, axis=1)
# print(timeContribution.shape)

meanTime = np.quantile(timeContribution, 0.55)

for t in range(sequence_length):
if timeContribution[0, t] > meanTime:
for c in range(input_size):
newInput = input.clone()
newInput[:, c, t] = assignment
if self.mode == "time":
newInput = np.swapaxes(
newInput, -1, -2
) # .reshape(-1, sequence_length, input_size)
if input_size>1:
for t in range(sequence_length):
print('TIME CONR',timeContribution[0, t])
if timeContribution[0, t] > meanTime:
for c in range(input_size):
newInput = input.clone()
newInput[:, c, t] = assignment
if self.mode == "time":
newInput = np.swapaxes(
newInput, -1, -2
) # .reshape(-1, sequence_length, input_size)

if hasBaseline is None:
inputGrad_perInput = (
self.Grad.attribute(newInput, target=TestingLabel)
.data.cpu()
.numpy()
)
else:
if hasFeatureMask is not None:
inputGrad_perInput = (
self.Grad.attribute(
newInput,
baselines=hasBaseline,
target=TestingLabel,
feature_mask=hasFeatureMask,
)
.data.cpu()
.numpy()
)
elif hasSliding_window_shapes is not None:
if hasBaseline is None:
inputGrad_perInput = (
self.Grad.attribute(
newInput,
sliding_window_shapes=hasSliding_window_shapes,
baselines=hasBaseline,
target=TestingLabel,
)
self.Grad.attribute(newInput, target=TestingLabel)
.data.cpu()
.numpy()
)
else:
inputGrad_perInput = (
self.Grad.attribute(
newInput, baselines=hasBaseline, target=TestingLabel
if hasFeatureMask is not None:
inputGrad_perInput = (
self.Grad.attribute(
newInput,
baselines=hasBaseline,
target=TestingLabel,
feature_mask=hasFeatureMask,
)
.data.cpu()
.numpy()
)
elif hasSliding_window_shapes is not None:
inputGrad_perInput = (
self.Grad.attribute(
newInput,
sliding_window_shapes=hasSliding_window_shapes,
baselines=hasBaseline,
target=TestingLabel,
)
.data.cpu()
.numpy()
)
else:
inputGrad_perInput = (
self.Grad.attribute(
newInput, baselines=hasBaseline, target=TestingLabel
)
.data.cpu()
.numpy()
)
.data.cpu()
.numpy()
)

inputGrad_perInput = np.absolute(ActualGrad - inputGrad_perInput)
inputGrad_perInput = np.swapaxes(
inputGrad_perInput, -1, -2
) # .reshape(
# -1, input_size, sequence_length
# )
inputGrad[c, :] = np.sum(inputGrad_perInput)
featureContribution = preprocessing.minmax_scale(inputGrad, axis=0)
inputGrad_perInput = np.absolute(ActualGrad - inputGrad_perInput)
inputGrad_perInput = np.swapaxes(
inputGrad_perInput, -1, -2
) # .reshape(
# -1, input_size, sequence_length
# )
inputGrad[c, :] = np.sum(inputGrad_perInput)
featureContribution = preprocessing.minmax_scale(inputGrad, axis=0)

else:
featureContribution = np.ones((input_size, 1)) * 0.1
# print('FC',featureContribution)
# newGrad = newGrad#.reshape(input_size, sequence_length)
if self.mode == "time":
# newGrad = newGrad.reshape(sequence_length, input_size)
newGrad = np.swapaxes(newGrad, -1, -2)
for c in range(input_size):
newGrad[c, t] = timeContribution[0, t] * featureContribution[c, 0]
if self.mode == "time":
# newGrad = newGrad.reshape(sequence_length, input_size)
newGrad = np.swapaxes(newGrad, -1, -2)
else:
featureContribution = np.ones((input_size, 1)) * 0.1

for c in range(input_size):
newGrad[c, t] = timeContribution[0, t] * featureContribution[c, 0]

else:
newGrad=timeContribution

if self.mode == "time":
newGrad = np.swapaxes(newGrad, -1, -2)
return newGrad

def _givenAttGetRescaledSaliency(self, attributions, isTensor=True):
Expand Down
69 changes: 47 additions & 22 deletions TSInterpret/InterpretabilityModels/Saliency/Saliency_Base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
NumFeatures: int,
method: str = "GRAD",
mode: str = "time",
normalize:bool=True,
) -> None:
"""
Arguments:
Expand All @@ -35,11 +36,13 @@ def __init__(
NumFeatures int: number of features.
method str: Saliency Method to be used.
mode str: Second dimension 'time'->`(1,time,feat)` or 'feat'->`(1,feat,time)`.
normalize bool: Wheather or not to normalize the results
"""
super().__init__(model, mode)
self.NumTimeSteps = NumTimeSteps
self.NumFeatures = NumFeatures
self.method = method
self.normalize=normalize

def explain(self):
raise NotImplementedError("Don't use the base CF class directly")
Expand All @@ -56,6 +59,7 @@ def plot(self, item, exp, figsize=(6.4, 4.8), heatmap=False, save=None):
save str: Path to save figure.
"""
plt.style.use("classic")
print(self.normalize)
i = 0
if self.mode == "time":
print("time mode")
Expand All @@ -75,8 +79,8 @@ def plot(self, item, exp, figsize=(6.4, 4.8), heatmap=False, save=None):
cbar=True,
ax=ax011,
yticklabels=False,
vmin=0,
vmax=1,
#vmin=0,
#vmax=1,
)
elif len(item[0]) == 1:
# if only onedimensional input
Expand All @@ -85,15 +89,24 @@ def plot(self, item, exp, figsize=(6.4, 4.8), heatmap=False, save=None):
)
# cbar_ax = fig.add_axes([.91, .3, .03, .4])
axn012 = axn.twinx()
sns.heatmap(
exp.reshape(1, -1),
fmt="g",
cmap="viridis",
ax=axn,
yticklabels=False,
vmin=0,
vmax=1,
)
if self.normalize:
sns.heatmap(
exp.reshape(1, -1),
fmt="g",
cmap="viridis",
ax=axn,
yticklabels=False,
vmin=0,
vmax=1,
)
else:
sns.heatmap(
exp.reshape(1, -1),
fmt="g",
cmap="viridis",
ax=axn,
yticklabels=False,
)
sns.lineplot(
x=range(0, len(item[0][0].reshape(-1))),
y=item[0][0].flatten(),
Expand All @@ -114,18 +127,30 @@ def plot(self, item, exp, figsize=(6.4, 4.8), heatmap=False, save=None):
# ax012.append(ax011[i].twinx())
# ax011[i].set_facecolor("#440154FF")
axn012 = axn[i].twinx()
if self.normalize:

sns.heatmap(
exp[i].reshape(1, -1),
fmt="g",
cmap="viridis",
cbar=i == 0,
cbar_ax=None if i else cbar_ax,
ax=axn[i],
yticklabels=False,
vmin=0,
vmax=1,
)
else:
sns.heatmap(
exp[i].reshape(1, -1),
fmt="g",
cmap="viridis",
cbar=i == 0,
cbar_ax=None if i else cbar_ax,
ax=axn[i],
yticklabels=False,
)

sns.heatmap(
exp[i].reshape(1, -1),
fmt="g",
cmap="viridis",
cbar=i == 0,
cbar_ax=None if i else cbar_ax,
ax=axn[i],
yticklabels=False,
vmin=0,
vmax=1,
)
sns.lineplot(
x=range(0, len(channel.reshape(-1))),
y=channel.flatten(),
Expand Down
Loading

0 comments on commit 7b675dd

Please sign in to comment.