Skip to content

Commit

Permalink
ChemEnvSiteFingerprint.from_preset() removal of not-implemented CEs (#…
Browse files Browse the repository at this point in the history
…948)

* Removed ces that are not implemented in chemenv, added tests that previously would have failed (issue #945).

* Fix linting errors.
  • Loading branch information
kaueltzen authored Oct 11, 2024
1 parent ee5747d commit dcbaf06
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
5 changes: 0 additions & 5 deletions matminer/featurizers/site/fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,23 +771,18 @@ def from_preset(preset):
"PA:10",
"SBSA:10",
"MI:10",
"S:10",
"H:10",
"BS_1:10",
"BS_2:10",
"TBSA:10",
"PCPA:11",
"H:11",
"SH:11",
"CO:11",
"DI:11",
"I:12",
"PBP:12",
"TT:12",
"C:12",
"AC:12",
"SC:12",
"S:12",
"HP:12",
"HA:12",
"SH:13",
Expand Down
8 changes: 6 additions & 2 deletions matminer/featurizers/site/tests/test_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,11 @@ def test_crystal_nn_fingerprint(self):

def test_chemenv_site_fingerprint(self):
cefp = ChemEnvSiteFingerprint.from_preset("multi_weights")
implemented_cetypes = {gg.ce_symbol for gg in cefp.lgf.allcg.get_implemented_geometries()}
assert set(cefp.cetypes).difference(implemented_cetypes) == set() # Added after issue #945
l = cefp.feature_labels()
cevals = cefp.featurize(self.sc, 0)
self.assertEqual(len(cevals), 66)
self.assertEqual(len(cevals), 61)
self.assertAlmostEqual(cevals[l.index("O:6")], 1, places=7)
self.assertAlmostEqual(cevals[l.index("C:8")], 0, places=7)
cevals = cefp.featurize(self.cscl, 0)
Expand All @@ -288,12 +290,14 @@ def test_chemenv_site_fingerprint(self):
cefp = ChemEnvSiteFingerprint.from_preset("simple")
l = cefp.feature_labels()
cevals = cefp.featurize(self.sc, 0)
self.assertEqual(len(cevals), 66)
self.assertEqual(len(cevals), 61)
self.assertAlmostEqual(cevals[l.index("O:6")], 1, places=7)
self.assertAlmostEqual(cevals[l.index("C:8")], 0, places=7)
cevals = cefp.featurize(self.cscl, 0)
self.assertAlmostEqual(cevals[l.index("C:8")], 0.9953721, places=7)
self.assertAlmostEqual(cevals[l.index("O:6")], 0, places=7)
cevals = cefp.featurize(self.ni3al, 0) # Added after issue #945
self.assertAlmostEqual(cevals[l.index("I:12")], 0.3401699, places=7)

def test_voronoifingerprint(self):
df_sc = pd.DataFrame({"struct": [self.sc], "site": [0]})
Expand Down

0 comments on commit dcbaf06

Please sign in to comment.