Skip to content

Commit

Permalink
v2.15.4
Browse files Browse the repository at this point in the history
v2.15.4
  • Loading branch information
bkb2135 authored Dec 29, 2024
2 parents 0355e52 + da461aa commit 2b9a7ce
Show file tree
Hide file tree
Showing 11 changed files with 138 additions and 62 deletions.
104 changes: 71 additions & 33 deletions neurons/validator.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
# ruff: noqa: E402
from shared import settings

settings.shared_settings = settings.SharedSettings.load(mode="validator")
shared_settings = settings.shared_settings

import asyncio
import multiprocessing as mp
import sys
import time

import loguru
import torch
import wandb

# ruff: noqa: E402
from shared import settings

shared_settings = settings.shared_settings
settings.shared_settings = settings.SharedSettings.load(mode="validator")


from prompting.api.api import start_scoring_api
from prompting.llms.model_manager import model_scheduler
from prompting.llms.utils import GPUInfo
from prompting.miner_availability.miner_availability import availability_checking_loop
from prompting.rewards.scoring import task_scorer
from prompting.tasks.task_creation import task_loop
from prompting.tasks.task_sending import task_sender
from prompting.weight_setting.weight_setter import weight_setter
from shared.profiling import profiler

# Add a handler to write logs to a file
loguru.logger.add("logfile.log", rotation="1000 MB", retention="10 days", level="DEBUG")
Expand All @@ -32,8 +27,34 @@

def create_loop_process(task_queue, scoring_queue, reward_events):
async def spawn_loops(task_queue, scoring_queue, reward_events):
# ruff: noqa: E402
wandb.setup()
from shared import settings

settings.shared_settings = settings.SharedSettings.load(mode="validator")

from prompting.llms.model_manager import model_scheduler
from prompting.miner_availability.miner_availability import availability_checking_loop
from prompting.rewards.scoring import task_scorer
from prompting.tasks.task_creation import task_loop
from prompting.tasks.task_sending import task_sender
from prompting.weight_setting.weight_setter import weight_setter
from shared.profiling import profiler

logger.info("Starting Profiler...")
asyncio.create_task(profiler.print_stats(), name="Profiler"),

# -------- Duplicate of create_task_loop ----------
logger.info("Starting AvailabilityCheckingLoop...")
asyncio.create_task(availability_checking_loop.start())

logger.info("Starting TaskSender...")
asyncio.create_task(task_sender.start(task_queue, scoring_queue))

logger.info("Starting TaskLoop...")
asyncio.create_task(task_loop.start(task_queue, scoring_queue))
# -------------------------------------------------

logger.info("Starting ModelScheduler...")
asyncio.create_task(model_scheduler.start(scoring_queue), name="ModelScheduler"),
logger.info("Starting TaskScorer...")
Expand Down Expand Up @@ -62,6 +83,8 @@ async def spawn_loops(task_queue, scoring_queue, reward_events):

def start_api():
async def start():
from prompting.api.api import start_scoring_api # noqa: F401

await start_scoring_api()
while True:
await asyncio.sleep(10)
Expand All @@ -70,21 +93,21 @@ async def start():
asyncio.run(start())


def create_task_loop(task_queue, scoring_queue):
async def start(task_queue, scoring_queue):
logger.info("Starting AvailabilityCheckingLoop...")
asyncio.create_task(availability_checking_loop.start())
# def create_task_loop(task_queue, scoring_queue):
# async def start(task_queue, scoring_queue):
# logger.info("Starting AvailabilityCheckingLoop...")
# asyncio.create_task(availability_checking_loop.start())

logger.info("Starting TaskSender...")
asyncio.create_task(task_sender.start(task_queue, scoring_queue))
# logger.info("Starting TaskSender...")
# asyncio.create_task(task_sender.start(task_queue, scoring_queue))

logger.info("Starting TaskLoop...")
asyncio.create_task(task_loop.start(task_queue, scoring_queue))
while True:
await asyncio.sleep(10)
logger.debug("Running task loop...")
# logger.info("Starting TaskLoop...")
# asyncio.create_task(task_loop.start(task_queue, scoring_queue))
# while True:
# await asyncio.sleep(10)
# logger.debug("Running task loop...")

asyncio.run(start(task_queue, scoring_queue))
# asyncio.run(start(task_queue, scoring_queue))


async def main():
Expand All @@ -109,23 +132,38 @@ async def main():
loop_process = mp.Process(
target=create_loop_process, args=(task_queue, scoring_queue, reward_events), name="LoopProcess"
)
task_loop_process = mp.Process(
target=create_task_loop, args=(task_queue, scoring_queue), name="TaskLoopProcess"
)
# task_loop_process = mp.Process(
# target=create_task_loop, args=(task_queue, scoring_queue), name="TaskLoopProcess"
# )
loop_process.start()
task_loop_process.start()
# task_loop_process.start()
processes.append(loop_process)
processes.append(task_loop_process)
# processes.append(task_loop_process)
GPUInfo.log_gpu_info()

