From 5f19bc5bc79ebb801f1fbcc0499485e69206cd66 Mon Sep 17 00:00:00 2001 From: Niall <89581219+nialldevlin1@users.noreply.github.com> Date: Mon, 8 Jan 2024 18:17:07 +0000 Subject: [PATCH] KendallTauCorrelation datetime support (#158) * cast datetime to int in kt correlation --- src/insight/metrics/metrics.py | 14 ++++++++++++++ tests/test_metrics/test_metrics.py | 2 ++ 2 files changed, 16 insertions(+) diff --git a/src/insight/metrics/metrics.py b/src/insight/metrics/metrics.py index e68f1f57..9f4d1cc8 100644 --- a/src/insight/metrics/metrics.py +++ b/src/insight/metrics/metrics.py @@ -1,4 +1,5 @@ """This module contains various metrics used across synthesized.""" +import datetime as dt import typing as ty import numpy as np @@ -75,6 +76,13 @@ def check_column_types(cls, sr_a: pd.Series, sr_b: pd.Series, check: Check = Col return False return True + @staticmethod + def _check_dtype(sr, func) -> bool: + for val in sr: + if pd.notna(val) and not func(val): + return False + return True + def _compute_metric(self, sr_a: pd.Series, sr_b: pd.Series): """Calculate the metric. @@ -85,6 +93,12 @@ def _compute_metric(self, sr_a: pd.Series, sr_b: pd.Series): Returns: The Kendall Tau coefficient between sr_a and sr_b. """ + + if self._check_dtype(sr_a, lambda x: isinstance(x, dt.datetime)): + sr_a = sr_a.astype("int") + if self._check_dtype(sr_b, lambda x: isinstance(x, dt.datetime)): + sr_b = sr_b.astype("int") + if hasattr(sr_a, "cat") and sr_a.cat.ordered: sr_a = sr_a.cat.codes diff --git a/tests/test_metrics/test_metrics.py b/tests/test_metrics/test_metrics.py index 4bedf2ab..26a0e243 100644 --- a/tests/test_metrics/test_metrics.py +++ b/tests/test_metrics/test_metrics.py @@ -214,6 +214,7 @@ def test_kt_correlation(): sr_f = pd.Series( list("feeddd"), dtype=pd.CategoricalDtype(categories=list("fed"), ordered=True) ) + sr_g = pd.to_datetime(pd.Series(np.random.normal(0, 1, 5), name="g")) kt_corr = KendallTauCorrelation() @@ -221,6 +222,7 @@ def test_kt_correlation(): assert kt_corr(sr_b, sr_c) is not None assert kt_corr(sr_c, sr_d) is None assert kt_corr(sr_e, sr_f) == 1.0 + assert kt_corr(sr_g, sr_g) is not None def test_cramers_v_basic():