Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CHIA-783: Stop auto-subscribing to local stores #18166

Merged
merged 22 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 101 additions & 1 deletion chia/_tests/core/data_layer/test_data_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
import enum
import json
import logging
import os
import random
import sqlite3
Expand All @@ -14,7 +15,7 @@
from dataclasses import dataclass
from enum import IntEnum
from pathlib import Path
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, cast
from typing import Any, AsyncIterator, Dict, List, Optional, Set, Tuple, cast

import anyio
import pytest
Expand Down Expand Up @@ -2239,6 +2240,16 @@ async def test_maximum_full_file_count(
assert filename not in filenames


@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules")
@pytest.mark.anyio
async def test_unsubscribe_unknown(
bare_data_layer_api: DataLayerRpcApi,
seeded_random: random.Random,
) -> None:
with pytest.raises(RuntimeError, match="No subscription found for the given store_id."):
await bare_data_layer_api.unsubscribe(request={"id": bytes32.random(seeded_random).hex(), "retain": False})


@pytest.mark.parametrize("retain", [True, False])
@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules")
@pytest.mark.anyio
Expand Down Expand Up @@ -2266,6 +2277,8 @@ async def test_unsubscribe_removes_files(
store_id = bytes32.from_hexstr(res["id"])
await farm_block_check_singleton(data_layer, full_node_api, ph, store_id, wallet=wallet_rpc_api.service)

# subscribe to ourselves
await data_rpc_api.subscribe(request={"id": store_id.hex()})
update_count = 10
for batch_count in range(update_count):
key = batch_count.to_bytes(2, "big")
Expand Down Expand Up @@ -3712,3 +3725,90 @@ class ModifiedStatus(IntEnum):
await farm_block_with_spend(full_node_api, ph, update_tx_rec1, wallet_rpc_api)
keys = await data_rpc_api.get_keys({"id": store_id.hex()})
assert keys == {"keys": ["0x30303031", "0x30303030"]}


@pytest.mark.parametrize(argnames="auto_subscribe_to_local_stores", argvalues=[True, False])
altendky marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules")
@pytest.mark.anyio
async def test_auto_subscribe_to_local_stores(
self_hostname: str,
one_wallet_and_one_simulator_services: SimulatorsAndWalletsServices,
tmp_path: Path,
monkeypatch: Any,
auto_subscribe_to_local_stores: bool,
) -> None:
wallet_rpc_api, full_node_api, wallet_rpc_port, ph, bt = await init_wallet_and_node(
self_hostname, one_wallet_and_one_simulator_services
)
manage_data_interval = 5
fake_store = bytes32([1] * 32)

async def mock_get_store_ids(self: Any) -> Set[bytes32]:
return {fake_store}

async def mock_dl_track_new(self: Any, request: Dict[str, Any]) -> Dict[str, Any]:
# ignore and just return empty response
return {}

with monkeypatch.context() as m:
m.setattr("chia.data_layer.data_store.DataStore.get_store_ids", mock_get_store_ids)
m.setattr("chia.rpc.wallet_rpc_client.WalletRpcClient.dl_track_new", mock_dl_track_new)

config = bt.config
config["data_layer"]["auto_subscribe_to_local_stores"] = auto_subscribe_to_local_stores
bt.change_config(new_config=config)

async with init_data_layer(
wallet_rpc_port=wallet_rpc_port,
bt=bt,
db_path=tmp_path,
manage_data_interval=manage_data_interval,
maximum_full_file_count=100,
) as data_layer:
data_rpc_api = DataLayerRpcApi(data_layer)

await asyncio.sleep(manage_data_interval)

response = await data_rpc_api.subscriptions(request={})

if auto_subscribe_to_local_stores:
assert fake_store.hex() in response["store_ids"]
else:
assert fake_store.hex() not in response["store_ids"]


@pytest.mark.limit_consensus_modes(reason="does not depend on consensus rules")
@pytest.mark.anyio
async def test_local_store_exception(
self_hostname: str,
one_wallet_and_one_simulator_services: SimulatorsAndWalletsServices,
tmp_path: Path,
monkeypatch: Any,
caplog: pytest.LogCaptureFixture,
) -> None:
wallet_rpc_api, full_node_api, wallet_rpc_port, ph, bt = await init_wallet_and_node(
self_hostname, one_wallet_and_one_simulator_services
)
manage_data_interval = 5
fake_store = bytes32([1] * 32)

async def mock_get_store_ids(self: Any) -> Set[bytes32]:
return {fake_store}

with monkeypatch.context() as m, caplog.at_level(logging.INFO):
m.setattr("chia.data_layer.data_store.DataStore.get_store_ids", mock_get_store_ids)

config = bt.config
config["data_layer"]["auto_subscribe_to_local_stores"] = True
bt.change_config(new_config=config)

async with init_data_layer(
wallet_rpc_port=wallet_rpc_port,
bt=bt,
db_path=tmp_path,
manage_data_interval=manage_data_interval,
maximum_full_file_count=100,
):
await asyncio.sleep(manage_data_interval)

assert f"Can't subscribe to local store {fake_store.hex()}:" in caplog.text
42 changes: 35 additions & 7 deletions chia/data_layer/data_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,10 @@ async def remove_subscriptions(self, store_id: bytes32, urls: List[str]) -> None

async def unsubscribe(self, store_id: bytes32, retain_data: bool) -> None:
async with self.subscription_lock:
subscriptions = await self.data_store.get_subscriptions()
if store_id not in (subscription.store_id for subscription in subscriptions):
raise RuntimeError("No subscription found for the given store_id.")

# Unsubscribe is processed later, after all fetching of data is done, to avoid races.
self.unsubscribe_data_queue.append(UnsubscribeData(store_id, retain_data))

Expand Down Expand Up @@ -863,22 +867,46 @@ async def periodically_manage_data(self) -> None:
await asyncio.sleep(0.1)

while not self._shut_down:
# Add existing subscriptions
async with self.subscription_lock:
subscriptions = await self.data_store.get_subscriptions()

# Subscribe to all local store_ids that we can find on chain.
local_store_ids = await self.data_store.get_store_ids()
# pseudo-subscribe to all unsubscribed owned stores
# Need this to make sure we process updates and generate DAT files
try:
owned_stores = await self.get_owned_stores()
except ValueError:
# Sometimes the DL wallet isn't available, so we can't get the owned stores.
# We'll try again next time.
owned_stores = []
subscription_store_ids = {subscription.store_id for subscription in subscriptions}
for local_id in local_store_ids:
if local_id not in subscription_store_ids:
for record in owned_stores:
store_id = record.launcher_id
if store_id not in subscription_store_ids:
try:
subscription = await self.subscribe(local_id, [])
subscriptions.insert(0, subscription)
# subscription = await self.subscribe(store_id, [])
altendky marked this conversation as resolved.
Show resolved Hide resolved
subscriptions.insert(0, Subscription(store_id=store_id, servers_info=[]))
except Exception as e:
self.log.info(
f"Can't subscribe to locally stored {local_id}: {type(e)} {e} {traceback.format_exc()}"
f"Can't subscribe to owned store {store_id}: {type(e)} {e} {traceback.format_exc()}"
)

# Optionally
# Subscribe to all local non-owned store_ids that we can find on chain.
# This is the prior behavior where all local stores, both owned and not owned, are subscribed to.
if self.config.get("auto_subscribe_to_local_stores", False):
local_store_ids = await self.data_store.get_store_ids()
subscription_store_ids = {subscription.store_id for subscription in subscriptions}
for local_id in local_store_ids:
if local_id not in subscription_store_ids:
try:
subscription = await self.subscribe(local_id, [])
subscriptions.insert(0, subscription)
except Exception as e:
self.log.info(
f"Can't subscribe to local store {local_id}: {type(e)} {e} {traceback.format_exc()}"
)

work_queue: asyncio.Queue[Job[Subscription]] = asyncio.Queue()
async with QueuedAsyncPool.managed(
name="DataLayer subscription update pool",
Expand Down
Loading