diff --git a/src/lib.rs b/src/lib.rs index 65846f1..c8f4a04 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -191,6 +191,23 @@ impl HashTrieMapPy { } } + #[classmethod] + fn fromkeys( + _cls: &PyType, + keys: &PyAny, + val: Option<&PyAny>, + py: Python, + ) -> PyResult { + let mut inner = HashTrieMap::new_sync(); + let none = py.None(); + let value = val.unwrap_or_else(|| none.as_ref(py)); + for each in keys.iter()? { + let key = Key::extract(each?)?.to_owned(); + inner.insert_mut(key, value.into()); + } + Ok(HashTrieMapPy { inner }) + } + fn get(&self, key: Key, default: Option) -> Option { if let Some(value) = self.inner.get(&key) { Some(value.to_owned()) diff --git a/tests/test_hash_trie_map.py b/tests/test_hash_trie_map.py index 9a6b672..9f7d75d 100644 --- a/tests/test_hash_trie_map.py +++ b/tests/test_hash_trie_map.py @@ -346,3 +346,27 @@ def test_get(): assert m1.get("foo") == "bar" assert m1.get("baz") is None assert m1.get("spam", "eggs") == "eggs" + + +def test_fromkeys(): + keys = list(range(10)) + got = HashTrieMap.fromkeys(keys) + expected = HashTrieMap((i, None) for i in keys) + assert got == HashTrieMap(dict.fromkeys(keys)) == expected + + +def test_fromkeys_explicit_value(): + keys = list(range(10)) + expected = HashTrieMap((i, "foo") for i in keys) + got = HashTrieMap.fromkeys(keys, "foo") + expected = HashTrieMap((i, "foo") for i in keys) + assert got == HashTrieMap(dict.fromkeys(keys, "foo")) == expected + + +def test_fromkeys_explicit_value_not_copied(): + keys = list(range(5)) + + got = HashTrieMap.fromkeys(keys, []) + got[3].append(1) + + assert got == HashTrieMap((i, [1]) for i in keys)