Skip to content

Commit

Permalink
Add code & update readme + workflows
Browse files Browse the repository at this point in the history
  • Loading branch information
tiadams committed Dec 21, 2023
1 parent 2d8f2b2 commit 91259f6
Show file tree
Hide file tree
Showing 10 changed files with 346 additions and 1 deletion.
40 changes: 40 additions & 0 deletions .github/workflows/python-package.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Python package

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
43 changes: 43 additions & 0 deletions .github/workflows/python-publish.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries

# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.

name: Upload Python Package

on:
release:
types: [published]

permissions:
contents: read

jobs:
deploy:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Update package version
run: |
VERSION=${{ github.event.release.tag_name }}
sed -i "s/version='.*'/version='${VERSION}'/" setup.py
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
44 changes: 43 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,44 @@
# syndat
# Syndat
Synthetic data quality evaluation & visualization

# Installation

Install via pip:

```bash
pip install syndat
```

# Usage

## Quality metrics

Compute data quality metrics by comparing real and synthetic data in terms of their separation complexity,
distribution similarity or pairwise feature correlations:

```python
import pandas as pd
from syndat import quality

real = pd.read_csv("real.csv")
synthetic = pd.read_csv("synthetic.csv")

jsd = quality.get_jsd(real, synthetic)
auc = quality.get_auc(real, synthetic)
norm = quality.get_norm_score(real, synthetic)
```

## Visualization

Visualize real vs. synthetic data distributions and summary statistics for each feature:

```python
import pandas as pd
from syndat import visualization

real = pd.read_csv("real.csv")
synthetic = pd.read_csv("synthetic.csv")

visualization.plot_distributions(real, synthetic, store_destination="results/plots")
```

8 changes: 8 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
pandas~=2.1.4
numpy~=1.26.2
scipy~=1.11.4
scikit-learn~=1.3.2
matplotlib~=3.8.2
plotly~=5.18.0
seaborn~=0.13.0
setuptools==69.0.2
12 changes: 12 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from setuptools import setup

setup(
name='syndat',
version='0.0.1',
packages=['syndat'],
url='https://github.com/SCAI-BIO/syndat',
license='CC BY-NC-ND 4.0.',
author='Tim Adams',
author_email='[email protected]',
description='A library for evaluation & visualization of synthetic data'
)
Empty file added syndat/__init__.py
Empty file.
35 changes: 35 additions & 0 deletions syndat/domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from enum import Enum


class ColumnType:
def __init__(self, name, datatype, options, minval, maxval):
self.name = name
self.datatype = datatype
self.options = options
self.minval = minval
self.maxval = maxval


class ColumnConstraint:
def __init__(self, name, minval, maxval, category):
self.name = name
self.minval = minval
self.maxval = maxval
self.category = category


class OutlierPredictionMode(Enum):
isolationForest = "isolation_forest"
local_outlier_factor = "local_outlier_factor"


class NaNHandlingStrategy(Enum):
accept_inbalance = "accept_inbalance"
sample_random = "sample_random"
sample_closest = "sample_closest"
encode_nan = "encode_nan"


class AggregationMethod(Enum):
AVERAGE = "average"
MEDIAN = "median"
82 changes: 82 additions & 0 deletions syndat/quality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pandas
import pandas as pd
import numpy as np
import scipy.spatial.distance

from sklearn import ensemble, neighbors
from sklearn.model_selection import cross_val_score

from syndat.domain import OutlierPredictionMode, AggregationMethod


def get_auc(real: pandas.DataFrame, synthetic: pandas.DataFrame, n_folds=10):
x = pd.concat([real, synthetic])
y = np.concatenate((np.zeros(real.shape[0]), np.ones(synthetic.shape[0])), axis=None)
rfc = ensemble.RandomForestClassifier()
return np.average(cross_val_score(rfc, x, y, cv=n_folds, scoring='roc_auc'))


def get_jsd(real: pandas.DataFrame, synthetic: pandas.DataFrame, aggregate_results: bool = True,
aggregation_method: AggregationMethod = AggregationMethod.AVERAGE):
# load datasets & remove id column
jsd_dict = {}
for col in real:
# delete empty cells
real_wo_missing = real[col].dropna()
# binning
if np.sum(real[col].values) % 1 == 0 and np.sum(synthetic[col].values) % 1 == 0:
# categorical column
real_binned = np.bincount(real[col])
virtual_binned = np.bincount(synthetic[col])
else:
# get optimal amount of bins
n_bins = np.histogram_bin_edges(real_wo_missing, bins='auto')
real_binned = np.bincount(np.digitize(real_wo_missing, n_bins))
virtual_binned = np.bincount(np.digitize(synthetic[col], n_bins))
# one array might be shorter here then the other, e.g. if real patients contain the categorical
# encoding 0-3, but virtual patients only contain 0-2
# in this case -> fill missing bin with zero
if len(real_binned) != len(virtual_binned):
padding_size = np.abs(len(real_binned) - len(virtual_binned))
if len(real_binned) > len(virtual_binned):
virtual_binned = np.pad(virtual_binned, (0, padding_size))
else:
real_binned = np.pad(real_binned, (0, padding_size))
# compute jsd
jsd = scipy.spatial.distance.jensenshannon(real_binned, virtual_binned)
jsd_dict[col] = jsd
if aggregate_results and aggregation_method == AggregationMethod.AVERAGE:
return np.mean(np.array(list(jsd_dict.values())))
elif aggregate_results and aggregation_method == AggregationMethod.MEDIAN:
return np.median(np.array(list(jsd_dict.values())))
else:
return jsd_dict


