Skip to content

Commit

Permalink
Support multi-line evaluation with nest assignment.
Browse files Browse the repository at this point in the history
Include more unit tests of `NestedFrame.eval`.
  • Loading branch information
gitosaurus committed Oct 15, 2024
1 parent 6866a8f commit 6496cac
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 8 deletions.
53 changes: 45 additions & 8 deletions src/nested_pandas/nestedframe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,52 @@ def _constructor_expanddim(self) -> Self: # type: ignore[name-defined] # noqa:
__pandas_priority__ = 3500


class NestResolver:
class NestResolver(dict):
"""
Used by NestedFrame.eval to resolve the names of nested columns when
Used by NestedFrame.eval to resolve the names of nests at the top level.
While the resolver is normally a dictionary, with values that are fixed
upon entering evaluation, this object needs to be dynamic so that it can
support multi-line expressions, where new nests may be created during
evaluation.
"""

def __init__(self, outer: NestedFrame):
self._outer = outer
super().__init__()

def __contains__(self, item):
if not isinstance(item, str):
return False
top_nest = item if "." not in item else item.split(".")[0].strip()
return top_nest in self._outer.nested_columns

def __len__(self):
return len(self._outer.nested_columns)

def __getitem__(self, item):
if not isinstance(item, str):
raise KeyError(f"Unknown nest {item}")
top_nest = item if "." not in item else item.split(".")[0].strip()
if not super().__contains__(top_nest):
if top_nest not in self._outer.nested_columns:
raise KeyError(f"Unknown nest {top_nest}")
super().__setitem__(top_nest, NestedFieldResolver(top_nest, self._outer))
return super().__getitem__(top_nest)

def __setitem__(self, key, value):
# Called to update the resolver with intermediate values.
# The important point is to intercept the call so that the evaluator
# does not create any new resolvers on the fly. Storing the value
# is not important, since that will have been done already in
# the NestedFrame.
pass


class NestedFieldResolver:
"""
Used by NestedFrame.eval to resolve the names of fields in nested columns when
encountered in expressions, interpreting __getattr__ in terms of a
specific nest context.
specific nest.
"""

def __init__(self, nest_name: str, outer: NestedFrame):
Expand Down Expand Up @@ -407,8 +448,7 @@ def eval(self, expr: str, *, inplace: bool = False, **kwargs) -> Any | None:
--------
https://pandas.pydata.org/docs/reference/api/pandas.eval.html
"""
nested_resolvers = self._get_nested_column_resolvers()
kwargs["resolvers"] = tuple(kwargs.get("resolvers", ())) + (nested_resolvers,)
kwargs["resolvers"] = tuple(kwargs.get("resolvers", ())) + (NestResolver(self),)
kwargs["inplace"] = inplace
kwargs["parser"] = "nested-pandas"
return super().eval(expr, **kwargs)
Expand Down Expand Up @@ -492,9 +532,6 @@ def query(self, expr: str, *, inplace: bool = False, **kwargs) -> NestedFrame |
else:
return result

def _get_nested_column_resolvers(self):
return {name: NestResolver(name, self) for name in self.nested_columns}

def _resolve_dropna_target(self, on_nested, subset):
"""resolves the target layer for a given set of dropna kwargs"""

Expand Down
72 changes: 72 additions & 0 deletions tests/nested_pandas/nestedframe/test_nestedframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,3 +819,75 @@ def test_mixed_eval_funcs():

# Across the nest: each base column element applies to each of its indexes
assert (nf.eval("a + packed.c") == nf["a"] + nf["packed.c"]).all()


def test_eval_assignment():
"""
Test eval strings that perform assignment, within base columns, nested columns,
and across base and nested.
"""
nf = NestedFrame(
data={"a": [1, 2, 3], "b": [2, 4, 6]},
index=pd.Index([0, 1, 2], name="idx"),
)
to_pack = pd.DataFrame(
data={
"time": [1, 2, 3, 1, 2, 4, 2, 1, 4],
"c": [0, 2, 4, 10, 4, 3, 1, 4, 1],
"d": [5, 4, 7, 5, 3, 1, 9, 3, 4],
},
index=pd.Index([0, 0, 0, 1, 1, 1, 2, 2, 2], name="idx"),
)
nf = nf.add_nested(to_pack, "packed")
# Assigning to new base columns from old base columns
nf_b = nf.eval("c = a + 1")
assert len(nf_b.columns) == len(nf.columns) + 1
assert (nf_b["c"] == nf["a"] + 1).all()

# Assigning to new nested columns from old nested columns
nf_nc = nf.eval("packed.e = packed.c + 1")
assert len(nf_nc.packed.nest.fields) == len(nf["packed"].nest.fields) + 1
assert (nf_nc["packed.e"] == nf["packed.c"] + 1).all()

# Verify that overwriting a nested column works
nf_nc_2 = nf_nc.eval("packed.e = packed.c * 2")
assert len(nf_nc_2.packed.nest.fields) == len(nf_nc["packed"].nest.fields)
assert (nf_nc_2["packed.e"] == nf["packed.c"] * 2).all()

# Assigning to new nested columns from a combo of base and nested
nf_nx = nf.eval("packed.f = a + packed.c")
assert len(nf_nx.packed.nest.fields) == len(nf["packed"].nest.fields) + 1
assert (nf_nx["packed.f"] == nf["a"] + nf["packed.c"]).all()
assert (nf_nx["packed.f"] == pd.Series([1, 3, 5, 12, 6, 5, 4, 7, 4], index=to_pack.index)).all()

# Assigning to new base columns from nested columns. This can't be done because
# it would attempt to create base column values that were "between indexes", or as
# Pandas puts, duplicate index labels.
with pytest.raises(ValueError):
nf.eval("g = packed.c * 2")

# Create new nests via eval()
nf_n2 = nf.eval("p2.c2 = packed.c * 2")
assert len(nf_n2.p2.nest.fields) == 1
assert (nf_n2["p2.c2"] == nf["packed.c"] * 2).all()
assert (nf_n2["p2.c2"] == pd.Series([0, 4, 8, 20, 8, 6, 2, 8, 2], index=to_pack.index)).all()
assert len(nf_n2.columns) == len(nf.columns) + 1 # new packed column
assert len(nf_n2.p2.nest.fields) == 1

# Assigning to new columns across two different nests
nf_n3 = nf_n2.eval("p2.d = p2.c2 + packed.d * 2 + b")
assert len(nf_n3.p2.nest.fields) == 2
assert (nf_n3["p2.d"] == nf_n2["p2.c2"] + nf["packed.d"] * 2 + nf["b"]).all()

# Now test multiline and inplace=True
nf.eval(
"""
c = a + b
p2.e = packed.d * 2 + c
p2.f = p2.e + b
""",
inplace=True,
)
assert len(nf.p2.nest.fields) == 2
assert (nf["p2.e"] == nf["packed.d"] * 2 + nf.c).all()
assert (nf["p2.f"] == nf["p2.e"] + nf.b).all()

0 comments on commit 6496cac

Please sign in to comment.