Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add deps and position field in VarData #4518

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ export function {{tag_name}} () {
{{ hook }}
{% endfor %}

{% for hook, data in component._get_all_hooks().items() if not data.position or data.position == positions.PRE_TRIGGER %}
{{ hook }}
{% endfor %}

{% for hook in memo_trigger_hooks %}
{{ hook }}
{% endfor %}

{% for hook in component._get_all_hooks() %}
{% for hook,data in component._get_all_hooks().items() if data.position and data.position == positions.POST_TRIGGER %}
Lendemor marked this conversation as resolved.
Show resolved Hide resolved
{{ hook }}
{% endfor %}

Expand Down
54 changes: 43 additions & 11 deletions reflex/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)

import reflex.state
from reflex import constants
from reflex.base import Base
from reflex.compiler.templates import STATEFUL_COMPONENT
from reflex.components.core.breakpoints import Breakpoints
Expand Down Expand Up @@ -1369,7 +1370,9 @@ def _get_hooks_imports(self) -> ParsedImportDict:
if user_hooks_data is not None:
other_imports.append(user_hooks_data.imports)
other_imports.extend(
hook_imports for hook_imports in self._get_added_hooks().values()
hook_vardata.imports
for hook_vardata in self._get_added_hooks().values()
if hook_vardata is not None
)

return imports.merge_imports(_imports, *other_imports)
Expand Down Expand Up @@ -1523,7 +1526,7 @@ def _get_hooks_internal(self) -> dict[str, None]:
**self._get_special_hooks(),
}

def _get_added_hooks(self) -> dict[str, ImportDict]:
def _get_added_hooks(self) -> dict[str, VarData | None]:
"""Get the hooks added via `add_hooks` method.
Returns:
Expand All @@ -1532,17 +1535,15 @@ def _get_added_hooks(self) -> dict[str, ImportDict]:
code = {}

def extract_var_hooks(hook: Var):
_imports = {}
var_data = VarData.merge(hook._get_all_var_data())
if var_data is not None:
for sub_hook in var_data.hooks:
code[sub_hook] = {}
if var_data.imports:
_imports = var_data.imports
code[sub_hook] = None
Lendemor marked this conversation as resolved.
Show resolved Hide resolved

if str(hook) in code:
code[str(hook)] = imports.merge_imports(code[str(hook)], _imports)
code[str(hook)] = VarData.merge(var_data, code[str(hook)])
else:
code[str(hook)] = _imports
code[str(hook)] = var_data

# Add the hook code from add_hooks for each parent class (this is reversed to preserve
# the order of the hooks in the final output)
Expand All @@ -1551,7 +1552,7 @@ def extract_var_hooks(hook: Var):
if isinstance(hook, Var):
extract_var_hooks(hook)
else:
code[hook] = {}
code[hook] = None

return code

Expand Down Expand Up @@ -1593,8 +1594,8 @@ def _get_all_hooks(self) -> dict[str, None]:
if hooks is not None:
code[hooks] = None

for hook in self._get_added_hooks():
code[hook] = None
for hook, var_data in self._get_added_hooks().items():
code[hook] = var_data

# Add the hook code for the children.
for child in self.children:
Expand Down Expand Up @@ -2168,6 +2169,7 @@ def _render_stateful_code(
tag_name=tag_name,
memo_trigger_hooks=memo_trigger_hooks,
component=component,
positions=constants.Hooks.HookPosition,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is kind of hacky. constants should be added to the global consts namespace in reflex/compiler/template.py

)

@staticmethod
Expand Down Expand Up @@ -2196,6 +2198,31 @@ def _get_hook_deps(hook: str) -> list[str]:
]
return [var_name]

@staticmethod
def _get_deps_from_event_trigger(event: EventChain | EventSpec | Var) -> set[str]:
"""Get the dependencies accessed by event triggers.
Args:
event: The event trigger to extract deps from.
Returns:
The dependencies accessed by the event triggers.
"""
events: list = [event]
deps = set()

if isinstance(event, EventChain):
events.extend(event.events)

for ev in events:
if isinstance(ev, EventSpec):
for arg in ev.args:
for a in arg:
var_datas = VarData.merge(a._get_all_var_data())
if var_datas and var_datas.deps is not None:
deps |= {str(dep) for dep in var_datas.deps}
return deps

@classmethod
def _get_memoized_event_triggers(
cls,
Expand Down Expand Up @@ -2232,6 +2259,11 @@ def _get_memoized_event_triggers(

# Calculate Var dependencies accessed by the handler for useCallback dep array.
var_deps = ["addEvents", "Event"]

# Get deps from event trigger var data.
var_deps.extend(cls._get_deps_from_event_trigger(event))

# Get deps from hooks.
for arg in event_args:
var_data = arg._get_all_var_data()
if var_data is None:
Expand Down
18 changes: 10 additions & 8 deletions reflex/components/core/clipboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

from reflex.components.base.fragment import Fragment
from reflex.components.tags.tag import Tag
from reflex.constants.compiler import Hooks
from reflex.event import EventChain, EventHandler, passthrough_event_spec
from reflex.utils.format import format_prop, wrap
from reflex.utils.imports import ImportVar
from reflex.vars import get_unique_variable_name
from reflex.vars.base import Var
from reflex.vars.base import Var, VarData


class Clipboard(Fragment):
Expand Down Expand Up @@ -72,7 +73,7 @@ def add_imports(self) -> dict[str, ImportVar]:
),
}

def add_hooks(self) -> list[str]:
def add_hooks(self) -> list[str | Var[str]]:
"""Add hook to register paste event listener.