step = 0
while True:
await asyncio.sleep(10)
logger.debug("Running...")
await asyncio.sleep(30)
if (
shared_settings.SUBTENSOR.get_current_block()
- shared_settings.METAGRAPH.last_update[shared_settings.UID]
> 500
and step > 120
):
logger.warning(
f"UPDATES HAVE STALED FOR {shared_settings.SUBTENSOR.get_current_block() - shared_settings.METAGRAPH.last_update[shared_settings.UID]} BLOCKS AND {step} STEPS"
)
logger.warning(
f"STALED: {shared_settings.SUBTENSOR.get_current_block()}, {shared_settings.METAGRAPH.block}"
)
sys.exit(1)
step += 1

except Exception as e:
logger.error(f"Main loop error: {e}")
raise
finally:
wandb.teardown()
# Clean up processes
for process in processes:
if process.is_alive():
Expand Down
12 changes: 1 addition & 11 deletions prompting/rewards/scoring.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
import threading
from dataclasses import dataclass

from loguru import logger
from pydantic import ConfigDict

from prompting.llms.model_manager import model_manager, model_scheduler
from prompting.rewards.scoring_config import ScoringConfig
from prompting.tasks.base_task import BaseTextTask
from prompting.tasks.task_registry import TaskRegistry
from shared.base import DatasetEntry
Expand All @@ -14,16 +14,6 @@
from shared.loop_runner import AsyncLoopRunner


@dataclass
class ScoringConfig:
task: BaseTextTask
response: DendriteResponseEvent
dataset_entry: DatasetEntry
block: int
step: int
task_id: str


