Skip to content

Commit

Permalink
Add convenience function to plot attribution scores
Browse files Browse the repository at this point in the history
Also had to import numpy, although it's unclear why this is suddenly failing. It's unrelated.

Signed-off-by: Peter Goetz <[email protected]>
  • Loading branch information
petergtz committed Nov 16, 2022
1 parent 7568d95 commit a067703
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -320,23 +320,7 @@
},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"def bar_plot_with_uncertainty(median_attribs, uncertainty_attribs, ylabel='Attribution Score', figsize=(8, 3), bwidth=0.8, xticks=None, xticks_rotation=90):\n",
" fig, ax = plt.subplots(figsize=figsize)\n",
" yerr_plus = [uncertainty_attribs[node][1] - median_attribs[node] for node in median_attribs.keys()]\n",
" yerr_minus = [median_attribs[node] - uncertainty_attribs[node][0] for node in median_attribs.keys()]\n",
" plt.bar(median_attribs.keys(), median_attribs.values(), yerr=np.array([yerr_minus, yerr_plus]), ecolor='#1E88E5', color='#ff0d57', width=bwidth)\n",
" plt.xticks(rotation=xticks_rotation)\n",
" plt.ylabel(ylabel)\n",
" ax.spines['right'].set_visible(False)\n",
" ax.spines['top'].set_visible(False)\n",
" if xticks:\n",
" plt.xticks(list(median_attribs.keys()), xticks)\n",
" plt.show()\n",
"\n",
"bar_plot_with_uncertainty(median_attribs, uncertainty_attribs)"
"gcm.util.bar_plot(median_attribs, uncertainty_attribs, 'Attribution Score')"
]
},
{
Expand Down Expand Up @@ -433,6 +417,8 @@
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"median_attribs, uncertainty_attribs = gcm.confidence_intervals(\n",
" lambda : gcm.distribution_change(causal_model,\n",
" normal_data.sample(frac=0.6),\n",
Expand All @@ -441,7 +427,7 @@
" difference_estimation_func=lambda x, y: np.mean(y) - np.mean(x)),\n",
" num_bootstrap_resamples = 10)\n",
"\n",
"bar_plot_with_uncertainty(median_attribs, uncertainty_attribs)"
"gcm.util.bar_plot(median_attribs, uncertainty_attribs, 'Attribution Score')"
]
},
{
Expand Down Expand Up @@ -504,13 +490,13 @@
"outputs": [],
"source": [
"avg_website_latency_before = outlier_data.mean().to_dict()['Website']\n",
"bar_plot_with_uncertainty(dict(before=avg_website_latency_before, after=median_mean_latencies['Website']),\n",
" dict(before=np.array([avg_website_latency_before, avg_website_latency_before]), after=uncertainty_mean_latencies['Website']),\n",
" ylabel='Avg. Website Latency',\n",
" figsize=(3, 2),\n",
" bwidth=0.4,\n",
" xticks=['Before', 'After'],\n",
" xticks_rotation=45)"
"gcm.util.bar_plot(dict(before=avg_website_latency_before, after=median_mean_latencies['Website']),\n",
" dict(before=np.array([avg_website_latency_before, avg_website_latency_before]), after=uncertainty_mean_latencies['Website']),\n",
" ylabel='Avg. Website Latency',\n",
" figure_size=(3, 2),\n",
" bar_width=0.4,\n",
" xticks=['Before', 'After'],\n",
" xticks_rotation=45)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion dowhy/gcm/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .plotting import plot, plot_adjacency_matrix
from .plotting import bar_plot, plot, plot_adjacency_matrix
49 changes: 49 additions & 0 deletions dowhy/gcm/util/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, List, Optional, Tuple

import networkx as nx
import numpy as np
import pandas as pd
from matplotlib import pyplot
from networkx.drawing import nx_pydot
Expand Down Expand Up @@ -166,3 +167,51 @@ def _plot_causal_graph_networkx(

def _calc_arrow_width(strength: float, max_strength: float):
return 0.2 + 4.0 * float(abs(strength)) / float(max_strength)


def bar_plot(
values: Dict[str, float],
uncertainties: Optional[Dict[str, Tuple[float, float]]] = None,
ylabel: str = "",
filename: Optional[str] = None,
display_plot: bool = True,
figure_size: Optional[List[int]] = None,
bar_width: float = 0.8,
xticks: List[str] = None,
xticks_rotation: int = 90,
) -> None:
"""Convenience function to make a bar plot of the given values with uncertainty bars, if provided. Useful for all
kinds of attribution results (including confidence intervals).
:param values: A dictionary where the keys are the labels and the values are the values to be plotted.
:param uncertainties: A dictionary of attributes to be added to the error bars.
:param ylabel: The label for the y-axis.
:param filename: An optional filename if the output should be plotted into a file.
:param display_plot: Optionally specify if the plot should be displayed or not (default to True).
:param figure_size: The size of the figure to be plotted.
:param bar_width: The width of the bars.
:param xticks: Explicitly specify the labels for the bars on the x-axis.
:param xticks_rotation: Specify the rotation of the labels on the x-axis.
"""
if uncertainties is None:
uncertainties = {node: [values[node], values[node]] for node in values}

figure, ax = pyplot.subplots(figsize=figure_size)
ci_plus = [uncertainties[node][1] - values[node] for node in values.keys()]
ci_minus = [values[node] - uncertainties[node][0] for node in values.keys()]
yerr = np.array([ci_minus, ci_plus])
yerr[abs(yerr) < 10**-7] = 0
pyplot.bar(values.keys(), values.values(), yerr=yerr, ecolor="#1E88E5", color="#ff0d57", width=bar_width)
pyplot.ylabel(ylabel)
pyplot.xticks(rotation=xticks_rotation)

ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
if xticks:
pyplot.xticks(list(uncertainties.keys()), xticks)

if display_plot:
pyplot.show()

if filename is not None:
figure.savefig(filename)

0 comments on commit a067703

Please sign in to comment.