From 2812a90fd32877c7d6d725bb44e11626a28fa520 Mon Sep 17 00:00:00 2001 From: Tim Adams Date: Mon, 11 Nov 2024 15:07:15 +0100 Subject: [PATCH] feat: add persistence to shap plots and add unit test --- syndat/visualization.py | 25 ++++++++++++++++++++++--- tests/test_visualization.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 tests/test_visualization.py diff --git a/syndat/visualization.py b/syndat/visualization.py index 8d7c42c..fe91c50 100644 --- a/syndat/visualization.py +++ b/syndat/visualization.py @@ -69,8 +69,18 @@ def plot_correlations(real: pandas.DataFrame, synthetic: pandas.DataFrame, store fig = ax.get_figure() fig.savefig(store_destination + "/" + names[idx] + '.png', bbox_inches="tight") -def plot_shap_discrimination(real: pandas.DataFrame, synthetic: pandas.DataFrame) -> None: - # Assuming 'real' and 'synthetic_no_dp' are your datasets and are pandas DataFrames +def plot_shap_discrimination(real: pd.DataFrame, synthetic: pd.DataFrame, save_path: str = None) -> None: + """ + Generates a SHAP summary plot to illustrate the discrimination between real and synthetic datasets + using a Random Forest classifier. + + :param real: The real data + :param synthetic: The synthetic data + :param save_path: Path to the file where the resulting plot should be saved. If None, the plot will not be saved. + + :return: None + """ + # Assuming 'real' and 'synthetic' are your datasets and are pandas DataFrames # Add a label column to each dataset real['label'] = 1 synthetic['label'] = 0 @@ -99,7 +109,16 @@ def plot_shap_discrimination(real: pandas.DataFrame, synthetic: pandas.DataFrame shap_values = explainer.shap_values(X_test) # Plot SHAP summary - shap.summary_plot(shap_values[1], X_test) + plt.figure() + shap.summary_plot(shap_values[1], X_test, show=False) + + # Save the plot if save_path is specified + if save_path: + plt.savefig(save_path, bbox_inches='tight') + print(f"Plot saved to {save_path}") + + # Show the plot + plt.show() def plot_categorical_feature(feature: str, real_data: pandas.DataFrame, synthetic_data: pandas.DataFrame) -> None: diff --git a/tests/test_visualization.py b/tests/test_visualization.py new file mode 100644 index 0000000..b69c754 --- /dev/null +++ b/tests/test_visualization.py @@ -0,0 +1,30 @@ +import unittest +import pandas as pd +import numpy as np +import os + +from syndat import plot_shap_discrimination + + +class TestPlotShapDiscrimination(unittest.TestCase): + + def setUp(self): + # Create sample data for testing + self.real = pd.DataFrame(np.random.normal(size=(100, 5)), columns=[f"feature_{i}" for i in range(5)]) + self.synthetic = pd.DataFrame(np.random.normal(size=(100, 5)), columns=[f"feature_{i}" for i in range(5)]) + + # Define the path where the plot will be temporarily saved + self.save_path = "test_shap_plot.png" + + def test_plot_shap_discrimination(self): + # Call the function with test data and save_path + plot_shap_discrimination(self.real, self.synthetic, save_path=self.save_path) + + # Check if the plot file was created + self.assertTrue(os.path.exists(self.save_path), "SHAP plot file was not created.") + + def tearDown(self): + # Remove the file if it exists after the test + if os.path.exists(self.save_path): + os.remove(self.save_path) +