class TaskScorer(AsyncLoopRunner):
"""The scoring manager maintains a queue of tasks & responses to score and then runs a scoring loop in a background thread.
This scoring loop will score the responses once the LLM needed is loaded in the model_manager and log the rewards.
Expand Down
15 changes: 15 additions & 0 deletions prompting/rewards/scoring_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from dataclasses import dataclass

from prompting.tasks.base_task import BaseTextTask
from shared.base import DatasetEntry
from shared.dendrite import DendriteResponseEvent


@dataclass
class ScoringConfig:
task: BaseTextTask
response: DendriteResponseEvent
dataset_entry: DatasetEntry
block: int
step: int
task_id: str
5 changes: 3 additions & 2 deletions prompting/tasks/task_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from prompting.miner_availability.miner_availability import miner_availabilities
from prompting.tasks.task_registry import TaskRegistry
from shared.logging import ErrorLoggingEvent, ValidatorLoggingEvent

# from shared.logging import ErrorLoggingEvent, ValidatorLoggingEvent
from shared.loop_runner import AsyncLoopRunner
from shared.settings import shared_settings

Expand All @@ -26,7 +27,7 @@ async def start(self, task_queue, scoring_queue):
self.scoring_queue = scoring_queue
await super().start()

async def run_step(self) -> ValidatorLoggingEvent | ErrorLoggingEvent | None:
async def run_step(self):
if len(self.task_queue) > shared_settings.TASK_QUEUE_LENGTH_THRESHOLD:
logger.debug("Task queue is full. Skipping task generation.")
return None
Expand Down
5 changes: 4 additions & 1 deletion prompting/tasks/task_sending.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from prompting.miner_availability.miner_availability import miner_availabilities

# from prompting.rewards.scoring import task_scorer
from prompting.rewards.scoring import ScoringConfig
from prompting.rewards.scoring_config import ScoringConfig
from prompting.tasks.base_task import BaseTextTask
from prompting.tasks.inference import InferenceTask
from shared.dendrite import DendriteResponseEvent, SynapseStreamResult
Expand Down Expand Up @@ -66,6 +66,9 @@ async def collect_responses(task: BaseTextTask) -> DendriteResponseEvent | None:
response_event = DendriteResponseEvent(
stream_results=stream_results,
uids=uids,
axons=[
shared_settings.METAGRAPH.axons[x].ip + ":" + str(shared_settings.METAGRAPH.axons[x].port) for x in uids
],
timeout=(
shared_settings.INFERENCE_TIMEOUT if isinstance(task, InferenceTask) else shared_settings.NEURON_TIMEOUT
),
Expand Down
8 changes: 3 additions & 5 deletions prompting/weight_setting/weight_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class WeightSetter(AsyncLoopRunner):
"""The weight setter looks at RewardEvents in the reward_events queue and sets the weights of the miners accordingly."""

sync: bool = True
interval: int = 60 * 22 # set rewards every 20 minutes
interval: int = 60 * 25 # set rewards every 25 minutes
reward_events: list[list[WeightedRewardEvent]] | None = None
subtensor: bt.Subtensor | None = None
metagraph: bt.Metagraph | None = None
Expand Down Expand Up @@ -240,10 +240,8 @@ async def run_step(self):
set_weights(
final_rewards, step=self.step, subtensor=shared_settings.SUBTENSOR, metagraph=shared_settings.METAGRAPH
)
self.reward_events = [] # empty reward events queue
logger.debug(f"Pre-Refresh Metagraph Block: {shared_settings.METAGRAPH.block}")
shared_settings.refresh_metagraph()
logger.debug(f"Post-Refresh Metagraph Block: {shared_settings.METAGRAPH.block}")
# TODO: empty rewards queue only on weight setting success
self.reward_events[:] = [] # empty reward events queue
await asyncio.sleep(0.01)
return final_rewards

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "prompting"
version = "2.15.3"
version = "2.15.4"
description = "Subnetwork 1 runs on Bittensor and is maintained by Macrocosmos. It's an effort to create decentralised AI"
authors = ["Kalei Brady, Dmytro Bobrenko, Felix Quinque, Steffen Cruz, Richard Wardle"]
readme = "README.md"
Expand Down
1 change: 1 addition & 0 deletions shared/dendrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def model_dump(self):

class DendriteResponseEvent(BaseModel):
uids: np.ndarray | list[float]
axons: list[str]
timeout: float
stream_results: list[SynapseStreamResult]
completions: list[str] = []
Expand Down
2 changes: 1 addition & 1 deletion shared/epistula.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async def query_miners(uids, body: dict[str, Any]):
exceptions = [resp for resp in responses if isinstance(resp, Exception)]
if exceptions:
for exc in exceptions:
logger.error(f"Error in make_openai_query: {exc}")
logger.debug(f"Error in make_openai_query: {exc}")

# 'responses' is a list of SynapseStreamResult objects
results = []
Expand Down
31 changes: 31 additions & 0 deletions shared/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,34 @@ def serialize_exception_to_string(e):
return serialized_str
else:
return e


def cached_property_with_expiration(expiration_seconds=1200):
"""
Decorator that caches the property's value for `expiration_seconds` seconds.
After this duration, the cached value is refreshed.
"""

def decorator(func):
attr_name = f"_cached_{func.__name__}"

@property
def wrapper(self):
now = time.time()

# Check if we have a cached value and if it's still valid
if hasattr(self, attr_name):
cached_value, timestamp = getattr(self, attr_name)

# If valid, return cached value
if now - timestamp < expiration_seconds:
return cached_value

# Otherwise, compute the new value and cache it
value = func(self)
setattr(self, attr_name, (value, now))
return value

return wrapper

return decorator
15 changes: 7 additions & 8 deletions shared/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings

from shared.misc import cached_property_with_expiration

logging.getLogger("requests").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
Expand Down Expand Up @@ -51,7 +53,7 @@ class SharedSettings(BaseSettings):
NEURON_GPUS: int = Field(1, env="NEURON_GPUS")

# Logging.
LOGGING_DONT_SAVE_EVENTS: bool = Field(False, env="LOGGING_DONT_SAVE_EVENTS")
LOGGING_DONT_SAVE_EVENTS: bool = Field(True, env="LOGGING_DONT_SAVE_EVENTS")
LOG_WEIGHTS: bool = Field(False, env="LOG_WEIGHTS")

# Neuron parameters.
Expand Down Expand Up @@ -243,17 +245,14 @@ def SUBTENSOR(self) -> bt.subtensor:
logger.info(f"Instantiating subtensor with network: {subtensor_network}")
return bt.subtensor(network=subtensor_network)

@cached_property
@cached_property_with_expiration(expiration_seconds=1200)
def METAGRAPH(self) -> bt.metagraph:
logger.info(f"Instantiating metagraph with NETUID: {self.NETUID}")
return self.SUBTENSOR.metagraph(netuid=self.NETUID)

def refresh_metagraph(self) -> bt.metagraph:
logger.debug("Refreshing metagraph")
if "METAGRAPH" in self.__dict__:
del self.__dict__["METAGRAPH"]
logger.debug("Deleting cached metagraph")
return self.METAGRAPH
@cached_property
def UID(self) -> int:
return self.METAGRAPH.hotkeys.index(self.WALLET.hotkey.ss58_address)

@cached_property
def DENDRITE(self) -> bt.dendrite:
Expand Down

0 comments on commit 2b9a7ce

Please sign in to comment.