Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/batch set commit weights #2485

Open
wants to merge 19 commits into
base: staging
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions bittensor/core/async_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
)
from bittensor.core.extrinsics.async_transfer import transfer_extrinsic
from bittensor.core.extrinsics.async_weights import (
batch_commit_weights_extrinsic,
batch_set_weights_extrinsic,
commit_weights_extrinsic,
set_weights_extrinsic,
)
Expand Down Expand Up @@ -1611,6 +1613,86 @@ async def set_weights(

return success, message

async def batch_set_weights(
self,
wallet: "Wallet",
netuids: list[int],
uidss: list[Union[NDArray[np.int64], "torch.LongTensor", list]],
weightss: list[Union[NDArray[np.float32], "torch.FloatTensor", list]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typos? Also is there a reason we accept torch.LongTensor objects for set_weights but not for commit_weights?

Copy link
Collaborator Author

@camfairchild camfairchild Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typos?

Supposed to be a list of uid lists. Maybe uids_list?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also is there a reason we accept torch.LongTensor objects for set_weights but not for commit_weights?

I have no idea. This is copy-pasted from the previous impl

version_keys: list[int] = [],
wait_for_inclusion: bool = False,
wait_for_finalization: bool = False,
max_retries: int = 5,
):
"""
Batch set weights for multiple subnets.

Args:
wallet (bittensor_wallet.Wallet): The wallet associated with the neuron setting the weights.
netuids (list[int]): The list of subnet uids.
uidss (list[Union[NDArray[np.int64], torch.LongTensor, list]]): The list of neuron UIDs that the weights are being set for.
weightss (list[Union[NDArray[np.float32], torch.FloatTensor, list]]): The corresponding weights to be set for each UID.
version_keys (list[int]): Version keys for compatibility with the network. Default is ``int representation of Bittensor version.``.
wait_for_inclusion (bool): Waits for the transaction to be included in a block. Default is ``False``.
wait_for_finalization (bool): Waits for the transaction to be finalized on the blockchain. Default is ``False``.
max_retries (int): The number of maximum attempts to set weights. Default is ``5``.

Returns:
tuple[bool, str]: ``True`` if the setting of weights is successful, False otherwise. And `msg`, a string value describing the success or potential error.

This function is crucial in shaping the network's collective intelligence, where each neuron's learning and contribution are influenced by the weights it sets towards others【81†source】.
"""
netuids_to_set = []
uidss_to_set = []
weightss_to_set = []
version_keys_to_set = []

if len(version_keys) == 0:
version_keys = [version_as_int] * len(netuids)

for i, netuid in enumerate(netuids):
uid = await self.get_uid_for_hotkey_on_subnet(
wallet.hotkey.ss58_address, netuid
)
retries = 0
success = False
message = "No attempt made. Perhaps it is too soon to set weights!"

if await self.blocks_since_last_update(
netuid, uid
) <= await self.weights_rate_limit(netuid):
logging.info(
f"Skipping subnet #{netuid} as it has not reached the weights rate limit."
)
continue

netuids_to_set.append(netuid)
uidss_to_set.append(uidss[i])
weightss_to_set.append(weightss[i])
version_keys_to_set.append(version_keys[i])

while retries < max_retries:
try:
logging.info(
f"Setting batch of weights for subnets #[blue]{netuids_to_set}[/blue]. Attempt [blue]{retries + 1} of {max_retries}[/blue]."
)
success, message = await batch_set_weights_extrinsic(
subtensor=self,
wallet=wallet,
netuids=netuids_to_set,
uidss=uidss_to_set,
weightss=weightss_to_set,
version_keys=version_keys_to_set,
wait_for_inclusion=wait_for_inclusion,
wait_for_finalization=wait_for_finalization,
)
except Exception as e:
logging.error(f"Error setting batch of weights: {e}")
finally:
retries += 1

return success, message

async def root_set_weights(
self,
wallet: "Wallet",
Expand Down Expand Up @@ -1710,3 +1792,80 @@ async def commit_weights(
retries += 1

return success, message

async def batch_commit_weights(
self,
wallet: "Wallet",
netuids: list[int],
salts: list[list[int]],
uids: list[Union[NDArray[np.int64], list]],
weights: list[Union[NDArray[np.int64], list]],
version_keys: list[int] = [],
camfairchild marked this conversation as resolved.
Show resolved Hide resolved
wait_for_inclusion: bool = False,
wait_for_finalization: bool = False,
max_retries: int = 5,
) -> tuple[bool, str]:
"""
Commits a batch of hashes of weights to the Bittensor blockchain using the provided wallet.
This allows for multiple subnets to be committed to at once in a single extrinsic.

Args:
wallet (bittensor_wallet.Wallet): The wallet associated with the neuron committing the weights.
netuids (list[int]): The list of subnet uids.
salts (list[list[int]]): The list of salts to generate weight hashes.
uids (list[np.ndarray]): The list of NumPy arrays of neuron UIDs for which weights are being committed.
weights (list[np.ndarray]): The list of NumPy arrays of weight values corresponding to each UID.
version_keys (list[int]): The list of version keys for compatibility with the network. Default is ``int representation of Bittensor version.``.
wait_for_inclusion (bool): Waits for the transaction to be included in a block. Default is ``False``.
wait_for_finalization (bool): Waits for the transaction to be finalized on the blockchain. Default is ``False``.
max_retries (int): The number of maximum attempts to commit weights. Default is ``5``.

Returns:
tuple[bool, str]: ``True`` if the weight commitment is successful, False otherwise. And `msg`, a string value describing the success or potential error.

This function allows commitments to be made for multiple subnets at once.
"""
retries = 0
success = False
message = "No attempt made. Perhaps it is too soon to commit weights!"

logging.info(
f"Committing a batch of weights with params: netuids={netuids}, salts={salts}, uids={uids}, weights={weights}, version_keys={version_keys}"
)

if len(version_keys) == 0:
version_keys = [version_as_int] * len(netuids)

# Generate the hash of the weights
commit_hashes = [
generate_weight_hash(
address=wallet.hotkey.ss58_address,
netuid=netuid,
uids=list(uids),
values=list(weights),
Comment on lines +1833 to +1834
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this is what you want. Calling list on a nparray will create a list of np ints, rather than a list of ints. Usually you'd do this with array.tolist().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is taken from commit_weights

salt=salt,
version_key=version_key,
)
for netuid, salt, uids, weights, version_key in zip(
netuids, salts, uids, weights, version_keys
)
]

while retries < max_retries:
try:
success, message = await batch_commit_weights_extrinsic(
subtensor=self,
wallet=wallet,
netuids=netuids,
commit_hashes=commit_hashes,
wait_for_inclusion=wait_for_inclusion,
wait_for_finalization=wait_for_finalization,
)
if success:
break
except Exception as e:
logging.error(f"Error batch committing weights: {e}")
finally:
retries += 1

return success, message
Loading
Loading