Skip to content

Commit

Permalink
KS Distance for continuous values only (#172)
Browse files Browse the repository at this point in the history
* KS distance fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_metrics_usage.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Mark <[email protected]>
  • Loading branch information
3 people authored Feb 14, 2024
1 parent 0959a3b commit 88dac15
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 20 deletions.
2 changes: 0 additions & 2 deletions src/insight/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,6 @@ def check_column_types(
) -> bool:
if check.continuous(sr_a) and check.continuous(sr_b):
return True
if check.categorical(sr_a) and check.categorical(sr_b):
return True
return False

def _compute_metric(self, sr_a: pd.Series, sr_b: pd.Series) -> float:
Expand Down
47 changes: 30 additions & 17 deletions tests/test_metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,37 +292,50 @@ def infer_dtype(self, sr: pd.Series) -> pd.Series:

def test_kolmogorov_smirnov_distance(group1):
# Test with identical distributions
assert kolmogorov_smirnov_distance(pd.Series([1, 2, 3]), pd.Series([1, 2, 3])) == 0
assert kolmogorov_smirnov_distance(group1, group1) == 0
cat_a = pd.Series(["a", "b", "c", "b"], dtype="string")
cat_b = pd.Series(["a", "b", "c", "a"], dtype="string")
cont_a = pd.Series([1, 2, 3, 3], dtype="int64")
cont_b = pd.Series([1, 1, 2, 3], dtype="int64")

# Test with distributions that are completely different
assert kolmogorov_smirnov_distance(pd.Series([1, 1, 1]), pd.Series([2, 2, 2])) == 1
class SimpleCheck(ColumnCheck):
def infer_dtype(self, sr: pd.Series) -> pd.Series:
return sr

def continuous(self, sr: pd.Series) -> bool:
return sr.dtype.kind in ("i", "f")

def categorical(self, sr: pd.Series) -> bool:
return sr.dtype == "string"

ksd = KolmogorovSmirnovDistance(check=SimpleCheck())

# Test with distributions that are slightly different
assert 0 < kolmogorov_smirnov_distance(pd.Series([1, 2, 3]), pd.Series([1, 2, 4])) < 1
assert ksd(cat_a, cat_b) is None
assert ksd(cat_a, cont_b) is None
assert ksd(cont_a, cat_b) is None
assert ksd(cont_a, cont_a) == 0

assert ksd(cont_a, cont_b) == 0.25

# Test with distributions that are completely different
assert ksd(pd.Series([1, 1, 1]), pd.Series([2, 2, 2])) == 1

# Test with random distributions
np.random.seed(0)
group2 = pd.Series(np.random.normal(0, 1, 1000))
group3 = pd.Series(np.random.normal(0.5, 1, 1000))
assert 0 < kolmogorov_smirnov_distance(group2, group3) < 1
assert 0 < ksd(group2, group3) < 1

# Test with distributions of different lengths
assert 0 < kolmogorov_smirnov_distance(pd.Series([1, 2, 3]), pd.Series([1, 2, 3, 4])) < 1

# Test with categorical data
cat1 = pd.Series(["a", "b", "c", "a"])
cat2 = pd.Series(["b", "c", "d"])
assert 0 < kolmogorov_smirnov_distance(cat1, cat2) < 1
assert 0 < ksd(pd.Series([1, 2, 3]), pd.Series([1, 2, 3, 4])) < 1

# Edge cases
# Test with one or both series empty
assert kolmogorov_smirnov_distance(pd.Series([]), pd.Series([1, 2, 3])) == 1
assert kolmogorov_smirnov_distance(pd.Series([1, 2, 3]), pd.Series([])) == 1
assert kolmogorov_smirnov_distance(pd.Series([]), pd.Series([])) == 1
assert ksd(pd.Series([]), pd.Series([1, 2, 3])) == 1
assert ksd(pd.Series([1, 2, 3]), pd.Series([])) == 1
assert ksd(pd.Series([]), pd.Series([])) == 1

# Test with series containing NaN values
assert 0 <= kolmogorov_smirnov_distance(pd.Series([1, np.nan, 3]), pd.Series([1, 2, 3])) <= 1
assert 0 <= ksd(pd.Series([1, np.nan, 3]), pd.Series([1, 2, 3])) <= 1


def test_js_divergence(group1, group2, group3):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_metrics/test_metrics_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_two_column_map_with_ksd(data):

assert col_map.name == expected_column_name
assert set(ksd_map_df.columns.to_list()) == set([expected_column_name])
assert all(not np.isnan(ksd_map_df[expected_column_name][cat]) for cat in categorical_cols)
assert all(np.isnan(ksd_map_df[expected_column_name][cat]) for cat in categorical_cols)
assert all(not np.isnan(ksd_map_df[expected_column_name][cont]) for cont in continuous_cols)


Expand Down

0 comments on commit 88dac15

Please sign in to comment.