From 5106b84f99c544f2168dfc392f5d335d13e17e37 Mon Sep 17 00:00:00 2001 From: kueltzen Date: Mon, 30 Sep 2024 17:28:36 +0200 Subject: [PATCH 1/2] Removed ces that are not implemented in chemenv, added tests that previously would have failed (issue #945). --- matminer/featurizers/site/fingerprint.py | 5 ----- matminer/featurizers/site/tests/test_fingerprint.py | 8 ++++++-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/matminer/featurizers/site/fingerprint.py b/matminer/featurizers/site/fingerprint.py index 9dd3eac57..828b5a1b4 100644 --- a/matminer/featurizers/site/fingerprint.py +++ b/matminer/featurizers/site/fingerprint.py @@ -771,15 +771,11 @@ def from_preset(preset): "PA:10", "SBSA:10", "MI:10", - "S:10", - "H:10", "BS_1:10", "BS_2:10", "TBSA:10", "PCPA:11", "H:11", - "SH:11", - "CO:11", "DI:11", "I:12", "PBP:12", @@ -787,7 +783,6 @@ def from_preset(preset): "C:12", "AC:12", "SC:12", - "S:12", "HP:12", "HA:12", "SH:13", diff --git a/matminer/featurizers/site/tests/test_fingerprint.py b/matminer/featurizers/site/tests/test_fingerprint.py index b8e9ffd03..89681ef5a 100644 --- a/matminer/featurizers/site/tests/test_fingerprint.py +++ b/matminer/featurizers/site/tests/test_fingerprint.py @@ -277,9 +277,11 @@ def test_crystal_nn_fingerprint(self): def test_chemenv_site_fingerprint(self): cefp = ChemEnvSiteFingerprint.from_preset("multi_weights") + implemented_cetypes = set([gg.ce_symbol for gg in cefp.lgf.allcg.get_implemented_geometries()]) + assert set(cefp.cetypes).difference(implemented_cetypes) == set() # Added after issue #945 l = cefp.feature_labels() cevals = cefp.featurize(self.sc, 0) - self.assertEqual(len(cevals), 66) + self.assertEqual(len(cevals), 61) self.assertAlmostEqual(cevals[l.index("O:6")], 1, places=7) self.assertAlmostEqual(cevals[l.index("C:8")], 0, places=7) cevals = cefp.featurize(self.cscl, 0) @@ -288,12 +290,14 @@ def test_chemenv_site_fingerprint(self): cefp = ChemEnvSiteFingerprint.from_preset("simple") l = cefp.feature_labels() cevals = cefp.featurize(self.sc, 0) - self.assertEqual(len(cevals), 66) + self.assertEqual(len(cevals), 61) self.assertAlmostEqual(cevals[l.index("O:6")], 1, places=7) self.assertAlmostEqual(cevals[l.index("C:8")], 0, places=7) cevals = cefp.featurize(self.cscl, 0) self.assertAlmostEqual(cevals[l.index("C:8")], 0.9953721, places=7) self.assertAlmostEqual(cevals[l.index("O:6")], 0, places=7) + cevals = cefp.featurize(self.ni3al, 0) # Added after issue #945 + self.assertAlmostEqual(cevals[l.index("I:12")], 0.3401699, places=7) def test_voronoifingerprint(self): df_sc = pd.DataFrame({"struct": [self.sc], "site": [0]}) From 198bec5078ff9e9a3df832127d6968f880f6dc8d Mon Sep 17 00:00:00 2001 From: kueltzen Date: Mon, 7 Oct 2024 11:21:14 +0200 Subject: [PATCH 2/2] Fix linting errors. --- matminer/featurizers/site/tests/test_fingerprint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matminer/featurizers/site/tests/test_fingerprint.py b/matminer/featurizers/site/tests/test_fingerprint.py index 89681ef5a..e18e3efa1 100644 --- a/matminer/featurizers/site/tests/test_fingerprint.py +++ b/matminer/featurizers/site/tests/test_fingerprint.py @@ -277,7 +277,7 @@ def test_crystal_nn_fingerprint(self): def test_chemenv_site_fingerprint(self): cefp = ChemEnvSiteFingerprint.from_preset("multi_weights") - implemented_cetypes = set([gg.ce_symbol for gg in cefp.lgf.allcg.get_implemented_geometries()]) + implemented_cetypes = {gg.ce_symbol for gg in cefp.lgf.allcg.get_implemented_geometries()} assert set(cefp.cetypes).difference(implemented_cetypes) == set() # Added after issue #945 l = cefp.feature_labels() cevals = cefp.featurize(self.sc, 0)