From 0839ea6c89ab7e4eccce5d65a933df48f9b86f47 Mon Sep 17 00:00:00 2001 From: Tom Barbette Date: Wed, 19 Jun 2024 13:20:26 +0200 Subject: [PATCH] Re-implemented the code to do correlation matrix --- npf/statistics.py | 20 ++++++++++++++++++++ setup.py | 5 +++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/npf/statistics.py b/npf/statistics.py index 9a747ae4..ca903910 100644 --- a/npf/statistics.py +++ b/npf/statistics.py @@ -107,6 +107,26 @@ def printline(n): print(" %s : %.02f " % (vs, tot / n)) print("") + print('') + + ys = np.ndarray(shape = (len(X), len(dataset))) + + for i,d in enumerate(dataset): + ys[:,i] = d[2] + import pandas as pd + df = pd.DataFrame(np.concatenate((X,ys),axis=1),columns=list(vars_values.keys()) + [d[0] if d[0] else "y" for d in dataset]) + print("Correlation matrix:") + corr = df.corr() + corr.style.background_gradient(cmap='viridis') + print(corr) + corr + import seaborn as sn + import matplotlib.pyplot as plt + sn.heatmap(corr, annot=True) + f = npf.build_filename(test, build, filename if not filename is True else None, {}, 'pdf', result_type, show_serie=False, suffix="correlation") + plt.savefig(f) + print(f"Graph of correlation matrix saved to {f}") + @classmethod def buildDataset(cls, all_results: Dataset, test: Test) -> List[tuple]: #map of every diff --git a/setup.py b/setup.py index 61a5556a..0e150444 100644 --- a/setup.py +++ b/setup.py @@ -30,8 +30,9 @@ 'importlib_metadata', 'npf-web-extension >= 0.6.4', 'jinja2', - 'spellwise' - ] + 'spellwise', + 'seaborn' + ] setuptools.setup( name="npf",