Returns:
Expand All @@ -83,13 +84,14 @@ def add_hooks(self) -> list[str]:
return []
if isinstance(on_paste, EventChain):
on_paste = wrap(str(format_prop(on_paste)).strip("{}"), "(")
hook_expr = f"usePasteHandler({self.targets!s}, {self.on_paste_event_actions!s}, {on_paste!s})"

return [
"usePasteHandler(%s, %s, %s)"
% (
str(self.targets),
str(self.on_paste_event_actions),
on_paste,
)
Var(
hook_expr,
_var_type="str",
_var_data=VarData(position=Hooks.HookPosition.POST_TRIGGER),
),
]


Expand Down
2 changes: 1 addition & 1 deletion reflex/components/core/clipboard.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,6 @@ class Clipboard(Fragment):
...

def add_imports(self) -> dict[str, ImportVar]: ...
def add_hooks(self) -> list[str]: ...
def add_hooks(self) -> list[str | Var[str]]: ...

clipboard = Clipboard.create
7 changes: 5 additions & 2 deletions reflex/components/datadisplay/dataeditor.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,11 @@ def add_hooks(self) -> list[str]:
editor_id = get_unique_variable_name()

# Define the name of the getData callback associated with this component and assign to get_cell_content.
data_callback = f"getData_{editor_id}"
self.get_cell_content = Var(_js_expr=data_callback) # type: ignore
if self.get_cell_content is not None:
data_callback = self.get_cell_content._js_expr
else:
data_callback = f"getData_{editor_id}"
self.get_cell_content = Var(_js_expr=data_callback) # type: ignore

code = [f"function {data_callback}([col, row])" "{"]

Expand Down
6 changes: 6 additions & 0 deletions reflex/constants/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ class Hooks(SimpleNamespace):
}
})"""

class HookPosition(enum.Enum):
"""The position of the hook in the component."""

PRE_TRIGGER = "pre_trigger"
POST_TRIGGER = "post_trigger"


class MemoizationDisposition(enum.Enum):
"""The conditions under which a component should be memoized."""
Expand Down
50 changes: 46 additions & 4 deletions reflex/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@

from reflex import constants
from reflex.base import Base
from reflex.utils import console, imports, serializers, types
from reflex.constants.compiler import Hooks
from reflex.utils import console, exceptions, imports, serializers, types
from reflex.utils.exceptions import (
VarAttributeError,
VarDependencyError,
Expand Down Expand Up @@ -115,12 +116,20 @@ class VarData:
# Hooks that need to be present in the component to render this var
hooks: Tuple[str, ...] = dataclasses.field(default_factory=tuple)

# Dependencies of the var
deps: Tuple[Var, ...] = dataclasses.field(default_factory=tuple)

# Position of the hook in the component
position: Hooks.HookPosition | None = None

def __init__(
self,
state: str = "",
field_name: str = "",
imports: ImportDict | ParsedImportDict | None = None,
hooks: dict[str, None] | None = None,
deps: list[Var] | None = None,
position: Hooks.HookPosition | None = None,
):
"""Initialize the var data.

Expand All @@ -129,6 +138,8 @@ def __init__(
field_name: The name of the field in the state.
imports: Imports needed to render this var.
hooks: Hooks that need to be present in the component to render this var.
deps: Dependencies of the var for useCallback.
position: Position of the hook in the component.
"""
immutable_imports: ImmutableParsedImportDict = tuple(
sorted(
Expand All @@ -139,6 +150,8 @@ def __init__(
object.__setattr__(self, "field_name", field_name)
object.__setattr__(self, "imports", immutable_imports)
object.__setattr__(self, "hooks", tuple(hooks or {}))
object.__setattr__(self, "deps", tuple(deps or []))
object.__setattr__(self, "position", position or None)

def old_school_imports(self) -> ImportDict:
"""Return the imports as a mutable dict.
Expand All @@ -154,6 +167,9 @@ def merge(*all: VarData | None) -> VarData | None:
Args:
*all: The var data objects to merge.

Raises:
ReflexError: If trying to merge VarData with different positions.

Returns:
The merged var data object.

Expand Down Expand Up @@ -184,12 +200,32 @@ def merge(*all: VarData | None) -> VarData | None:
*(var_data.imports for var_data in all_var_datas)
)

if state or _imports or hooks or field_name:
deps = [dep for var_data in all_var_datas for dep in var_data.deps]

positions = list(
{
var_data.position
for var_data in all_var_datas
if var_data.position is not None
}
)
if positions:
if len(positions) > 1:
Lendemor marked this conversation as resolved.
Show resolved Hide resolved
raise exceptions.ReflexError(
f"Cannot merge var data with different positions: {positions}"
)
position = positions[0]
else:
position = None

if state or _imports or hooks or field_name or deps or position:
return VarData(
state=state,
field_name=field_name,
imports=_imports,
hooks=hooks,
deps=deps,
position=position,
)

return None
Expand All @@ -200,7 +236,14 @@ def __bool__(self) -> bool:
Returns:
True if any field is set to a non-default value.
"""
return bool(self.state or self.imports or self.hooks or self.field_name)
return bool(
self.state
or self.imports
or self.hooks
or self.field_name
or self.deps
or self.position
)

@classmethod
def from_state(cls, state: Type[BaseState] | str, field_name: str = "") -> VarData:
Expand Down Expand Up @@ -480,7 +523,6 @@ def _replace(
raise TypeError(
"The _var_full_name_needs_state_prefix argument is not supported for Var."
)

value_with_replaced = dataclasses.replace(
self,
_var_type=_var_type or self._var_type,
Expand Down
Loading