Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/SCAI-BIO/syndat
Browse files Browse the repository at this point in the history
  • Loading branch information
tiadams committed Sep 27, 2024
2 parents 95915bc + a5076cd commit 63ad745
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 57 deletions.
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,18 @@ import syndat
real = pd.read_csv("real.csv")
synthetic = pd.read_csv("synthetic.csv")

jsd = syndat.quality.jsd(real, synthetic)
auc = syndat.quality.auc(real, synthetic)
norm = syndat.quality.correlation(real, synthetic)
# How similar are the statistical distributions of real and synthetic features
distribution_similarity_score = syndat.scores.distribution(real, synthetic)

# How hard is it for a classifier to discriminate real and synthetic data
discrimination_score = syndat.scores.discrimination(real, synthetic)

# How well are pairwise feature correlations preserved
correlation_score = syndat.scores.correlation(real, synthetic)
```

Scores are defined in a range of 0-100, with a higher score corresponding to better data fidelity.

## Visualization

Visualize real vs. synthetic data distributions and summary statistics for each feature:
Expand Down
4 changes: 2 additions & 2 deletions syndat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from syndat import domain
from syndat import quality
from syndat import visualization
from syndat import scores
from syndat import visualization
63 changes: 42 additions & 21 deletions syndat/quality.py → syndat/scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@

from syndat.domain import AggregationMethod

logger = logging.getLogger(__name__)


def auc(real: pandas.DataFrame, synthetic: pandas.DataFrame, n_folds=5,
drop_na_threshold=0.9, score: bool = True) -> float:
"""
Computes the Differentiation Complexity Score / ROC AUC score of a classifier trained to differentiate between real
Computes the Discrimination Complexity Score / ROC AUC score of a classifier trained to differentiate between real
and synthetic data.
:param real: The real data.
Expand All @@ -27,18 +29,43 @@ def auc(real: pandas.DataFrame, synthetic: pandas.DataFrame, n_folds=5,
:param score: Return result in a normalized score in [0,100]. Default is True.
:return: Differentiation Complexity Score / AUC ROC Score
"""

warnings.warn(
"old_function is deprecated and will be removed in a future version. Please use discrimination_score instead.",
"auc is deprecated and will be removed in a future version. Please use discrimination instead.",
DeprecationWarning,
stacklevel=2
)
return discrimination_score(real, synthetic, n_folds=n_folds, drop_na_threshold=drop_na_threshold, score=score)
return discrimination(real, synthetic, n_folds=n_folds, drop_na_threshold=drop_na_threshold, score=score)


def discrimination_score(real: pandas.DataFrame, synthetic: pandas.DataFrame, n_folds=5,
drop_na_threshold=0.9, score: bool = True) -> float:
def jsd(real: pd.DataFrame, synthetic: pd.DataFrame, aggregate_results: bool = True,
aggregation_method: AggregationMethod = AggregationMethod.AVERAGE, score: bool = True,
n_unique_threshold=10) -> Union[List[float], float]:
"""
Computes the Differentiation Complexity Score / ROC AUC score of a classifier trained to differentiate between real
Computes the feature distribution similarity using the Jensen-Shannon distance of real and synthetic data.
:param real: The real data.
:param synthetic: The synthetic data.
:param aggregate_results: Compute a single aggregated score for all features. Default is True.
:param aggregation_method: How the scores are aggregated. Default is using the median of all feature scores.
:param score: Return result in a normalized score in [0,100]. Default is True.
:param n_unique_threshold: Threshold to determine at which number of unique values bins will span over several
values.
:return: Distribution Similarity / JSD
"""

warnings.warn(
"auc is deprecated and will be removed in a future version. Please use discrimination instead.",
DeprecationWarning,
stacklevel=2
)
return distribution(real, synthetic, aggregate_results, aggregation_method, score, n_unique_threshold)


def discrimination(real: pandas.DataFrame, synthetic: pandas.DataFrame, n_folds=5,
drop_na_threshold=0.9, score: bool = True) -> float:
"""
Computes the Discrimination Complexity Score / ROC AUC score of a classifier trained to differentiate between real
and synthetic data.
:param real: The real data.
Expand All @@ -52,10 +79,10 @@ def discrimination_score(real: pandas.DataFrame, synthetic: pandas.DataFrame, n_
real_filtered, synthetic_filtered = __filter_rows_with_common_categories(real, synthetic)
# check for missing values in real data
real_clean = real_filtered.dropna(thresh=int(drop_na_threshold * len(real_filtered)), axis=1)
logging.info(f'Dropped {real_clean.shape[1] - real_clean.shape[1]} '
logger.info(f'Dropped {real_clean.shape[1] - real_clean.shape[1]} '
f'due to high missingness (threshold is {drop_na_threshold}).')
real_clean = real_clean.dropna()
logging.info(f'Removed {len(real) - len(real_clean)} entries due to missing values.')
logger.info(f'Removed {len(real) - len(real_clean)} entries due to missing values.')
# assert that both real and synthetic have same columns
synthetic_clean = synthetic_filtered[real_clean.columns]
# one-hot-encode categorical columns
Expand All @@ -74,9 +101,9 @@ def discrimination_score(real: pandas.DataFrame, synthetic: pandas.DataFrame, n_
return auc_score


def jsd(real: pd.DataFrame, synthetic: pd.DataFrame, aggregate_results: bool = True,
aggregation_method: AggregationMethod = AggregationMethod.AVERAGE, score: bool = True,
n_unique_threshold=10) -> Union[List[float], float]:
def distribution(real: pd.DataFrame, synthetic: pd.DataFrame, aggregate_results: bool = True,
aggregation_method: AggregationMethod = AggregationMethod.AVERAGE, score: bool = True,
n_unique_threshold=10) -> Union[List[float], float]:
"""
Computes the feature distribution similarity using the Jensen-Shannon distance of real and synthetic data.
Expand All @@ -100,7 +127,7 @@ def jsd(real: pd.DataFrame, synthetic: pd.DataFrame, aggregate_results: bool = T
col_dtype_real = real[col].dtype
col_dtype_synthetic = synthetic[col].dtype
if col_dtype_real != col_dtype_synthetic:
logging.warning(f'Real data at col {col} is dtype {col_dtype_real} but synthetic is {col_dtype_synthetic}. '
logger.warning(f'Real data at col {col} is dtype {col_dtype_real} but synthetic is {col_dtype_synthetic}. '
f'Evaluation will be done based on the assumed data type of the real data.')
synthetic[col] = synthetic[col].astype(col_dtype_real)
# categorical column
Expand Down Expand Up @@ -216,21 +243,15 @@ def __filter_rows_with_common_categories(real: pd.DataFrame, synthetic: pd.DataF
synthetic_categorical_cols = synthetic.select_dtypes(include=['object', 'category']).columns
# Identify common categorical columns
common_categorical_cols = set(real_categorical_cols) & set(synthetic_categorical_cols)
if not common_categorical_cols:
logging.warning("No common categorical columns found. Correlation will be computed on numeric data only.")
# Filter rows with common categories in each column
for col in common_categorical_cols:
real_categories = set(real[col].unique())
synthetic_categories = set(synthetic[col].unique())
common_categories = real_categories & synthetic_categories
if len(real_categories - common_categories) > 0:
logging.warning(
f"Categories {real_categories - common_categories} in column '{col}' "
f"are in real data but not in synthetic data and will be excluded.")
if len(synthetic_categories - common_categories) > 0:
logging.warning(
f"Categories {synthetic_categories - common_categories} in column '{col}' "
f"are in synthetic data but not in real data and will be excluded.")
logger.warning(
f"Categories {real_categories - common_categories} in column '{col}' are in real data but not in "
f"synthetic data. They will not be considered in the score computation.")
# Filter rows to keep only common categories
real = real[real[col].isin(common_categories)]
synthetic = synthetic[synthetic[col].isin(common_categories)]
Expand Down
12 changes: 6 additions & 6 deletions tests/test_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,41 +43,41 @@ def preprocess_categorical_data(self, real_data, synthetic_data):
return real_data_encoded, synthetic_data_encoded

def test_auc_score(self):
auc_score = syndat.quality.auc(self.real_data, self.synthetic_data)
auc_score = syndat.scores.discrimination(self.real_data, self.synthetic_data)
self.assertTrue(isinstance(auc_score, float))
self.assertGreaterEqual(auc_score, 0.0)
self.assertLessEqual(auc_score, 100.0)

def test_auc_score_normalized(self):
auc_score = syndat.quality.auc(self.real_data, self.synthetic_data)
auc_score = syndat.scores.discrimination(self.real_data, self.synthetic_data)
self.assertTrue(isinstance(auc_score, float))
self.assertGreaterEqual(auc_score, 0.0)
self.assertLessEqual(auc_score, 100.0)

def test_auc_score_with_missing_values(self):
# Introduce missing values in real data
self.real_data.iloc[::10, 0] = np.nan # 10% missing data
auc_score = syndat.quality.auc(self.real_data, self.synthetic_data)
auc_score = syndat.scores.discrimination(self.real_data, self.synthetic_data)
self.assertTrue(isinstance(auc_score, float))
self.assertGreaterEqual(auc_score, 0.0)
self.assertLessEqual(auc_score, 100.0)

def test_auc_score_with_missing_values_drop_col(self):
# Introduce missing values in real data
self.real_data.iloc[::2, 0] = np.nan # 50% missing data -> col drop
auc_score = syndat.quality.auc(self.real_data, self.synthetic_data)
auc_score = syndat.scores.discrimination(self.real_data, self.synthetic_data)
self.assertTrue(isinstance(auc_score, float))
self.assertGreaterEqual(auc_score, 0.0)
self.assertLessEqual(auc_score, 100.0)

def test_auc_score_with_custom_folds(self):
auc_score = syndat.quality.auc(self.real_data, self.synthetic_data, n_folds=5)
auc_score = syndat.scores.discrimination(self.real_data, self.synthetic_data, n_folds=5)
self.assertTrue(isinstance(auc_score, float))
self.assertGreaterEqual(auc_score, 0.0)
self.assertLessEqual(auc_score, 100.0)

def test_auc_score_with_categorical_data(self):
auc_score = syndat.quality.auc(self.real_data_cat, self.synthetic_data_cat)
auc_score = syndat.scores.discrimination(self.real_data_cat, self.synthetic_data_cat)
self.assertTrue(isinstance(auc_score, float))
self.assertGreaterEqual(auc_score, 0.0)
self.assertLessEqual(auc_score, 100.0)
2 changes: 1 addition & 1 deletion tests/test_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pandas as pd

from syndat.quality import correlation
from syndat.scores import correlation


class TestCorrelation(unittest.TestCase):
Expand Down
48 changes: 24 additions & 24 deletions tests/test_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def test_jsd_zero_int64(self):
'feature1': [6, 7, 8, 9, 10],
'feature2': [15, 16, 17, 18, 19]
})
jsd = syndat.quality.jsd(real, synthetic)
self.assertEqual(0, jsd)
distribution = syndat.scores.distribution(real, synthetic)
self.assertEqual(0, distribution)

def test_jsd_zero_int64_float64(self):
synthetic = pd.DataFrame({
Expand All @@ -30,8 +30,8 @@ def test_jsd_zero_int64_float64(self):
'feature1': [6, 7, 8, 9, 10],
'feature2': [0.6, 0.7, 0.8, 0.9, 1.0]
})
jsd = syndat.quality.jsd(real, synthetic)
self.assertEqual(jsd, 0)
distribution = syndat.scores.distribution(real, synthetic)
self.assertEqual(distribution, 0)

def test_jsd_perfect_int64(self):
synthetic = pd.DataFrame({
Expand All @@ -43,8 +43,8 @@ def test_jsd_perfect_int64(self):
'feature1': [1, 2, 1, 2, 3],
'feature2': [11, 12, 13, 14, 15]
})
jsd = syndat.quality.jsd(real, synthetic)
self.assertEqual(jsd, 100)
distribution = syndat.scores.distribution(real, synthetic)
self.assertEqual(distribution, 100)

def test_jsd_perfect_int64_and_float64(self):
synthetic = pd.DataFrame({
Expand All @@ -56,8 +56,8 @@ def test_jsd_perfect_int64_and_float64(self):
'feature1': [1, 2, 1, 2, 3],
'feature2': [0.1, 0.2, 0.3, 0.4, 0.5]
})
jsd = syndat.quality.jsd(real, synthetic)
self.assertEqual(jsd, 100)
distribution = syndat.scores.distribution(real, synthetic)
self.assertEqual(distribution, 100)

def test_jsd_different_col_types(self):
synthetic = pd.DataFrame({
Expand All @@ -69,7 +69,7 @@ def test_jsd_different_col_types(self):
'feature1': [1.2, 2.1, 1.1, 2.1, 3.1],
'feature2': [1, 2, 3, 4, 5]
})
jsd = syndat.quality.jsd(real, synthetic, score=False)
distribution = syndat.scores.distribution(real, synthetic, score=False)

def test_jsd_negative_int64(self):
synthetic = pd.DataFrame({
Expand All @@ -81,7 +81,7 @@ def test_jsd_negative_int64(self):
'feature1': [-1, 2, 3, 4, 5],
'feature2': [1, 2, 3, 4, 5]
})
jsd = syndat.quality.jsd(real, synthetic)
distribution = syndat.scores.distribution(real, synthetic)

def test_jsd_single_outlier(self):
synthetic = pd.DataFrame({
Expand All @@ -93,8 +93,8 @@ def test_jsd_single_outlier(self):
'feature1': [1, 1, 1, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9],
'feature2': [1, 1, 1, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 100],
})
jsd = syndat.quality.jsd(real, synthetic)
self.assertTrue(jsd < 100)
distribution = syndat.scores.distribution(real, synthetic)
self.assertTrue(distribution < 100)

def test_jsd_categorical_equal(self):
synthetic = pd.DataFrame({
Expand All @@ -106,8 +106,8 @@ def test_jsd_categorical_equal(self):
'feature1': ['A', 'B', 'A', 'B', 'C'],
'feature2': ['X', 'Y', 'Y', 'X', 'Z']
})
jsd = syndat.quality.jsd(real, synthetic)
self.assertEqual(jsd, 100)
distribution = syndat.scores.distribution(real, synthetic)
self.assertEqual(distribution, 100)

def test_jsd_categorical_different(self):
synthetic = pd.DataFrame({
Expand All @@ -119,8 +119,8 @@ def test_jsd_categorical_different(self):
'feature1': ['A', 'B', 'A', 'B', 'D'],
'feature2': ['X', 'Y', 'Z', 'X', 'W']
})
jsd = syndat.quality.jsd(real, synthetic)
self.assertTrue(jsd < 100)
distribution = syndat.scores.distribution(real, synthetic)
self.assertTrue(distribution < 100)

def test_jsd_categorical_mixed(self):
synthetic = pd.DataFrame({
Expand All @@ -132,8 +132,8 @@ def test_jsd_categorical_mixed(self):
'feature1': ['A', 'B', 'C', 'F', 'G'],
'feature2': [1.0, 2.0, 3.0, 6.0, 7.0]
})
jsd = syndat.quality.jsd(real, synthetic)
self.assertTrue(jsd < 100)
distribution = syndat.scores.distribution(real, synthetic)
self.assertTrue(distribution < 100)

def test_jsd_categorical_with_numerical(self):
synthetic = pd.DataFrame({
Expand All @@ -145,8 +145,8 @@ def test_jsd_categorical_with_numerical(self):
'feature1': ['A', 'B', 'C', 'A', 'D'],
'feature2': [1.0, 2.0, 3.0, 4.0, 6.0]
})
jsd = syndat.quality.jsd(real, synthetic)
self.assertTrue(jsd < 100)
distribution = syndat.scores.distribution(real, synthetic)
self.assertTrue(distribution < 100)

def test_jsd_categorical_with_nan(self):
synthetic = pd.DataFrame({
Expand All @@ -158,8 +158,8 @@ def test_jsd_categorical_with_nan(self):
'feature1': ['A', 'B', 'C', 'D', None],
'feature2': [1.0, 2.0, None, 4.0, 5.0]
})
jsd = syndat.quality.jsd(real, synthetic)
self.assertTrue(jsd < 100)
distribution = syndat.scores.distribution(real, synthetic)
self.assertTrue(distribution < 100)

def test_jsd_categorical_all_nan(self):
synthetic = pd.DataFrame({
Expand All @@ -171,5 +171,5 @@ def test_jsd_categorical_all_nan(self):
'feature1': [None, None, None, None, None],
'feature2': [None, None, None, None, None]
})
jsd = syndat.quality.jsd(real, synthetic)
self.assertEqual(jsd, 100)
distribution = syndat.scores.distribution(real, synthetic)
self.assertEqual(distribution, 100)

0 comments on commit 63ad745

Please sign in to comment.