def get_norm_score(real: pandas.DataFrame, synthetic: pandas.DataFrame):
corr_real = real.corr()
corr_synthetic = synthetic.corr()
norm_diff = np.linalg.norm(corr_real - corr_synthetic)
norm_real = np.linalg.norm(corr_real)
norm_quotient = norm_diff / norm_real
return norm_quotient


def get_outliers(synthetic: pd.DataFrame, mode: OutlierPredictionMode = OutlierPredictionMode.isolationForest,
anomaly_score: bool = False):
if mode == OutlierPredictionMode.isolationForest:
model = ensemble.IsolationForest(random_state=42)
return outlier_predictions(model, anomaly_score, x=synthetic)
elif mode == OutlierPredictionMode.local_outlier_factor:
model = neighbors.LocalOutlierFactor(n_neighbors=2)
return outlier_predictions(model, anomaly_score, x=synthetic)


def outlier_predictions(model, anomaly_score, x):
if anomaly_score:
model.fit(x)
return model.score_samples(X=x) * -1
else:
predictions = model.fit_predict(X=x)
outliers_idx = np.array(np.where(predictions == -1))[0]
return outliers_idx
83 changes: 83 additions & 0 deletions syndat/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pandas
import matplotlib
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import numpy as np
import pandas as pd
import seaborn as sns
from pandas.plotting import table

from sklearn.manifold import TSNE
from syndat.quality import get_outliers


def get_tsne_plot_data(real: pandas.DataFrame, synthetic: pandas.DataFrame):
x = pandas.concat([real, synthetic])
perplexity = 30
if real.shape[1] < 30:
perplexity = real.shape[1] - 1
tsne = TSNE(n_components=2, perplexity=perplexity, random_state=42)
tsne_result = tsne.fit_transform(x)
border = real.shape[0]
x_real = tsne_result[:border, 0]
y_real = tsne_result[:border, 1]
x_virtual = tsne_result[border:, 0]
y_virtual = tsne_result[border:, 1]
return x_real, y_real, x_virtual, y_virtual


def show_outlier_plot(real: pandas.DataFrame, synthetic: pandas.DataFrame):
trace_real, trace_virtual = get_tsne_plot_data(real, synthetic)
fig = go.Figure()
fig.add_trace(go.Scatter(x=trace_real["x"], y=trace_real["y"], mode="markers", name='real'))
fig.add_trace(go.Scatter(x=trace_virtual['x'], y=trace_virtual['y'], mode="markers", name='synthetic'))
# Add outlier markings for virtual patients
outliers = get_outliers(synthetic)
for outlier in outliers:
x0 = trace_virtual['x'][outlier] - 0.5
y0 = trace_virtual['y'][outlier] - 0.5
x1 = trace_virtual['x'][outlier] + 0.5
y1 = trace_virtual['y'][outlier] + 0.5
fig.add_shape(type="circle", xref="x", yref="y", x0=x0, y0=y0, x1=x1, y1=y1, line_color="LightSeaGreen")
# display
fig.show()


def plot_distributions(real: pandas.DataFrame, synthetic: pandas.DataFrame, store_destination: str):
for column_name in real.columns:
matplotlib.use('Agg')
real_col = real[column_name].to_numpy()
virtual_col = synthetic[column_name].to_numpy()
plt.figure()
plt.title(column_name)
patient_types = np.concatenate([np.zeros(real_col.size), np.ones(virtual_col.size)])
df = pd.DataFrame(data={"type": np.where(patient_types == 0, "real", "synthetic"),
"value": np.concatenate([real_col, virtual_col])})
if real_col.dtype == str or real_col.dtype == object:
ax = sns.countplot(data=df, x="value", hue="type", order=df['value'].value_counts().index)
elif np.sum(real_col) % 1 == 0 and np.max(real_col) < 10:
ax = sns.countplot(data=df, x="value", hue="type")
else:
df = pd.DataFrame(data={"real": real_col, "synthetic": virtual_col})
ax = sns.violinplot(data=df)
# remove y-labels as they are redundant with the table headers
ax.set_xticks([])
table(ax, df.describe().round(2), loc='bottom', colLoc='center', bbox=[0, -0.55, 1, 0.5],
colWidths=[.5, .5])
fig = ax.get_figure()
matplotlib.pyplot.close()
fig.savefig(store_destination + "/" + column_name + '.png', bbox_inches="tight")


def create_correlation_plots(real_patients, virtual_patients, store_destination):
names = ["dec_rp", "dec_vp"]
for idx, patient_type in enumerate([real_patients, virtual_patients]):
plt.figure()
plt.title("Correlation")
ax = sns.heatmap(patient_type.corr())
fig = ax.get_figure()
fig.savefig(store_destination + "/" + names[idx] + '.png', bbox_inches="tight")




Empty file added tests/__init__.py
Empty file.

0 comments on commit 91259f6

Please sign in to comment.