Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
malcolmgreaves committed Mar 9, 2022
1 parent ac4379e commit 00d894b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 26 deletions.
2 changes: 1 addition & 1 deletion core_utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def deserialize(
# didn't deserialize to its expected type
else:
fail = True

if fail:
raise FieldDeserializeFail(
field_name="", expected_type=type_value, actual_value=value
Expand Down
13 changes: 6 additions & 7 deletions tests/test_custom_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from dataclasses import dataclass
from typing import Tuple, NamedTuple, Mapping, Sequence
from typing import Tuple, NamedTuple, Mapping, Sequence, Dict, List

import numpy as np
import torch
Expand Down Expand Up @@ -154,7 +154,7 @@ def check(*, actual, expected):
def test_nested_array_dict_int_keys(custom_serialize, custom_deserialize):
N = 4
M = 3

def check(*, actual, expected):
assert isinstance(actual, type(expected))
assert len(actual) == N
Expand All @@ -164,9 +164,8 @@ def check(*, actual, expected):
assert isinstance(i, int)
_check_array_like(actual=arr, expected=np.ones(i))

m: List[List[Dict[int, np.ndarray]]] = [[
[{i:np.ones(i) for i in range(M)]
for _ in range(N)
]]
m: List[List[Dict[int, np.ndarray]]] = [
[{i: np.ones(i)} for i in range(M)] for _ in range(N)
]

_roundtrip(m, custom_serialize, custom_deserialize, check)
_roundtrip(m, custom_serialize, custom_deserialize, check)
46 changes: 28 additions & 18 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@
from collections import namedtuple
from dataclasses import dataclass
from enum import Enum, IntEnum
from typing import NamedTuple, Sequence, Optional, Mapping, Set, Tuple, Union, Any, List, Dict
from typing import (
NamedTuple,
Sequence,
Optional,
Mapping,
Set,
Tuple,
Union,
Any,
List,
Dict,
)

from pytest import raises, fixture
import yaml
Expand Down Expand Up @@ -494,21 +505,20 @@ def test_serialize_none_special_cases_mapping():
assert deserialize(Mapping[str, Optional[int]], s) == m_empty
assert deserialize(Mapping[str, Optional[int]], serialize(m_empty)) == {}


def test_serialize_dict_with_numeric_keys():
d1: Dict[int, List[str]] = {
i: [x for x in 'hello world! how are you today?'] for i in range(10)
}
s1 = serialize(d1)
assert deserialize(Dict[int, List[str]], s1) == d1

j1 = json.dumps(s1)
assert deserialize(Dict[int, List[str]], json.loads(j1)) == d1

d2: Dict[float, List[str]] = {
float(i): xs for i, xs in d1.items()
}
s2 = serialize(d2)
assert deserialize(Dict[int, List[str]], s2) == d2

j2 = json.dumps(s2)
assert deserialize(Dict[int, List[str]], json.loads(j2)) == d2
d1: Dict[int, List[str]] = {
i: [x for x in "hello world! how are you today?"] for i in range(10)
}
s1 = serialize(d1)
assert deserialize(Dict[int, List[str]], s1) == d1

j1 = json.dumps(s1)
assert deserialize(Dict[int, List[str]], json.loads(j1)) == d1

d2: Dict[float, List[str]] = {float(i): xs for i, xs in d1.items()}
s2 = serialize(d2)
assert deserialize(Dict[int, List[str]], s2) == d2

j2 = json.dumps(s2)
assert deserialize(Dict[int, List[str]], json.loads(j2)) == d2

0 comments on commit 00d894b

Please sign in to comment.