Skip to content

Commit

Permalink
Add HashTrieMap.fromkeys with dict.fromkeys' signature.
Browse files Browse the repository at this point in the history
  • Loading branch information
Julian committed Dec 28, 2023
1 parent 3e9fe5e commit 4fe2bf8
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,23 @@ impl HashTrieMapPy {
}
}

#[classmethod]
fn fromkeys(
_cls: &PyType,
keys: &PyAny,
val: Option<&PyAny>,
py: Python,
) -> PyResult<HashTrieMapPy> {
let mut inner = HashTrieMap::new_sync();
let none = py.None();
let value = val.unwrap_or_else(|| none.as_ref(py));
for each in keys.iter()? {
let key = Key::extract(each?)?.to_owned();
inner.insert_mut(key, value.into());
}
Ok(HashTrieMapPy { inner })
}

fn get(&self, key: Key, default: Option<PyObject>) -> Option<PyObject> {
if let Some(value) = self.inner.get(&key) {
Some(value.to_owned())
Expand Down
24 changes: 24 additions & 0 deletions tests/test_hash_trie_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,27 @@ def test_get():
assert m1.get("foo") == "bar"
assert m1.get("baz") is None
assert m1.get("spam", "eggs") == "eggs"


def test_fromkeys():
keys = list(range(10))
got = HashTrieMap.fromkeys(keys)
expected = HashTrieMap((i, None) for i in keys)
assert got == HashTrieMap(dict.fromkeys(keys)) == expected


def test_fromkeys_explicit_value():
keys = list(range(10))
expected = HashTrieMap((i, "foo") for i in keys)
got = HashTrieMap.fromkeys(keys, "foo")
expected = HashTrieMap((i, "foo") for i in keys)
assert got == HashTrieMap(dict.fromkeys(keys, "foo")) == expected


def test_fromkeys_explicit_value_not_copied():
keys = list(range(5))

got = HashTrieMap.fromkeys(keys, [])
got[3].append(1)

assert got == HashTrieMap((i, [1]) for i in keys)

0 comments on commit 4fe2bf8

Please sign in to comment.