Skip to content

Commit

Permalink
Add test checking input config is unchanged
Browse files Browse the repository at this point in the history
  • Loading branch information
jesszzzz committed Jan 6, 2025
1 parent 9dedc40 commit a29f35e
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/instantiate/test_instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,10 @@ def test_class_instantiate(
recursive: bool,
) -> Any:
passthrough["_recursive_"] = recursive
original_config_str = str(config)
obj = instantiate_func(config, **passthrough)
assert partial_equal(obj, expected)
assert str(config) == original_config_str


def test_partial_with_missing(instantiate_func: Any) -> Any:
Expand All @@ -431,10 +433,12 @@ def test_partial_with_missing(instantiate_func: Any) -> Any:
"b": 20,
"c": 30,
}
original_config_str = str(config)
partial_obj = instantiate_func(config)
assert partial_equal(partial_obj, partial(AClass, b=20, c=30))
obj = partial_obj(a=10)
assert partial_equal(obj, AClass(a=10, b=20, c=30))
assert str(config) == original_config_str


def test_instantiate_with_missing(instantiate_func: Any) -> Any:
Expand Down Expand Up @@ -468,6 +472,7 @@ def test_none_cases(
ListConfig(None),
],
}
original_config_str = str(cfg)
ret = instantiate_func(cfg)
assert ret.kwargs["none_dict"] is None
assert ret.kwargs["none_list"] is None
Expand All @@ -477,6 +482,7 @@ def test_none_cases(
assert ret.kwargs["list"][0] == 10
assert ret.kwargs["list"][1] is None
assert ret.kwargs["list"][2] is None
assert str(cfg) == original_config_str


@mark.parametrize(
Expand Down Expand Up @@ -537,6 +543,20 @@ def test_none_cases(
6,
id="interpolation_from_recursive",
),
param(
{
"my_id": 5,
"node": {
"b": "${foo_b}",
},
"foo_b": {
"unique_id": "${my_id}",
},
},
{},
OmegaConf.create({"b": {"unique_id": 5}}),
id="interpolation_from_parent_with_interpolation",
),
],
)
def test_interpolation_accessing_parent(
Expand All @@ -547,12 +567,14 @@ def test_interpolation_accessing_parent(
) -> Any:
cfg_copy = OmegaConf.create(input_conf)
input_conf = OmegaConf.create(input_conf)
original_config_str = str(input_conf)
obj = instantiate_func(input_conf.node, **passthrough)
if isinstance(expected, partial):
assert partial_equal(obj, expected)
else:
assert obj == expected
assert input_conf == cfg_copy
assert str(input_conf) == original_config_str


@mark.parametrize(
Expand Down

0 comments on commit a29f35e

Please sign in to comment.