Skip to content

Commit

Permalink
[CHIA-1032] Add a config slot in action scopes and use it for wallets (
Browse files Browse the repository at this point in the history
…#18365)

* Add a config slot in action scopes and use it for wallets

* Simplify tx_endpoint a bit

* pylint
  • Loading branch information
Quexington authored Jul 29, 2024
1 parent 1c1e030 commit 2c2b0f4
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 28 deletions.
27 changes: 17 additions & 10 deletions chia/_tests/util/test_action_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ def from_bytes(cls, blob: bytes) -> TestSideEffects:
return cls(blob)


@final
@dataclass
class TestConfig:
test_foo: str = "test_foo"


async def default_async_callback(interface: StateInterface[TestSideEffects]) -> None:
return None # pragma: no cover

Expand All @@ -36,13 +42,14 @@ def test_set_callback() -> None:


@pytest.fixture(name="action_scope")
async def action_scope_fixture() -> AsyncIterator[ActionScope[TestSideEffects]]:
async with ActionScope.new_scope(TestSideEffects) as scope:
async def action_scope_fixture() -> AsyncIterator[ActionScope[TestSideEffects, TestConfig]]:
async with ActionScope.new_scope(TestSideEffects, TestConfig()) as scope:
assert scope.config == TestConfig(test_foo="test_foo")
yield scope


@pytest.mark.anyio
async def test_new_action_scope(action_scope: ActionScope[TestSideEffects]) -> None:
async def test_new_action_scope(action_scope: ActionScope[TestSideEffects, TestConfig]) -> None:
"""
Assert we can immediately check out some initial state
"""
Expand All @@ -51,7 +58,7 @@ async def test_new_action_scope(action_scope: ActionScope[TestSideEffects]) -> N


@pytest.mark.anyio
async def test_scope_persistence(action_scope: ActionScope[TestSideEffects]) -> None:
async def test_scope_persistence(action_scope: ActionScope[TestSideEffects, TestConfig]) -> None:
async with action_scope.use() as interface:
interface.side_effects.buf = b"baz"

Expand All @@ -60,7 +67,7 @@ async def test_scope_persistence(action_scope: ActionScope[TestSideEffects]) ->


@pytest.mark.anyio
async def test_transactionality(action_scope: ActionScope[TestSideEffects]) -> None:
async def test_transactionality(action_scope: ActionScope[TestSideEffects, TestConfig]) -> None:
async with action_scope.use() as interface:
interface.side_effects.buf = b"baz"

Expand All @@ -75,7 +82,7 @@ async def test_transactionality(action_scope: ActionScope[TestSideEffects]) -> N

@pytest.mark.anyio
async def test_callbacks() -> None:
async with ActionScope.new_scope(TestSideEffects) as action_scope:
async with ActionScope.new_scope(TestSideEffects, TestConfig()) as action_scope:
async with action_scope.use() as interface:

async def callback(interface: StateInterface[TestSideEffects]) -> None:
Expand All @@ -89,7 +96,7 @@ async def callback(interface: StateInterface[TestSideEffects]) -> None:
@pytest.mark.anyio
async def test_callback_in_callback_error() -> None:
with pytest.raises(RuntimeError, match="Callback"):
async with ActionScope.new_scope(TestSideEffects) as action_scope:
async with ActionScope.new_scope(TestSideEffects, TestConfig()) as action_scope:
async with action_scope.use() as interface:

async def callback(interface: StateInterface[TestSideEffects]) -> None:
Expand All @@ -101,7 +108,7 @@ async def callback(interface: StateInterface[TestSideEffects]) -> None:
@pytest.mark.anyio
async def test_no_callbacks_if_error() -> None:
with pytest.raises(Exception, match="This should prevent the callbacks from being called"):
async with ActionScope.new_scope(TestSideEffects) as action_scope:
async with ActionScope.new_scope(TestSideEffects, TestConfig()) as action_scope:
async with action_scope.use() as interface:

async def callback(interface: StateInterface[TestSideEffects]) -> None:
Expand All @@ -113,7 +120,7 @@ async def callback(interface: StateInterface[TestSideEffects]) -> None:
raise RuntimeError("This should prevent the callbacks from being called")

with pytest.raises(Exception, match="This should prevent the callbacks from being called"):
async with ActionScope.new_scope(TestSideEffects) as action_scope:
async with ActionScope.new_scope(TestSideEffects, TestConfig()) as action_scope:
async with action_scope.use() as interface:

async def callback2(interface: StateInterface[TestSideEffects]) -> None:
Expand All @@ -126,7 +133,7 @@ async def callback2(interface: StateInterface[TestSideEffects]) -> None:

# TODO: add suport, change this test to test it and add a test for nested transactionality
@pytest.mark.anyio
async def test_nested_use_banned(action_scope: ActionScope[TestSideEffects]) -> None:
async def test_nested_use_banned(action_scope: ActionScope[TestSideEffects, TestConfig]) -> None:
async with action_scope.use():
with pytest.raises(RuntimeError, match="cannot currently support nested transactions"):
async with action_scope.use():
Expand Down
3 changes: 0 additions & 3 deletions chia/rpc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,6 @@ async def inner(request) -> aiohttp.web.Response:
def tx_endpoint(
push: bool = False,
merge_spends: bool = True,
# The purpose of this is in case endpoints need to raise based on certain non default values
requires_default_information: bool = False,
) -> Callable[[RpcEndpoint], RpcEndpoint]:
def _inner(func: RpcEndpoint) -> RpcEndpoint:
async def rpc_endpoint(self, request: Dict[str, Any], *args, **kwargs) -> Dict[str, Any]:
Expand Down Expand Up @@ -162,7 +160,6 @@ async def rpc_endpoint(self, request: Dict[str, Any], *args, **kwargs) -> Dict[s
request,
*args,
action_scope,
*([push] if requires_default_information else []),
tx_config=tx_config,
extra_conditions=extra_conditions,
**kwargs,
Expand Down
15 changes: 6 additions & 9 deletions chia/rpc/wallet_rpc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,12 +694,11 @@ async def get_wallets(self, request: Dict[str, Any]) -> EndpointResult:
response["fingerprint"] = self.service.logged_in_fingerprint
return response

@tx_endpoint(push=True, requires_default_information=True)
@tx_endpoint(push=True)
async def create_new_wallet(
self,
request: Dict[str, Any],
action_scope: WalletActionScope,
push: bool = True,
tx_config: TXConfig = DEFAULT_TX_CONFIG,
extra_conditions: Tuple[Condition, ...] = tuple(),
) -> EndpointResult:
Expand All @@ -715,7 +714,7 @@ async def create_new_wallet(
name = request.get("name", None)
if request["mode"] == "new":
if request.get("test", False):
if not push:
if not action_scope.config.push:
raise ValueError("Test CAT minting must be pushed automatically") # pragma: no cover
async with self.service.wallet_state_manager.lock:
cat_wallet = await CATWallet.create_new_cat_wallet(
Expand Down Expand Up @@ -1825,16 +1824,15 @@ async def cat_asset_id_to_name(self, request: Dict[str, Any]) -> EndpointResult:
else:
return {"wallet_id": wallet.id(), "name": (wallet.get_name())}

@tx_endpoint(push=False, requires_default_information=True)
@tx_endpoint(push=False)
async def create_offer_for_ids(
self,
request: Dict[str, Any],
action_scope: WalletActionScope,
push: bool = False,
tx_config: TXConfig = DEFAULT_TX_CONFIG,
extra_conditions: Tuple[Condition, ...] = tuple(),
) -> EndpointResult:
if push:
if action_scope.config.push:
raise ValueError("Cannot push an incomplete spend") # pragma: no cover

offer: Dict[str, int] = request["offer"]
Expand Down Expand Up @@ -3617,16 +3615,15 @@ async def nft_calculate_royalties(self, request: Dict[str, Any]) -> EndpointResu
{asset["asset"]: uint64(asset["amount"]) for asset in request.get("fungible_assets", [])},
)

@tx_endpoint(push=False, requires_default_information=True)
@tx_endpoint(push=False)
async def nft_mint_bulk(
self,
request: Dict[str, Any],
action_scope: WalletActionScope,
push: bool = False,
tx_config: TXConfig = DEFAULT_TX_CONFIG,
extra_conditions: Tuple[Condition, ...] = tuple(),
) -> EndpointResult:
if push:
if action_scope.config.push:
raise ValueError("Automatic pushing of nft minting transactions not yet available") # pragma: no cover
if await self.service.wallet_state_manager.synced() is False:
raise ValueError("Wallet needs to be fully synced.")
Expand Down
15 changes: 12 additions & 3 deletions chia/util/action_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,12 @@ def from_bytes(cls: Type[_T_SideEffects], blob: bytes) -> _T_SideEffects: ...


_T_SideEffects = TypeVar("_T_SideEffects", bound=SideEffects)
_T_Config = TypeVar("_T_Config")


@final
@dataclass
class ActionScope(Generic[_T_SideEffects]):
class ActionScope(Generic[_T_SideEffects, _T_Config]):
"""
The idea of an "action" is to map a single client input to many potentially distributed functions and side
effects. The action holds on to a temporary state that the many callers modify at will but only one at a time.
Expand All @@ -100,6 +101,7 @@ class ActionScope(Generic[_T_SideEffects]):

_resource_manager: ResourceManager
_side_effects_format: Type[_T_SideEffects]
_config: _T_Config # An object not intended to be mutated during the lifetime of the scope
_callback: Optional[Callable[[StateInterface[_T_SideEffects]], Awaitable[None]]] = None
_final_side_effects: Optional[_T_SideEffects] = field(init=False, default=None)

Expand All @@ -113,15 +115,22 @@ def side_effects(self) -> _T_SideEffects:

return self._final_side_effects

@property
def config(self) -> _T_Config:
return self._config

@classmethod
@contextlib.asynccontextmanager
async def new_scope(
cls,
side_effects_format: Type[_T_SideEffects],
# I want a default here in case a use case doesn't want to take advantage of the config but no default seems to
# satisfy the type hint _T_Config so we'll just ignore this.
config: _T_Config = object(), # type: ignore[assignment]
resource_manager_backend: Type[ResourceManager] = SQLiteResourceManager,
) -> AsyncIterator[ActionScope[_T_SideEffects]]:
) -> AsyncIterator[ActionScope[_T_SideEffects, _T_Config]]:
async with resource_manager_backend.managed(side_effects_format()) as resource_manager:
self = cls(_resource_manager=resource_manager, _side_effects_format=side_effects_format)
self = cls(_resource_manager=resource_manager, _side_effects_format=side_effects_format, _config=config)

yield self

Expand Down
18 changes: 15 additions & 3 deletions chia/wallet/wallet_action_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import contextlib
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, AsyncIterator, List, Optional, cast
from typing import TYPE_CHECKING, AsyncIterator, List, Optional, cast, final

from chia.types.spend_bundle import SpendBundle
from chia.util.action_scope import ActionScope
Expand Down Expand Up @@ -65,7 +65,17 @@ def from_bytes(cls, blob: bytes) -> WalletSideEffects:
return instance


WalletActionScope = ActionScope[WalletSideEffects]
@final
@dataclass(frozen=True)
class WalletActionConfig:
push: bool
merge_spends: bool
sign: Optional[bool]
additional_signing_responses: List[SigningResponse]
extra_spends: List[SpendBundle]


WalletActionScope = ActionScope[WalletSideEffects, WalletActionConfig]


@contextlib.asynccontextmanager
Expand All @@ -77,7 +87,9 @@ async def new_wallet_action_scope(
additional_signing_responses: List[SigningResponse] = [],
extra_spends: List[SpendBundle] = [],
) -> AsyncIterator[WalletActionScope]:
async with ActionScope.new_scope(WalletSideEffects) as self:
async with ActionScope.new_scope(
WalletSideEffects, WalletActionConfig(push, merge_spends, sign, additional_signing_responses, extra_spends)
) as self:
self = cast(WalletActionScope, self)
async with self.use() as interface:
interface.side_effects.signing_responses = additional_signing_responses.copy()
Expand Down

0 comments on commit 2c2b0f4

Please sign in to comment.