Skip to content

Commit

Permalink
[CHIA-1024] Fix wallet observer mode log in while non-observer keys a…
Browse files Browse the repository at this point in the history
…re present (#18361)

Fix wallet observer mode log in while non-observer keys are present
  • Loading branch information
Quexington authored Jul 29, 2024
1 parent 54ff8ba commit 6a5d84d
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 10 deletions.
31 changes: 31 additions & 0 deletions chia/_tests/wallet/test_wallet_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,3 +720,34 @@ async def validate_received_state_from_peer(*args: Any) -> bool:

with pytest.raises(PeerRequestException):
await wallet_node.get_coin_state([], wallet_node.get_full_node_peer())


@pytest.mark.anyio
@pytest.mark.standard_block_tools
async def test_start_with_multiple_key_types(
simulator_and_wallet: OldSimulatorsAndWallets, self_hostname: str, default_400_blocks: List[FullBlock]
) -> None:
[full_node_api], [(wallet_node, wallet_server)], bt = simulator_and_wallet

async def restart_with_fingerprint(fingerprint: Optional[int]) -> None:
wallet_node._close()
await wallet_node._await_closed(shutting_down=False)
await wallet_node._start_with_fingerprint(fingerprint=fingerprint)

initial_sk = wallet_node.wallet_state_manager.private_key

pk: G1Element = await wallet_node.keychain_proxy.add_key(
"c00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
None,
private=False,
)
fingerprint_pk: int = pk.get_fingerprint()

await restart_with_fingerprint(fingerprint_pk)
assert wallet_node.wallet_state_manager.private_key is None
assert wallet_node.wallet_state_manager.root_pubkey == G1Element()

await wallet_node.keychain_proxy.delete_key_by_fingerprint(fingerprint_pk)

await restart_with_fingerprint(fingerprint_pk)
assert wallet_node.wallet_state_manager.private_key == initial_sk
7 changes: 6 additions & 1 deletion chia/daemon/keychain_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,12 @@ async def get_key_for_fingerprint(
break
else:
raise KeychainKeyNotFound(fingerprint)
key = selected_key.private_key if private else selected_key.public_key
if private and selected_key.secrets is not None:
key = selected_key.private_key
elif not private:
key = selected_key.public_key
else:
return None
else:
response, success = await self.get_response_for_request(
"get_key_for_fingerprint", {"fingerprint": fingerprint, "private": private}
Expand Down
24 changes: 15 additions & 9 deletions chia/wallet/wallet_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,17 @@ async def get_key_for_fingerprint(

return key

async def get_key(self, fingerprint: Optional[int], private: bool = True) -> Optional[Union[PrivateKey, G1Element]]:
async def get_key(
self, fingerprint: Optional[int], private: bool = True, find_a_default: bool = True
) -> Optional[Union[PrivateKey, G1Element]]:
"""
Attempt to get the private key for the given fingerprint. If the fingerprint is None,
get_key_for_fingerprint() will return the first private key. Similarly, if a key isn't
returned for the provided fingerprint, the first key will be returned.
"""
key: Optional[Union[PrivateKey, G1Element]] = await self.get_key_for_fingerprint(fingerprint, private=private)

if key is None and fingerprint is not None:
if key is None and fingerprint is not None and find_a_default:
key = await self.get_key_for_fingerprint(None, private=private)
if key is not None:
if isinstance(key, PrivateKey):
Expand Down Expand Up @@ -413,15 +415,19 @@ async def _start_with_fingerprint(
multiprocessing_context = multiprocessing.get_context(method=multiprocessing_start_method)
self._weight_proof_handler = WalletWeightProofHandler(self.constants, multiprocessing_context)
self.synced_peers = set()
private_key = await self.get_key(fingerprint, private=True)
public_key = None
private_key = await self.get_key(fingerprint, private=True, find_a_default=False)
if private_key is None:
public_key = await self.get_key(fingerprint, private=False)
else:
assert isinstance(private_key, PrivateKey)
public_key = private_key.get_g1()
public_key = await self.get_key(fingerprint, private=False, find_a_default=False)

if public_key is None:
self.log_out()
return False
private_key = await self.get_key(None, private=True, find_a_default=True)
if private_key is not None:
assert isinstance(private_key, PrivateKey)
public_key = private_key.get_g1()
else:
self.log_out()
return False
assert isinstance(public_key, G1Element)
# override with private key fetched in case it's different from what was passed
if fingerprint is None:
Expand Down

0 comments on commit 6a5d84d

Please sign in to comment.