From 253c09fb5af10df61b31e43f3ebb8875f89acee4 Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Tue, 16 Apr 2024 19:05:10 +0000 Subject: [PATCH 01/21] adds aiohttp to branch --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7db7c1a2..f09823b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,4 +22,5 @@ langchain==0.1.5 langchainhub==0.1.14 python-dotenv wikipedia_sections -vllm \ No newline at end of file +vllm +aiohttp \ No newline at end of file From c1640be80ae3d439b7044209bdc5870d5d4be3d8 Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Tue, 16 Apr 2024 21:48:16 +0000 Subject: [PATCH 02/21] initial push on batch wiki dataset --- prompting/tools/datasets/base.py | 38 ++++- prompting/tools/datasets/batch_wiki.py | 110 +++++++++++++ prompting/tools/datasets/context.py | 6 + prompting/tools/datasets/wiki.py | 145 ++-------------- prompting/utils/__init__.py | 2 + prompting/utils/async_wiki.py | 220 +++++++++++++++++++++++++ prompting/utils/wiki.py | 126 ++++++++++++++ 7 files changed, 512 insertions(+), 135 deletions(-) create mode 100644 prompting/tools/datasets/batch_wiki.py create mode 100644 prompting/utils/async_wiki.py create mode 100644 prompting/utils/wiki.py diff --git a/prompting/tools/datasets/base.py b/prompting/tools/datasets/base.py index ab4b6d1c..37c2dcfd 100644 --- a/prompting/tools/datasets/base.py +++ b/prompting/tools/datasets/base.py @@ -22,7 +22,7 @@ import bittensor as bt from ..selector import Selector -from .context import Context +from .context import Context, BatchContext from prompting.utils.exceptions import MaxRetryError @@ -82,3 +82,39 @@ def next( "next_kwargs": kwargs, } return Context(**info) + + +class BatchDataset(ABC): + """Base class for datasets.""" + + max_tries: int = 10 + + @abstractmethod + async def random_batch(self, name): + ... + + async def next( + self, method: str = "random", selector: Selector = Selector(), **kwargs + ) -> BatchContext: + t0 = time.time() + + for tries in range(1, self.max_tries + 1): + if method == "random": + results = await self.random_batch(selector=selector, **kwargs) + stats = { + "creator": self.__class__.__name__, + "fetch_time": time.time() - t0, + "num_tries": tries, + "fetch_method": method, + "next_kwargs": kwargs, + } + + return BatchContext(results=results, stats=stats) + else: + raise ValueError(f"Unknown dataset get method {method!r}") + + # If no valid info is found after max_tries + raise MaxRetryError( + f"Could not find any samples which meet {self.__class__.__name__} requirements after {self.max_tries} tries." + ) + \ No newline at end of file diff --git a/prompting/tools/datasets/batch_wiki.py b/prompting/tools/datasets/batch_wiki.py new file mode 100644 index 00000000..af5a50af --- /dev/null +++ b/prompting/tools/datasets/batch_wiki.py @@ -0,0 +1,110 @@ +import re +import sys +import random +import datetime +import bittensor as bt +import wikipedia as wiki +from prompting.utils import async_wiki as async_wiki_utils +from typing import Dict, Union, List, Tuple +from functools import lru_cache +from .base import Dataset, BatchDataset +from ..selector import Selector +from prompting.tools.datasets import Context + + + + +@lru_cache(maxsize=1000) +def _get_random_titles(pages=10, seed=42) -> List: + """Cached wikipedia random page. Approximately deterministic random titles. This is useful for testing. + NOTE: the actually cached result will change each session, but the result will be the same within a session. + """ + return wiki.random(pages=pages) + + + + +class BatchWikiDataset(BatchDataset): + """Wikipedia dataset. Uses the wikipedia python api to fetch articles and sections.""" + + EXCLUDE_HEADERS = ("See also", "References", "Further reading", "External links") + EXCLUDE_CATEGORIES = ("articles", "wiki", "pages", "cs1") + + def __init__( + self, + batch_size: int = 16, + min_length_words: int = 50, + max_links: int = 10, + ): + """ + Args: + min_length_words (int, optional): Minimum section length. Defaults to 50. + max_links (int, optional): _description_. Defaults to 10. + """ + self.batch_size = batch_size + self.min_length_words = min_length_words + self.max_links = max_links + + async def get_multiple_pages( + self, + titles: List[str], + selector: Selector = None, + include: List = None, + exclude: List = None, + **kwargs, + ) -> Dict: + """Get a specified Wikipedia page and extract a section based on the selector. + + Args: + name (_type_): _description_ + pageid (_type_, optional): _description_. Defaults to None. + auto_suggest (bool, optional): _description_. Defaults to True. + redirect (bool, optional): _description_. Defaults to True. + selector (Selector, optional): _description_. Defaults to None. + include (List, optional): _description_. Defaults to None. + exclude (List, optional): _description_. Defaults to None. + + Returns: + Dict: _description_ + """ + pages = await async_wiki_utils.fetch_pages(titles, **kwargs) + + # Only return a sections with a minimum number of words + exclude = (exclude or []) + list(self.EXCLUDE_HEADERS) + + # TODO: FIX THE RETURN FOR PROCESS PAGES TO BE A MANAGABLE TYPE + sections = await async_wiki_utils.process_pages( + pages, + valid_header=lambda x: x not in exclude and (not include or x in include), + valid_content=lambda x: len(x.split()) >= self.min_length_words, + ) + + + key = header, section_title = selector(list(sections.keys())) + content = "\n".join(sections[key]) + section_length = len(content.split()) + return { + "title": name, # title of wiki article + "topic": header or section_title, # title of wiki section + "subtopic": section_title, + "content": content, + "internal_links": list(filter(lambda x: x not in exclude, page.sections)), + "external_links": most_relevant_links(page, num_links=self.max_links), + "tags": filter_categories(page.categories, exclude=self.EXCLUDE_CATEGORIES), + "source": "Wikipedia", + "extra": { + "url": page.url, + "page_length": len(page.content.split()), + "section_length": section_length, + }, + } + + async def random_batch(self, seed=None, selector: Selector = None, **kwargs) -> List[Context]: + """Get random batch of wikipedia pages.""" + random_titles = ( + wiki.random(pages=self.batch_size) + if seed is None + else _get_random_titles(pages=self.batch_size, seed=seed) + ) + + return await self.get_multiple_pages(random_titles, selector=selector) \ No newline at end of file diff --git a/prompting/tools/datasets/context.py b/prompting/tools/datasets/context.py index a9918c72..20962ce0 100644 --- a/prompting/tools/datasets/context.py +++ b/prompting/tools/datasets/context.py @@ -15,3 +15,9 @@ class Context: tags: List[str] = None extra: dict = None # additional non-essential information stats: dict = None # retrieval stats such as fetch time, number of tries, etc. + + +@dataclass +class BatchContext: + results: List[Context] + stats: dict = None \ No newline at end of file diff --git a/prompting/tools/datasets/wiki.py b/prompting/tools/datasets/wiki.py index dd516fd6..69c6ffd8 100644 --- a/prompting/tools/datasets/wiki.py +++ b/prompting/tools/datasets/wiki.py @@ -17,138 +17,15 @@ # DEALINGS IN THE SOFTWARE. import re -import sys import random import datetime -import bittensor as bt import wikipedia as wiki -from typing import Dict, Union, List, Tuple - -from functools import lru_cache +from typing import Dict, List +from prompting.utils import wiki as wiki_utils from .base import Dataset from ..selector import Selector -# speed up page loading -@lru_cache(maxsize=1000) -def _get_page( - title, pageid=None, auto_suggest=False, redirect=True, seed=None -) -> wiki.WikipediaPage: - """Cached Wikipedia page loading.""" - try: - page = wiki.page( - title=title, pageid=pageid, auto_suggest=auto_suggest, redirect=redirect - ) - # create sections manually if not found - if not page.sections: - page._sections = [ - line.strip("= ") - for line in page.content.splitlines() - if re.search(r"=+\s+.*\s+=+", line) - ] - return page - - except wiki.DisambiguationError as e: - bt.logging.debug(f"{e.__class__.__name__} loading page {title!r}: {e}") - # exc info contains a tuple of (requested_title: str, possible_matches: List[str]) - pages = sys.exc_info()[1].args[1] - if not type(pages) == list: - return None - title = random.Random(seed).choice(pages) - return _get_page(title, auto_suggest=auto_suggest, redirect=redirect) - - except wiki.PageError as e: - bt.logging.warning(f"{e.__class__.__name__} loading page {title!r}: {e}") - if not auto_suggest: - return _get_page(title, auto_suggest=True, redirect=redirect) - return None - - -@lru_cache(maxsize=1000) -def _get_random_titles(pages=10, seed=42) -> List: - """Cached wikipedia random page. Approximately deterministic random titles. This is useful for testing. - NOTE: the actually cached result will change each session, but the result will be the same within a session. - """ - return wiki.random(pages=pages) - - -@lru_cache(maxsize=1000) -def _wiki_search(name, results) -> List: - """Cached Wikipedia search.""" - return wiki.search(name, results=results) - - -def process_page( - page, valid_header: callable = None, valid_content: callable = None -) -> Dict: - """Process a Wikipedia page and return a dictionary of sections with their content. - - Args: - page: wikipedia.WikipediaPage - valid_header: callable to determine if a section header is valid - valid_content: callable to determine if a section content is valid - Returns: - dict: dictionary of sections and their content. Note that keys are tuples (header, section_title) - """ - header = "" - sections = {} - - for section_title in page.sections: - content = page.section(section_title) - if not content: - header = section_title - continue - - # Filter out sections that don't match the headers and/or are not valid - if (valid_header and not valid_header(header)) or ( - valid_content and not valid_content(content) - ): - continue - - key = (header, section_title) - sections[key] = content.splitlines() - - if not sections: - bt.logging.debug(f"No valid sections found in page {page.title!r} ({page.url})") - - return sections - - -def most_relevant_links(page, num_links=10, num_summary_words=50, return_scores=False): - """Return the most relevant links to a Wikipedia page based on the intersection over union (IOU) of the link and the page summary.""" - link_scores = {} - summary_words = set(page.summary.split()[:num_summary_words]) - for link in page.links: - link_words = set(link.split()) - iou = len(summary_words.intersection(link_words)) / len( - summary_words.union(link_words) - ) - link_scores[link] = iou / len(link.split()) - - sorted_links = sorted(link_scores.items(), key=lambda x: x[1], reverse=True) - if return_scores: - return sorted_links[:num_links] - - return [link for link, _ in sorted_links[:num_links]] - - -def filter_categories(categories, exclude=None, include=None): - """Filter categories based on a list of categories to exclude and/or include.""" - if exclude: - categories = [ - cat - for cat in categories - if not re.search("|".join(exclude), cat, re.IGNORECASE) - ] - if include: - categories = [ - cat - for cat in categories - if re.search("|".join(include), cat, re.IGNORECASE) - ] - return categories - - class WikiDataset(Dataset): """Wikipedia dataset. Uses the wikipedia python api to fetch articles and sections.""" @@ -191,13 +68,13 @@ def get( Dict: _description_ """ - page = _get_page(title=name, **kwargs) + page = wiki_utils._get_page(title=name, **kwargs) if page is None: return None # Only return a sections with a minimum number of words exclude = (exclude or []) + list(self.EXCLUDE_HEADERS) - sections = process_page( + sections = wiki_utils.process_page( page, valid_header=lambda x: x not in exclude and (not include or x in include), valid_content=lambda x: len(x.split()) >= self.min_length_words, @@ -214,8 +91,8 @@ def get( "subtopic": section_title, "content": content, "internal_links": list(filter(lambda x: x not in exclude, page.sections)), - "external_links": most_relevant_links(page, num_links=self.max_links), - "tags": filter_categories(page.categories, exclude=self.EXCLUDE_CATEGORIES), + "external_links": wiki_utils.most_relevant_links(page, num_links=self.max_links), + "tags": wiki_utils.filter_categories(page.categories, exclude=self.EXCLUDE_CATEGORIES), "source": "Wikipedia", "extra": { "url": page.url, @@ -225,7 +102,7 @@ def get( } def search(self, name, results=3, selector: Selector = None) -> Dict: - titles = _wiki_search(name, results=results) + titles = wiki_utils._wiki_search(name, results=results) title = selector(titles) return self.get(title, selector=selector) @@ -233,7 +110,7 @@ def random(self, pages=10, seed=None, selector: Selector = None, **kwargs) -> Di titles = ( wiki.random(pages=pages) if seed is None - else _get_random_titles(pages=pages, seed=seed) + else wiki_utils._get_random_titles(pages=pages, seed=seed) ) title = selector(titles) return self.get(title, selector=selector) @@ -296,7 +173,7 @@ def get( ), f"Month should be one of {self.MONTHS}, but got {date[0]!r}" assert date[1].isdigit(), f"Day should be a number, but got {date[1]!r}" - page = _get_page( + page = wiki_utils._get_page( title=name, pageid=pageid, auto_suggest=auto_suggest, redirect=redirect ) if page is None: @@ -304,7 +181,7 @@ def get( # Only return a sections which contain event-like format # e.g. "1999 - Some event happened" - sections = process_page( + sections = wiki_utils.process_page( page, valid_header=lambda x: x in self.INCLUDE_HEADERS, valid_content=lambda x: any( @@ -326,7 +203,7 @@ def get( "content": "-".join(event).strip(". "), "internal_links": list(sections.keys()), "external_links": links, - "tags": filter_categories( + "tags": wiki_utils.filter_categories( page.categories, exclude=WikiDataset.EXCLUDE_CATEGORIES ), "source": "Wikipedia", diff --git a/prompting/utils/__init__.py b/prompting/utils/__init__.py index 359321ed..e3619ecb 100644 --- a/prompting/utils/__init__.py +++ b/prompting/utils/__init__.py @@ -2,3 +2,5 @@ from . import misc from . import uids from . import logging +from . import async_wiki +from . import wiki \ No newline at end of file diff --git a/prompting/utils/async_wiki.py b/prompting/utils/async_wiki.py new file mode 100644 index 00000000..ff0c20e6 --- /dev/null +++ b/prompting/utils/async_wiki.py @@ -0,0 +1,220 @@ +import aiohttp +import asyncio +import bittensor as bt +import sys +import random +from datetime import datetime, timedelta +from wikipedia import USER_AGENT, RATE_LIMIT, RATE_LIMIT_MIN_WAIT, API_URL, WikipediaPage, DisambiguationError, PageError +from wikipedia.exceptions import HTTPTimeoutError, WikipediaException, PageError +from functools import lru_cache +from typing import Dict, List + + + +##################### Wraps and overwrite wikipedia features to make it async ##################### + +async def _async_wiki_request(params): + global RATE_LIMIT_LAST_CALL + global USER_AGENT + + params['format'] = 'json' + if 'action' not in params: + params['action'] = 'query' + + headers = {'User-Agent': USER_AGENT} + + if RATE_LIMIT and RATE_LIMIT_LAST_CALL and \ + RATE_LIMIT_LAST_CALL + RATE_LIMIT_MIN_WAIT > datetime.now(): + + wait_time = (RATE_LIMIT_LAST_CALL + RATE_LIMIT_MIN_WAIT) - datetime.now() + await asyncio.sleep(wait_time.total_seconds()) + + async with aiohttp.ClientSession() as session: + async with session.get(API_URL, params=params, headers=headers) as response: + if RATE_LIMIT: + RATE_LIMIT_LAST_CALL = datetime.now() + return await response.json() + + +class AsyncWikipediaPage(WikipediaPage): + @property + async def sections(self): + ''' + Overwrites the `sections` property to be async. + ''' + if not getattr(self, '_sections', False): + query_params = { + 'action': 'parse', + 'prop': 'sections', + } + if not getattr(self, 'title', None) is None: + query_params["page"] = self.title + + request = await _async_wiki_request(query_params) + self._sections = [section['line'] for section in request['parse']['sections']] + + return self._sections + + async def section(self, section_title: str): + ''' + Get the plain text content of a section from `self.sections`. + Returns None if `section_title` isn't found, otherwise returns a whitespace stripped string. + + This is a convenience method that wraps self.content. + + .. warning:: Calling `section` on a section that has subheadings will NOT return + the full text of all of the subsections. It only gets the text between + `section_title` and the next subheading, which is often empty. + ''' + + section = u"== {} ==".format(section_title) + content = await self.content + + try: + index = content.index(section) + len(section) + except ValueError: + return None + + try: + next_index = content.index("==", index) + except ValueError: + next_index = len(content) + + return content[index:next_index].lstrip("=").strip() + + @property + async def content(self): + ''' + Overwrites the `content` property that is called by the `section` property. + This change enables the `content` property to be called independently in async. + ''' + + if not getattr(self, '_content', False): + query_params = { + 'prop': 'extracts|revisions', + 'explaintext': '', + 'rvprop': 'ids' + } + if not getattr(self, 'title', None) is None: + query_params['titles'] = self.title + else: + query_params['pageids'] = self.pageid + request = await _async_wiki_request(query_params) + self._content = request['query']['pages'][self.pageid]['extract'] + self._revision_id = request['query']['pages'][self.pageid]['revisions'][0]['revid'] + self._parent_id = request['query']['pages'][self.pageid]['revisions'][0]['parentid'] + + return self._content + + +##################### Utility functions ##################### + +async def process_page( + page: AsyncWikipediaPage, valid_header: callable = None, valid_content: callable = None +) -> Dict: + """Process a Wikipedia page and return a dictionary of sections with their content. + + Args: + page: wikipedia.WikipediaPage + valid_header: callable to determine if a section header is valid + valid_content: callable to determine if a section content is valid + Returns: + dict: dictionary of sections and their content. Note that keys are tuples (header, section_title) + """ + header = "" + sections = {} + + for section_title in await page.sections: + content = await page.section(section_title=section_title) + if not content: + header = section_title + continue + + # Filter out sections that don't match the headers and/or are not valid + if (valid_header and not valid_header(header)) or ( + valid_content and not valid_content(content) + ): + continue + + key = (header, section_title) + sections[key] = content.splitlines() + + if not sections: + bt.logging.debug(f"No valid sections found in page {page.title!r} ({page.url})") + + return sections + +async def process_pages( + pages: List[AsyncWikipediaPage], valid_header: callable = None, valid_content: callable = None +): + tasks = [process_page(page, valid_header, valid_content) for page in pages] + sections = await asyncio.gather(*tasks) + return sections + + +async def search(query, results=10, suggestion=False): + ''' Overwrites wikipedia base functions to use aiohttp and make it async ''' + search_params = { + 'list': 'search', + 'srprop': '', + 'srlimit': results, + 'limit': results, + 'srsearch': query + } + if suggestion: + search_params['srinfo'] = 'suggestion' + + raw_results = await _async_wiki_request(search_params) + + if 'error' in raw_results: + if raw_results['error']['info'] in ('HTTP request timed out.', 'Pool queue is full'): + raise HTTPTimeoutError(query) + else: + raise WikipediaException(raw_results['error']['info']) + + search_results = (d['title'] for d in raw_results['query']['search']) + + if suggestion: + if raw_results['query'].get('searchinfo'): + return list(search_results), raw_results['query']['searchinfo']['suggestion'] + else: + return list(search_results), None + + return list(search_results) + +@lru_cache(maxsize=1000) +async def get_async_page(title=None, pageid=None, auto_suggest=True, redirect=True, preload=False, seed=None): + try: + if title is not None: + if auto_suggest: + results, suggestion = await search(title, results=1, suggestion=True) + try: + title = suggestion or results[0] + except IndexError: + raise PageError(title) + # Assuming WikipediaPage is a class that needs to be defined or imported + return AsyncWikipediaPage(title, redirect=redirect, preload=preload) + elif pageid is not None: + return AsyncWikipediaPage(pageid=pageid, preload=preload) + else: + raise ValueError("Either a title or a pageid must be specified") + except DisambiguationError as e: + bt.logging.debug(f"{e.__class__.__name__} loading page {title!r}: {e}") + # exc info contains a tuple of (requested_title: str, possible_matches: List[str]) + pages = sys.exc_info()[1].args[1] + if not type(pages) == list: + return None + title = random.Random(seed).choice(pages) + return await get_async_page(title, auto_suggest=auto_suggest, redirect=redirect) + + except PageError as e: + bt.logging.warning(f"{e.__class__.__name__} loading page {title!r}: {e}") + if not auto_suggest: + return await get_async_page(title, auto_suggest=True, redirect=redirect) + return None + + +async def fetch_pages(titles, auto_suggest=False, redirect=True, preload=False): + tasks = [get_async_page(title=title, auto_suggest=auto_suggest, redirect=redirect, preload=preload) for title in titles] + pages = await asyncio.gather(*tasks) + return pages diff --git a/prompting/utils/wiki.py b/prompting/utils/wiki.py new file mode 100644 index 00000000..233a2bd7 --- /dev/null +++ b/prompting/utils/wiki.py @@ -0,0 +1,126 @@ +import re +import sys +import random +import bittensor as bt +import wikipedia as wiki +from typing import Dict, List +from functools import lru_cache + +# speed up page loading +@lru_cache(maxsize=1000) +def _get_page( + title, pageid=None, auto_suggest=False, redirect=True, seed=None +) -> wiki.WikipediaPage: + """Cached Wikipedia page loading.""" + try: + page = wiki.page( + title=title, pageid=pageid, auto_suggest=auto_suggest, redirect=redirect + ) + # create sections manually if not found + if not page.sections: + page._sections = [ + line.strip("= ") + for line in page.content.splitlines() + if re.search(r"=+\s+.*\s+=+", line) + ] + return page + + except wiki.DisambiguationError as e: + bt.logging.debug(f"{e.__class__.__name__} loading page {title!r}: {e}") + # exc info contains a tuple of (requested_title: str, possible_matches: List[str]) + pages = sys.exc_info()[1].args[1] + if not type(pages) == list: + return None + title = random.Random(seed).choice(pages) + return _get_page(title, auto_suggest=auto_suggest, redirect=redirect) + + except wiki.PageError as e: + bt.logging.warning(f"{e.__class__.__name__} loading page {title!r}: {e}") + if not auto_suggest: + return _get_page(title, auto_suggest=True, redirect=redirect) + return None + + +@lru_cache(maxsize=1000) +def _get_random_titles(pages=10, seed=42) -> List: + """Cached wikipedia random page. Approximately deterministic random titles. This is useful for testing. + NOTE: the actually cached result will change each session, but the result will be the same within a session. + """ + return wiki.random(pages=pages) + + +@lru_cache(maxsize=1000) +def _wiki_search(name, results) -> List: + """Cached Wikipedia search.""" + return wiki.search(name, results=results) + + +def process_page( + page, valid_header: callable = None, valid_content: callable = None +) -> Dict: + """Process a Wikipedia page and return a dictionary of sections with their content. + + Args: + page: wikipedia.WikipediaPage + valid_header: callable to determine if a section header is valid + valid_content: callable to determine if a section content is valid + Returns: + dict: dictionary of sections and their content. Note that keys are tuples (header, section_title) + """ + header = "" + sections = {} + + for section_title in page.sections: + content = page.section(section_title) + if not content: + header = section_title + continue + + # Filter out sections that don't match the headers and/or are not valid + if (valid_header and not valid_header(header)) or ( + valid_content and not valid_content(content) + ): + continue + + key = (header, section_title) + sections[key] = content.splitlines() + + if not sections: + bt.logging.debug(f"No valid sections found in page {page.title!r} ({page.url})") + + return sections + + +def most_relevant_links(page, num_links=10, num_summary_words=50, return_scores=False): + """Return the most relevant links to a Wikipedia page based on the intersection over union (IOU) of the link and the page summary.""" + link_scores = {} + summary_words = set(page.summary.split()[:num_summary_words]) + for link in page.links: + link_words = set(link.split()) + iou = len(summary_words.intersection(link_words)) / len( + summary_words.union(link_words) + ) + link_scores[link] = iou / len(link.split()) + + sorted_links = sorted(link_scores.items(), key=lambda x: x[1], reverse=True) + if return_scores: + return sorted_links[:num_links] + + return [link for link, _ in sorted_links[:num_links]] + + +def filter_categories(categories, exclude=None, include=None): + """Filter categories based on a list of categories to exclude and/or include.""" + if exclude: + categories = [ + cat + for cat in categories + if not re.search("|".join(exclude), cat, re.IGNORECASE) + ] + if include: + categories = [ + cat + for cat in categories + if re.search("|".join(include), cat, re.IGNORECASE) + ] + return categories From 92936c9acb7677f930c825765056526a4be2fec7 Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Thu, 18 Apr 2024 21:23:45 +0000 Subject: [PATCH 03/21] adds async wiki wrapper --- prompting/tools/datasets/__init__.py | 3 +- prompting/tools/datasets/base.py | 1 - prompting/tools/datasets/batch_wiki.py | 94 +---------- prompting/utils/__init__.py | 3 +- prompting/utils/async_wiki.py | 220 ------------------------- prompting/utils/async_wiki_utils.py | 145 ++++++++++++++++ prompting/utils/async_wiki_wrapper.py | 206 +++++++++++++++++++++++ 7 files changed, 363 insertions(+), 309 deletions(-) delete mode 100644 prompting/utils/async_wiki.py create mode 100644 prompting/utils/async_wiki_utils.py create mode 100644 prompting/utils/async_wiki_wrapper.py diff --git a/prompting/tools/datasets/__init__.py b/prompting/tools/datasets/__init__.py index 66c9e4de..7fa92044 100644 --- a/prompting/tools/datasets/__init__.py +++ b/prompting/tools/datasets/__init__.py @@ -1,6 +1,7 @@ -from .context import Context +from .context import Context, BatchContext from .base import Dataset from .code import HFCodingDataset, StackOverflowDataset from .math import MathDataset from .mock import MockDataset from .wiki import WikiDataset, WikiDateDataset +from .batch_wiki import BatchWikiDataset diff --git a/prompting/tools/datasets/base.py b/prompting/tools/datasets/base.py index 37c2dcfd..f77e9b0d 100644 --- a/prompting/tools/datasets/base.py +++ b/prompting/tools/datasets/base.py @@ -86,7 +86,6 @@ def next( class BatchDataset(ABC): """Base class for datasets.""" - max_tries: int = 10 @abstractmethod diff --git a/prompting/tools/datasets/batch_wiki.py b/prompting/tools/datasets/batch_wiki.py index af5a50af..658860c8 100644 --- a/prompting/tools/datasets/batch_wiki.py +++ b/prompting/tools/datasets/batch_wiki.py @@ -1,27 +1,8 @@ -import re -import sys -import random -import datetime -import bittensor as bt -import wikipedia as wiki -from prompting.utils import async_wiki as async_wiki_utils -from typing import Dict, Union, List, Tuple -from functools import lru_cache -from .base import Dataset, BatchDataset -from ..selector import Selector +from prompting.utils import async_wiki_utils as async_wiki_utils +from typing import List +from .base import BatchDataset from prompting.tools.datasets import Context - - - - -@lru_cache(maxsize=1000) -def _get_random_titles(pages=10, seed=42) -> List: - """Cached wikipedia random page. Approximately deterministic random titles. This is useful for testing. - NOTE: the actually cached result will change each session, but the result will be the same within a session. - """ - return wiki.random(pages=pages) - - +from prompting.utils.custom_async_wiki import get_batch_random_sections class BatchWikiDataset(BatchDataset): @@ -45,66 +26,7 @@ def __init__( self.min_length_words = min_length_words self.max_links = max_links - async def get_multiple_pages( - self, - titles: List[str], - selector: Selector = None, - include: List = None, - exclude: List = None, - **kwargs, - ) -> Dict: - """Get a specified Wikipedia page and extract a section based on the selector. - - Args: - name (_type_): _description_ - pageid (_type_, optional): _description_. Defaults to None. - auto_suggest (bool, optional): _description_. Defaults to True. - redirect (bool, optional): _description_. Defaults to True. - selector (Selector, optional): _description_. Defaults to None. - include (List, optional): _description_. Defaults to None. - exclude (List, optional): _description_. Defaults to None. - - Returns: - Dict: _description_ - """ - pages = await async_wiki_utils.fetch_pages(titles, **kwargs) - - # Only return a sections with a minimum number of words - exclude = (exclude or []) + list(self.EXCLUDE_HEADERS) - - # TODO: FIX THE RETURN FOR PROCESS PAGES TO BE A MANAGABLE TYPE - sections = await async_wiki_utils.process_pages( - pages, - valid_header=lambda x: x not in exclude and (not include or x in include), - valid_content=lambda x: len(x.split()) >= self.min_length_words, - ) - - - key = header, section_title = selector(list(sections.keys())) - content = "\n".join(sections[key]) - section_length = len(content.split()) - return { - "title": name, # title of wiki article - "topic": header or section_title, # title of wiki section - "subtopic": section_title, - "content": content, - "internal_links": list(filter(lambda x: x not in exclude, page.sections)), - "external_links": most_relevant_links(page, num_links=self.max_links), - "tags": filter_categories(page.categories, exclude=self.EXCLUDE_CATEGORIES), - "source": "Wikipedia", - "extra": { - "url": page.url, - "page_length": len(page.content.split()), - "section_length": section_length, - }, - } - - async def random_batch(self, seed=None, selector: Selector = None, **kwargs) -> List[Context]: - """Get random batch of wikipedia pages.""" - random_titles = ( - wiki.random(pages=self.batch_size) - if seed is None - else _get_random_titles(pages=self.batch_size, seed=seed) - ) - - return await self.get_multiple_pages(random_titles, selector=selector) \ No newline at end of file + + async def random_batch(self) -> List[Context]: + contexts = await get_batch_random_sections() + return contexts diff --git a/prompting/utils/__init__.py b/prompting/utils/__init__.py index e3619ecb..387d0167 100644 --- a/prompting/utils/__init__.py +++ b/prompting/utils/__init__.py @@ -2,5 +2,6 @@ from . import misc from . import uids from . import logging -from . import async_wiki +from . import async_wiki_utils +from . import async_wiki_wrapper from . import wiki \ No newline at end of file diff --git a/prompting/utils/async_wiki.py b/prompting/utils/async_wiki.py deleted file mode 100644 index ff0c20e6..00000000 --- a/prompting/utils/async_wiki.py +++ /dev/null @@ -1,220 +0,0 @@ -import aiohttp -import asyncio -import bittensor as bt -import sys -import random -from datetime import datetime, timedelta -from wikipedia import USER_AGENT, RATE_LIMIT, RATE_LIMIT_MIN_WAIT, API_URL, WikipediaPage, DisambiguationError, PageError -from wikipedia.exceptions import HTTPTimeoutError, WikipediaException, PageError -from functools import lru_cache -from typing import Dict, List - - - -##################### Wraps and overwrite wikipedia features to make it async ##################### - -async def _async_wiki_request(params): - global RATE_LIMIT_LAST_CALL - global USER_AGENT - - params['format'] = 'json' - if 'action' not in params: - params['action'] = 'query' - - headers = {'User-Agent': USER_AGENT} - - if RATE_LIMIT and RATE_LIMIT_LAST_CALL and \ - RATE_LIMIT_LAST_CALL + RATE_LIMIT_MIN_WAIT > datetime.now(): - - wait_time = (RATE_LIMIT_LAST_CALL + RATE_LIMIT_MIN_WAIT) - datetime.now() - await asyncio.sleep(wait_time.total_seconds()) - - async with aiohttp.ClientSession() as session: - async with session.get(API_URL, params=params, headers=headers) as response: - if RATE_LIMIT: - RATE_LIMIT_LAST_CALL = datetime.now() - return await response.json() - - -class AsyncWikipediaPage(WikipediaPage): - @property - async def sections(self): - ''' - Overwrites the `sections` property to be async. - ''' - if not getattr(self, '_sections', False): - query_params = { - 'action': 'parse', - 'prop': 'sections', - } - if not getattr(self, 'title', None) is None: - query_params["page"] = self.title - - request = await _async_wiki_request(query_params) - self._sections = [section['line'] for section in request['parse']['sections']] - - return self._sections - - async def section(self, section_title: str): - ''' - Get the plain text content of a section from `self.sections`. - Returns None if `section_title` isn't found, otherwise returns a whitespace stripped string. - - This is a convenience method that wraps self.content. - - .. warning:: Calling `section` on a section that has subheadings will NOT return - the full text of all of the subsections. It only gets the text between - `section_title` and the next subheading, which is often empty. - ''' - - section = u"== {} ==".format(section_title) - content = await self.content - - try: - index = content.index(section) + len(section) - except ValueError: - return None - - try: - next_index = content.index("==", index) - except ValueError: - next_index = len(content) - - return content[index:next_index].lstrip("=").strip() - - @property - async def content(self): - ''' - Overwrites the `content` property that is called by the `section` property. - This change enables the `content` property to be called independently in async. - ''' - - if not getattr(self, '_content', False): - query_params = { - 'prop': 'extracts|revisions', - 'explaintext': '', - 'rvprop': 'ids' - } - if not getattr(self, 'title', None) is None: - query_params['titles'] = self.title - else: - query_params['pageids'] = self.pageid - request = await _async_wiki_request(query_params) - self._content = request['query']['pages'][self.pageid]['extract'] - self._revision_id = request['query']['pages'][self.pageid]['revisions'][0]['revid'] - self._parent_id = request['query']['pages'][self.pageid]['revisions'][0]['parentid'] - - return self._content - - -##################### Utility functions ##################### - -async def process_page( - page: AsyncWikipediaPage, valid_header: callable = None, valid_content: callable = None -) -> Dict: - """Process a Wikipedia page and return a dictionary of sections with their content. - - Args: - page: wikipedia.WikipediaPage - valid_header: callable to determine if a section header is valid - valid_content: callable to determine if a section content is valid - Returns: - dict: dictionary of sections and their content. Note that keys are tuples (header, section_title) - """ - header = "" - sections = {} - - for section_title in await page.sections: - content = await page.section(section_title=section_title) - if not content: - header = section_title - continue - - # Filter out sections that don't match the headers and/or are not valid - if (valid_header and not valid_header(header)) or ( - valid_content and not valid_content(content) - ): - continue - - key = (header, section_title) - sections[key] = content.splitlines() - - if not sections: - bt.logging.debug(f"No valid sections found in page {page.title!r} ({page.url})") - - return sections - -async def process_pages( - pages: List[AsyncWikipediaPage], valid_header: callable = None, valid_content: callable = None -): - tasks = [process_page(page, valid_header, valid_content) for page in pages] - sections = await asyncio.gather(*tasks) - return sections - - -async def search(query, results=10, suggestion=False): - ''' Overwrites wikipedia base functions to use aiohttp and make it async ''' - search_params = { - 'list': 'search', - 'srprop': '', - 'srlimit': results, - 'limit': results, - 'srsearch': query - } - if suggestion: - search_params['srinfo'] = 'suggestion' - - raw_results = await _async_wiki_request(search_params) - - if 'error' in raw_results: - if raw_results['error']['info'] in ('HTTP request timed out.', 'Pool queue is full'): - raise HTTPTimeoutError(query) - else: - raise WikipediaException(raw_results['error']['info']) - - search_results = (d['title'] for d in raw_results['query']['search']) - - if suggestion: - if raw_results['query'].get('searchinfo'): - return list(search_results), raw_results['query']['searchinfo']['suggestion'] - else: - return list(search_results), None - - return list(search_results) - -@lru_cache(maxsize=1000) -async def get_async_page(title=None, pageid=None, auto_suggest=True, redirect=True, preload=False, seed=None): - try: - if title is not None: - if auto_suggest: - results, suggestion = await search(title, results=1, suggestion=True) - try: - title = suggestion or results[0] - except IndexError: - raise PageError(title) - # Assuming WikipediaPage is a class that needs to be defined or imported - return AsyncWikipediaPage(title, redirect=redirect, preload=preload) - elif pageid is not None: - return AsyncWikipediaPage(pageid=pageid, preload=preload) - else: - raise ValueError("Either a title or a pageid must be specified") - except DisambiguationError as e: - bt.logging.debug(f"{e.__class__.__name__} loading page {title!r}: {e}") - # exc info contains a tuple of (requested_title: str, possible_matches: List[str]) - pages = sys.exc_info()[1].args[1] - if not type(pages) == list: - return None - title = random.Random(seed).choice(pages) - return await get_async_page(title, auto_suggest=auto_suggest, redirect=redirect) - - except PageError as e: - bt.logging.warning(f"{e.__class__.__name__} loading page {title!r}: {e}") - if not auto_suggest: - return await get_async_page(title, auto_suggest=True, redirect=redirect) - return None - - -async def fetch_pages(titles, auto_suggest=False, redirect=True, preload=False): - tasks = [get_async_page(title=title, auto_suggest=auto_suggest, redirect=redirect, preload=preload) for title in titles] - pages = await asyncio.gather(*tasks) - return pages diff --git a/prompting/utils/async_wiki_utils.py b/prompting/utils/async_wiki_utils.py new file mode 100644 index 00000000..cd59e5a7 --- /dev/null +++ b/prompting/utils/async_wiki_utils.py @@ -0,0 +1,145 @@ +import asyncio +import bittensor as bt +import sys +import random +from wikipedia import DisambiguationError, PageError +from wikipedia.exceptions import HTTPTimeoutError, WikipediaException, PageError +from functools import lru_cache +from typing import List +from prompting.utils.async_wiki_wrapper import AsyncWikipediaPage, _async_wiki_request +from dataclasses import dataclass + +@dataclass +class ProcessedSection: + header: str + section_title: str + content: List[str] + + def get_str_content(self): + return "\n".join(self.content) + +@dataclass +class ProcessedPage: + title: str + url: str + page: AsyncWikipediaPage + sections: List[ProcessedSection] + + +async def process_page( + page: AsyncWikipediaPage, valid_header: callable = None, valid_content: callable = None +) -> ProcessedPage: + """Process a Wikipedia page and return a dictionary of sections with their content. + + Args: + page: wikipedia.WikipediaPage + valid_header: callable to determine if a section header is valid + valid_content: callable to determine if a section content is valid + Returns: + dict: dictionary of sections and their content. Note that keys are tuples (header, section_title) + """ + header = "" + page_sections = [] + + # Get all section titles first + section_titles = await page.sections + + # Concurrently get the content of all sections + contents = await asyncio.gather( + *(page.section(section_title=title) for title in section_titles) + ) + + for section_title, content in zip(section_titles, contents): + if not content: + header = section_title + continue + + # Filter out sections that don't match the headers and/or are not valid + if (valid_header and not valid_header(header)) or ( + valid_content and not valid_content(content) + ): + continue + + section = ProcessedSection(header=header, section_title=section_title, content=content.splitlines()) + page_sections.append(section) + + if not page_sections: + bt.logging.debug(f"No valid sections found in page {page.title!r} ({page.url})") + + return ProcessedPage(title=page.title, sections=page_sections, url=page.url, page=page) + +async def process_pages( + pages: List[AsyncWikipediaPage], valid_header: callable = None, valid_content: callable = None +): + tasks = [process_page(page, valid_header, valid_content) for page in pages] + sections = await asyncio.gather(*tasks) + return sections + + +async def search(query, results=10, suggestion=False): + ''' Overwrites wikipedia base functions to use aiohttp and make it async ''' + search_params = { + 'list': 'search', + 'srprop': '', + 'srlimit': results, + 'limit': results, + 'srsearch': query + } + if suggestion: + search_params['srinfo'] = 'suggestion' + + raw_results = await _async_wiki_request(search_params) + + if 'error' in raw_results: + if raw_results['error']['info'] in ('HTTP request timed out.', 'Pool queue is full'): + raise HTTPTimeoutError(query) + else: + raise WikipediaException(raw_results['error']['info']) + + search_results = (d['title'] for d in raw_results['query']['search']) + + if suggestion: + if raw_results['query'].get('searchinfo'): + return list(search_results), raw_results['query']['searchinfo']['suggestion'] + else: + return list(search_results), None + + return list(search_results) + +@lru_cache(maxsize=1000) +async def get_async_page(title=None, pageid=None, auto_suggest=True, redirect=True, preload=False, seed=None): + try: + if title is not None: + if auto_suggest: + results, suggestion = await search(title, results=1, suggestion=True) + try: + title = suggestion or results[0] + except IndexError: + raise PageError(title) + # Assuming WikipediaPage is a class that needs to be defined or imported + wiki_page = AsyncWikipediaPage(title=title, redirect=redirect, preload=preload) + return wiki_page + elif pageid is not None: + wiki_page = AsyncWikipediaPage(pageid=pageid, preload=preload) + else: + raise ValueError("Either a title or a pageid must be specified") + except DisambiguationError as e: + bt.logging.debug(f"{e.__class__.__name__} loading page {title!r}: {e}") + # exc info contains a tuple of (requested_title: str, possible_matches: List[str]) + pages = sys.exc_info()[1].args[1] + if not type(pages) == list: + return None + title = random.Random(seed).choice(pages) + return await get_async_page(title, auto_suggest=auto_suggest, redirect=redirect) + + except PageError as e: + bt.logging.warning(f"{e.__class__.__name__} loading page {title!r}: {e}") + if not auto_suggest: + return await get_async_page(title, auto_suggest=True, redirect=redirect) + return None + + +async def fetch_pages(titles, auto_suggest=False, redirect=True, preload=False): + tasks = [get_async_page(title=title, auto_suggest=auto_suggest, redirect=redirect, preload=preload) for title in titles] + pages = await asyncio.gather(*tasks) + return pages diff --git a/prompting/utils/async_wiki_wrapper.py b/prompting/utils/async_wiki_wrapper.py new file mode 100644 index 00000000..6b6648cd --- /dev/null +++ b/prompting/utils/async_wiki_wrapper.py @@ -0,0 +1,206 @@ +import bittensor as bt +import aiohttp +import asyncio +from datetime import datetime +from wikipedia import USER_AGENT, RATE_LIMIT, RATE_LIMIT_MIN_WAIT, API_URL, ODD_ERROR_MESSAGE, WikipediaPage, DisambiguationError, PageError +from wikipedia.exceptions import PageError, RedirectError +from bs4 import BeautifulSoup + +async def _async_wiki_request(params): + global RATE_LIMIT_LAST_CALL + global USER_AGENT + + params['format'] = 'json' + if 'action' not in params: + params['action'] = 'query' + + headers = {'User-Agent': USER_AGENT} + + if RATE_LIMIT and RATE_LIMIT_LAST_CALL and \ + RATE_LIMIT_LAST_CALL + RATE_LIMIT_MIN_WAIT > datetime.now(): + + wait_time = (RATE_LIMIT_LAST_CALL + RATE_LIMIT_MIN_WAIT) - datetime.now() + await asyncio.sleep(wait_time.total_seconds()) + + # bt.logging.info('*' * 15) + # bt.logging.info('Querying Wikipedia API with params: {}'.format(params)) + async with aiohttp.ClientSession() as session: + async with session.get(API_URL, params=params, headers=headers) as response: + if RATE_LIMIT: + RATE_LIMIT_LAST_CALL = datetime.now() + return await response.json() + + +class AsyncWikipediaPage(WikipediaPage): + def __init__(self, title=None, pageid=None, redirect=True, preload=False, original_title=''): + if title is not None: + self.title = title + self.original_title = original_title or title + elif pageid is not None: + self.pageid = pageid + else: + raise ValueError("Either a title or a pageid must be specified") + + self.redirect = redirect + self.preload = preload + + loop = asyncio.get_event_loop() + loop.run_until_complete(self.load()) + + if preload: + for prop in ('content', 'summary', 'images', 'references', 'links', 'sections'): + getattr(self, prop) + + async def load(self): + ''' + Load basic information from Wikipedia. + Confirm that page exists and is not a disambiguation/redirect. + + Does not need to be called manually, should be called automatically during __init__. + ''' + query_params = { + 'prop': 'info|pageprops', + 'inprop': 'url', + 'ppprop': 'disambiguation', + 'redirects': '', + } + if not getattr(self, 'pageid', None): + query_params['titles'] = self.title + else: + query_params['pageids'] = self.pageid + + request = await _async_wiki_request(query_params) + + query = request['query'] + pageid = list(query['pages'].keys())[0] + page = query['pages'][pageid] + + # missing is present if the page is missing + if 'missing' in page: + if hasattr(self, 'title'): + raise PageError(self.title) + else: + raise PageError(pageid=self.pageid) + + # same thing for redirect, except it shows up in query instead of page for + # whatever silly reason + elif 'redirects' in query: + if self.redirect: + redirects = query['redirects'][0] + + if 'normalized' in query: + normalized = query['normalized'][0] + assert normalized['from'] == self.title, ODD_ERROR_MESSAGE + + from_title = normalized['to'] + + else: + from_title = self.title + + assert redirects['from'] == from_title, ODD_ERROR_MESSAGE + + # change the title and reload the whole object + self.title = redirects['to'] + await self.load(redirects['to'], redirect=self.redirect, preload=self.preload) + + else: + raise RedirectError(getattr(self, 'title', page['title'])) + + # since we only asked for disambiguation in ppprop, + # if a pageprop is returned, + # then the page must be a disambiguation page + elif 'pageprops' in page: + query_params = { + 'prop': 'revisions', + 'rvprop': 'content', + 'rvparse': '', + 'rvlimit': 1 + } + if hasattr(self, 'pageid'): + query_params['pageids'] = self.pageid + else: + query_params['titles'] = self.title + request = await _async_wiki_request(query_params) + html = request['query']['pages'][pageid]['revisions'][0]['*'] + + lis = BeautifulSoup(html, 'html.parser').find_all('li') + filtered_lis = [li for li in lis if not 'tocsection' in ''.join(li.get('class', []))] + may_refer_to = [li.a.get_text() for li in filtered_lis if li.a] + + raise DisambiguationError(getattr(self, 'title', page['title']), may_refer_to) + + else: + self.pageid = pageid + self.title = page['title'] + self.url = page['fullurl'] + + + @property + async def sections(self): + ''' + Overwrites the `sections` property to be async. + ''' + if not getattr(self, '_sections', False): + query_params = { + 'action': 'parse', + 'prop': 'sections', + } + if not getattr(self, 'title', None) is None: + query_params["page"] = self.title + + request = await _async_wiki_request(query_params) + self._sections = [section['line'] for section in request['parse']['sections']] + + return self._sections + + async def section(self, section_title: str): + ''' + Get the plain text content of a section from `self.sections`. + Returns None if `section_title` isn't found, otherwise returns a whitespace stripped string. + + This is a convenience method that wraps self.content. + + .. warning:: Calling `section` on a section that has subheadings will NOT return + the full text of all of the subsections. It only gets the text between + `section_title` and the next subheading, which is often empty. + ''' + + section = u"== {} ==".format(section_title) + content = await self.content + + try: + index = content.index(section) + len(section) + except ValueError: + return None + + try: + next_index = content.index("==", index) + except ValueError: + next_index = len(content) + + return content[index:next_index].lstrip("=").strip() + + @property + async def content(self): + ''' + Overwrites the `content` property that is called by the `section` property. + This change enables the `content` property to be called independently in async. + ''' + + if not getattr(self, '_content', False): + query_params = { + 'prop': 'extracts|revisions', + 'explaintext': '', + 'rvprop': 'ids' + } + if not getattr(self, 'title', None) is None: + query_params['titles'] = self.title + else: + query_params['pageids'] = self.pageid + request = await _async_wiki_request(query_params) + self._content = request['query']['pages'][self.pageid]['extract'] + self._revision_id = request['query']['pages'][self.pageid]['revisions'][0]['revid'] + self._parent_id = request['query']['pages'][self.pageid]['revisions'][0]['parentid'] + + return self._content + \ No newline at end of file From c7e0d841ddab5c50aeb474bf7cd3b253703cae09 Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Thu, 18 Apr 2024 21:25:32 +0000 Subject: [PATCH 04/21] drops async wrapper by a custom implementation --- prompting/utils/async_wiki_utils.py | 261 ++++++++++++-------------- prompting/utils/async_wiki_wrapper.py | 206 -------------------- 2 files changed, 124 insertions(+), 343 deletions(-) delete mode 100644 prompting/utils/async_wiki_wrapper.py diff --git a/prompting/utils/async_wiki_utils.py b/prompting/utils/async_wiki_utils.py index cd59e5a7..d508ce36 100644 --- a/prompting/utils/async_wiki_utils.py +++ b/prompting/utils/async_wiki_utils.py @@ -1,145 +1,132 @@ +import aiohttp import asyncio -import bittensor as bt -import sys import random -from wikipedia import DisambiguationError, PageError -from wikipedia.exceptions import HTTPTimeoutError, WikipediaException, PageError -from functools import lru_cache -from typing import List -from prompting.utils.async_wiki_wrapper import AsyncWikipediaPage, _async_wiki_request -from dataclasses import dataclass - -@dataclass -class ProcessedSection: - header: str - section_title: str - content: List[str] - - def get_str_content(self): - return "\n".join(self.content) - -@dataclass -class ProcessedPage: - title: str - url: str - page: AsyncWikipediaPage - sections: List[ProcessedSection] - - -async def process_page( - page: AsyncWikipediaPage, valid_header: callable = None, valid_content: callable = None -) -> ProcessedPage: - """Process a Wikipedia page and return a dictionary of sections with their content. - - Args: - page: wikipedia.WikipediaPage - valid_header: callable to determine if a section header is valid - valid_content: callable to determine if a section content is valid - Returns: - dict: dictionary of sections and their content. Note that keys are tuples (header, section_title) - """ - header = "" - page_sections = [] - - # Get all section titles first - section_titles = await page.sections - - # Concurrently get the content of all sections - contents = await asyncio.gather( - *(page.section(section_title=title) for title in section_titles) - ) - - for section_title, content in zip(section_titles, contents): - if not content: - header = section_title - continue - - # Filter out sections that don't match the headers and/or are not valid - if (valid_header and not valid_header(header)) or ( - valid_content and not valid_content(content) - ): - continue +import bittensor as bt +from dataclasses import dataclass, field +from typing import List, Dict +from tqdm.asyncio import tqdm - section = ProcessedSection(header=header, section_title=section_title, content=content.splitlines()) - page_sections.append(section) +EXCLUDE_HEADERS = ("See also", "References", "Further reading", "External links") +EXCLUDE_CATEGORIES = ("articles", "wiki", "pages", "cs1") - if not page_sections: - bt.logging.debug(f"No valid sections found in page {page.title!r} ({page.url})") - - return ProcessedPage(title=page.title, sections=page_sections, url=page.url, page=page) +class SectionNotFoundException(Exception): + """Exception raised when no valid section is found.""" + pass -async def process_pages( - pages: List[AsyncWikipediaPage], valid_header: callable = None, valid_content: callable = None -): - tasks = [process_page(page, valid_header, valid_content) for page in pages] - sections = await asyncio.gather(*tasks) - return sections +class MaxRetriesReachedException(Exception): + """Exception raised when maximum retry attempts are reached.""" + pass -async def search(query, results=10, suggestion=False): - ''' Overwrites wikipedia base functions to use aiohttp and make it async ''' - search_params = { - 'list': 'search', - 'srprop': '', - 'srlimit': results, - 'limit': results, - 'srsearch': query +@dataclass +class Context: + title: str + topic: str + subtopic: str + content: str + internal_links: List[str] + external_links: List[str] + source: str + tags: List[str] = field(default_factory=list) + extra: Dict[str, any] = field(default_factory=dict) + stats: Dict[str, any] = field(default_factory=dict) + +async def fetch_content(session: aiohttp.ClientSession, pageid: str) -> str: + url = "https://en.wikipedia.org/w/api.php" + params = { + "action": "query", + "format": "json", + "prop": "extracts", + "explaintext": "", + "pageids": pageid } - if suggestion: - search_params['srinfo'] = 'suggestion' - - raw_results = await _async_wiki_request(search_params) - - if 'error' in raw_results: - if raw_results['error']['info'] in ('HTTP request timed out.', 'Pool queue is full'): - raise HTTPTimeoutError(query) - else: - raise WikipediaException(raw_results['error']['info']) - - search_results = (d['title'] for d in raw_results['query']['search']) - - if suggestion: - if raw_results['query'].get('searchinfo'): - return list(search_results), raw_results['query']['searchinfo']['suggestion'] - else: - return list(search_results), None - - return list(search_results) - -@lru_cache(maxsize=1000) -async def get_async_page(title=None, pageid=None, auto_suggest=True, redirect=True, preload=False, seed=None): - try: - if title is not None: - if auto_suggest: - results, suggestion = await search(title, results=1, suggestion=True) - try: - title = suggestion or results[0] - except IndexError: - raise PageError(title) - # Assuming WikipediaPage is a class that needs to be defined or imported - wiki_page = AsyncWikipediaPage(title=title, redirect=redirect, preload=preload) - return wiki_page - elif pageid is not None: - wiki_page = AsyncWikipediaPage(pageid=pageid, preload=preload) - else: - raise ValueError("Either a title or a pageid must be specified") - except DisambiguationError as e: - bt.logging.debug(f"{e.__class__.__name__} loading page {title!r}: {e}") - # exc info contains a tuple of (requested_title: str, possible_matches: List[str]) - pages = sys.exc_info()[1].args[1] - if not type(pages) == list: - return None - title = random.Random(seed).choice(pages) - return await get_async_page(title, auto_suggest=auto_suggest, redirect=redirect) - - except PageError as e: - bt.logging.warning(f"{e.__class__.__name__} loading page {title!r}: {e}") - if not auto_suggest: - return await get_async_page(title, auto_suggest=True, redirect=redirect) - return None - - -async def fetch_pages(titles, auto_suggest=False, redirect=True, preload=False): - tasks = [get_async_page(title=title, auto_suggest=auto_suggest, redirect=redirect, preload=preload) for title in titles] - pages = await asyncio.gather(*tasks) - return pages + async with session.get(url, params=params) as response: + data = await response.json() + content = data['query']['pages'][str(pageid)]['extract'] + return content + +async def fetch_random_page(session: aiohttp.ClientSession) -> str: + url = "https://en.wikipedia.org/w/api.php" + params = { + "action": "query", + "format": "json", + "list": "random", + "rnnamespace": "0", + "rnlimit": "1" + } + async with session.get(url, params=params) as response: + data = await response.json() + return data['query']['random'][0]['id'] + +async def fetch_page_details(session: aiohttp.ClientSession, pageid: str) -> Dict[str, any]: + url = "https://en.wikipedia.org/w/api.php" + params = { + "action": "parse", + "format": "json", + "pageid": pageid, + "prop": "sections|links|categories|externallinks", + "disabletoc": "1", + "disableeditsection": "1" + } + async with session.get(url, params=params) as response: + data = await response.json() + return data['parse'] + + +async def fetch_random_section_context(session: aiohttp.ClientSession, progress: tqdm) -> Context: + max_attempts = 10 + for attempt in range(1, max_attempts + 1): + try: + bt.logging.info("Fetching random section context...") + pageid = await fetch_random_page(session) + page_details = await fetch_page_details(session, pageid) + content = await fetch_content(session, pageid) + + # Filter sections here + filtered_sections = [section for section in page_details['sections'] if section['line'] not in EXCLUDE_HEADERS] + + if not filtered_sections: + bt.logging.error("No valid sections found.") + raise SectionNotFoundException("No valid sections found.") + + selected_section = random.choice(filtered_sections) + + internal_links = [link['*'] for link in page_details['links'] if link['ns'] == 0] + external_links = page_details.get('externallinks', []) + tags = [category['*'] for category in page_details['categories'] if not any(excl in category['*'].lower() for excl in EXCLUDE_CATEGORIES)] + + context = Context( + title=page_details['title'], + topic=selected_section.get('line', 'No Topic'), + subtopic=selected_section['line'], + content=content, + internal_links=internal_links, + external_links=external_links, + tags=tags, + source="Wikipedia", + extra={} + ) + progress.update(1) + return context + + except SectionNotFoundException as e: + bt.logging.warning(f"Attempt {attempt} failed: {e}") + if attempt == max_attempts: + bt.logging.error("Maximum retry attempts reached, failing...") + raise MaxRetriesReachedException(f"Maximum retry attempts reached: {max_attempts}") + + +async def get_batch_random_sections(batch_size: int = 16) -> List[Context]: + async with aiohttp.ClientSession() as session: + tasks: List[asyncio.Task] = [] + progress = tqdm(total=batch_size, desc=f"Fetching {batch_size} random wikipedia sections", unit="section") # Total is the number of tasks + + # Creates a list of tasks to be executed concurrently + for _ in range(batch_size): + task = asyncio.create_task(fetch_random_section_context(session, progress)) + tasks.append(task) + + results = await asyncio.gather(*tasks) + progress.close() # Ensure the progress bar closes after all tasks complete + + return results \ No newline at end of file diff --git a/prompting/utils/async_wiki_wrapper.py b/prompting/utils/async_wiki_wrapper.py deleted file mode 100644 index 6b6648cd..00000000 --- a/prompting/utils/async_wiki_wrapper.py +++ /dev/null @@ -1,206 +0,0 @@ -import bittensor as bt -import aiohttp -import asyncio -from datetime import datetime -from wikipedia import USER_AGENT, RATE_LIMIT, RATE_LIMIT_MIN_WAIT, API_URL, ODD_ERROR_MESSAGE, WikipediaPage, DisambiguationError, PageError -from wikipedia.exceptions import PageError, RedirectError -from bs4 import BeautifulSoup - -async def _async_wiki_request(params): - global RATE_LIMIT_LAST_CALL - global USER_AGENT - - params['format'] = 'json' - if 'action' not in params: - params['action'] = 'query' - - headers = {'User-Agent': USER_AGENT} - - if RATE_LIMIT and RATE_LIMIT_LAST_CALL and \ - RATE_LIMIT_LAST_CALL + RATE_LIMIT_MIN_WAIT > datetime.now(): - - wait_time = (RATE_LIMIT_LAST_CALL + RATE_LIMIT_MIN_WAIT) - datetime.now() - await asyncio.sleep(wait_time.total_seconds()) - - # bt.logging.info('*' * 15) - # bt.logging.info('Querying Wikipedia API with params: {}'.format(params)) - async with aiohttp.ClientSession() as session: - async with session.get(API_URL, params=params, headers=headers) as response: - if RATE_LIMIT: - RATE_LIMIT_LAST_CALL = datetime.now() - return await response.json() - - -class AsyncWikipediaPage(WikipediaPage): - def __init__(self, title=None, pageid=None, redirect=True, preload=False, original_title=''): - if title is not None: - self.title = title - self.original_title = original_title or title - elif pageid is not None: - self.pageid = pageid - else: - raise ValueError("Either a title or a pageid must be specified") - - self.redirect = redirect - self.preload = preload - - loop = asyncio.get_event_loop() - loop.run_until_complete(self.load()) - - if preload: - for prop in ('content', 'summary', 'images', 'references', 'links', 'sections'): - getattr(self, prop) - - async def load(self): - ''' - Load basic information from Wikipedia. - Confirm that page exists and is not a disambiguation/redirect. - - Does not need to be called manually, should be called automatically during __init__. - ''' - query_params = { - 'prop': 'info|pageprops', - 'inprop': 'url', - 'ppprop': 'disambiguation', - 'redirects': '', - } - if not getattr(self, 'pageid', None): - query_params['titles'] = self.title - else: - query_params['pageids'] = self.pageid - - request = await _async_wiki_request(query_params) - - query = request['query'] - pageid = list(query['pages'].keys())[0] - page = query['pages'][pageid] - - # missing is present if the page is missing - if 'missing' in page: - if hasattr(self, 'title'): - raise PageError(self.title) - else: - raise PageError(pageid=self.pageid) - - # same thing for redirect, except it shows up in query instead of page for - # whatever silly reason - elif 'redirects' in query: - if self.redirect: - redirects = query['redirects'][0] - - if 'normalized' in query: - normalized = query['normalized'][0] - assert normalized['from'] == self.title, ODD_ERROR_MESSAGE - - from_title = normalized['to'] - - else: - from_title = self.title - - assert redirects['from'] == from_title, ODD_ERROR_MESSAGE - - # change the title and reload the whole object - self.title = redirects['to'] - await self.load(redirects['to'], redirect=self.redirect, preload=self.preload) - - else: - raise RedirectError(getattr(self, 'title', page['title'])) - - # since we only asked for disambiguation in ppprop, - # if a pageprop is returned, - # then the page must be a disambiguation page - elif 'pageprops' in page: - query_params = { - 'prop': 'revisions', - 'rvprop': 'content', - 'rvparse': '', - 'rvlimit': 1 - } - if hasattr(self, 'pageid'): - query_params['pageids'] = self.pageid - else: - query_params['titles'] = self.title - request = await _async_wiki_request(query_params) - html = request['query']['pages'][pageid]['revisions'][0]['*'] - - lis = BeautifulSoup(html, 'html.parser').find_all('li') - filtered_lis = [li for li in lis if not 'tocsection' in ''.join(li.get('class', []))] - may_refer_to = [li.a.get_text() for li in filtered_lis if li.a] - - raise DisambiguationError(getattr(self, 'title', page['title']), may_refer_to) - - else: - self.pageid = pageid - self.title = page['title'] - self.url = page['fullurl'] - - - @property - async def sections(self): - ''' - Overwrites the `sections` property to be async. - ''' - if not getattr(self, '_sections', False): - query_params = { - 'action': 'parse', - 'prop': 'sections', - } - if not getattr(self, 'title', None) is None: - query_params["page"] = self.title - - request = await _async_wiki_request(query_params) - self._sections = [section['line'] for section in request['parse']['sections']] - - return self._sections - - async def section(self, section_title: str): - ''' - Get the plain text content of a section from `self.sections`. - Returns None if `section_title` isn't found, otherwise returns a whitespace stripped string. - - This is a convenience method that wraps self.content. - - .. warning:: Calling `section` on a section that has subheadings will NOT return - the full text of all of the subsections. It only gets the text between - `section_title` and the next subheading, which is often empty. - ''' - - section = u"== {} ==".format(section_title) - content = await self.content - - try: - index = content.index(section) + len(section) - except ValueError: - return None - - try: - next_index = content.index("==", index) - except ValueError: - next_index = len(content) - - return content[index:next_index].lstrip("=").strip() - - @property - async def content(self): - ''' - Overwrites the `content` property that is called by the `section` property. - This change enables the `content` property to be called independently in async. - ''' - - if not getattr(self, '_content', False): - query_params = { - 'prop': 'extracts|revisions', - 'explaintext': '', - 'rvprop': 'ids' - } - if not getattr(self, 'title', None) is None: - query_params['titles'] = self.title - else: - query_params['pageids'] = self.pageid - request = await _async_wiki_request(query_params) - self._content = request['query']['pages'][self.pageid]['extract'] - self._revision_id = request['query']['pages'][self.pageid]['revisions'][0]['revid'] - self._parent_id = request['query']['pages'][self.pageid]['revisions'][0]['parentid'] - - return self._content - \ No newline at end of file From 8df86ee1b4a097733c770fd0ef7ef99dac15b066 Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Thu, 18 Apr 2024 21:29:18 +0000 Subject: [PATCH 05/21] fix imports --- prompting/tools/datasets/base.py | 2 +- prompting/tools/datasets/batch_wiki.py | 2 +- prompting/utils/__init__.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/prompting/tools/datasets/base.py b/prompting/tools/datasets/base.py index f77e9b0d..755c720d 100644 --- a/prompting/tools/datasets/base.py +++ b/prompting/tools/datasets/base.py @@ -99,7 +99,7 @@ async def next( for tries in range(1, self.max_tries + 1): if method == "random": - results = await self.random_batch(selector=selector, **kwargs) + results = await self.random_batch() stats = { "creator": self.__class__.__name__, "fetch_time": time.time() - t0, diff --git a/prompting/tools/datasets/batch_wiki.py b/prompting/tools/datasets/batch_wiki.py index 658860c8..e2e5536e 100644 --- a/prompting/tools/datasets/batch_wiki.py +++ b/prompting/tools/datasets/batch_wiki.py @@ -2,7 +2,7 @@ from typing import List from .base import BatchDataset from prompting.tools.datasets import Context -from prompting.utils.custom_async_wiki import get_batch_random_sections +from prompting.utils.async_wiki_utils import get_batch_random_sections class BatchWikiDataset(BatchDataset): diff --git a/prompting/utils/__init__.py b/prompting/utils/__init__.py index 387d0167..39a17e70 100644 --- a/prompting/utils/__init__.py +++ b/prompting/utils/__init__.py @@ -3,5 +3,4 @@ from . import uids from . import logging from . import async_wiki_utils -from . import async_wiki_wrapper from . import wiki \ No newline at end of file From aca5965537d9989cb8d344f7b165f3d8c4308637 Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Fri, 19 Apr 2024 19:43:28 +0000 Subject: [PATCH 06/21] adds unit tests --- prompting/tools/datasets/__init__.py | 2 +- prompting/tools/datasets/batch_wiki.py | 13 ++----------- prompting/utils/config.py | 7 +++++++ tests/fixtures/dataset.py | 4 ++++ tests/test_dataset.py | 27 +++++++++++++++++++++++--- 5 files changed, 38 insertions(+), 15 deletions(-) diff --git a/prompting/tools/datasets/__init__.py b/prompting/tools/datasets/__init__.py index 7fa92044..a07ed81e 100644 --- a/prompting/tools/datasets/__init__.py +++ b/prompting/tools/datasets/__init__.py @@ -1,5 +1,5 @@ from .context import Context, BatchContext -from .base import Dataset +from .base import Dataset, BatchDataset from .code import HFCodingDataset, StackOverflowDataset from .math import MathDataset from .mock import MockDataset diff --git a/prompting/tools/datasets/batch_wiki.py b/prompting/tools/datasets/batch_wiki.py index e2e5536e..ba228212 100644 --- a/prompting/tools/datasets/batch_wiki.py +++ b/prompting/tools/datasets/batch_wiki.py @@ -7,15 +7,9 @@ class BatchWikiDataset(BatchDataset): """Wikipedia dataset. Uses the wikipedia python api to fetch articles and sections.""" - - EXCLUDE_HEADERS = ("See also", "References", "Further reading", "External links") - EXCLUDE_CATEGORIES = ("articles", "wiki", "pages", "cs1") - def __init__( self, batch_size: int = 16, - min_length_words: int = 50, - max_links: int = 10, ): """ Args: @@ -23,10 +17,7 @@ def __init__( max_links (int, optional): _description_. Defaults to 10. """ self.batch_size = batch_size - self.min_length_words = min_length_words - self.max_links = max_links - - + async def random_batch(self) -> List[Context]: - contexts = await get_batch_random_sections() + contexts = await get_batch_random_sections(self.batch_size) return contexts diff --git a/prompting/utils/config.py b/prompting/utils/config.py index ae27921d..daabbb11 100644 --- a/prompting/utils/config.py +++ b/prompting/utils/config.py @@ -388,6 +388,13 @@ def add_validator_args(cls, parser): help="Max time to wait for a forward call to complete in seconds.", default=120, ) + + parser.add_argument( + "--neuron.batch_size", + type=int, + help="Max time to wait for a forward call to complete in seconds.", + default=16, + ) def config(cls): diff --git a/tests/fixtures/dataset.py b/tests/fixtures/dataset.py index 7d57f814..253b5496 100644 --- a/tests/fixtures/dataset.py +++ b/tests/fixtures/dataset.py @@ -4,6 +4,7 @@ WikiDataset, WikiDateDataset, MathDataset, + BatchWikiDataset ) DATASETS = [ @@ -14,6 +15,9 @@ MathDataset, ] +BATCH_DATASETS = [ + BatchWikiDataset, +] MOCK_CONTEXT = MockDataset().next() WIKI_CONTEXT = WikiDataset().next(name="Emilio Alvarez (bishop)", method="get") diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 907607f3..e2f97e8d 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,10 +1,8 @@ import pytest - -from .fixtures.dataset import DATASETS, CONTEXTS, CONTEXT_FIELDS +from .fixtures.dataset import DATASETS, CONTEXTS, CONTEXT_FIELDS, BATCH_DATASETS from prompting.tools.datasets import Dataset from prompting.tools import Context - @pytest.mark.parametrize("dataset", DATASETS) def test_create_dataset(dataset: Dataset): ds = dataset() @@ -65,3 +63,26 @@ def test_context_field_is_not_null(dataset: Dataset, field: str): ) def test_context_stats_field_contains_expected_keys(dataset: Dataset, field: str): assert field in CONTEXTS[dataset].stats + + +## Batch dataset tests +@pytest.mark.asyncio +@pytest.mark.parametrize("dataset", BATCH_DATASETS) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) +async def test_batch_size_parameter(dataset, batch_size): + batch_context = await dataset(batch_size=batch_size).next() + results = batch_context.results + # Check if results match expected batch size + assert len(results) == batch_size + + +@pytest.mark.asyncio +@pytest.mark.parametrize("dataset", BATCH_DATASETS) +async def test_random_batch_retrieval(dataset): + # Fetch batches + batch1_results = (await dataset(batch_size=2).next()).results + batch2_results = (await dataset(batch_size=2).next()).results + + # Check that batches have different elements + assert batch1_results != batch2_results + From 58ed3b210bb2ddd467a488fb30e977e767f5e1ff Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Fri, 19 Apr 2024 19:44:52 +0000 Subject: [PATCH 07/21] runs black --- prompting/base/neuron.py | 4 +- prompting/forward.py | 15 +++-- prompting/tools/datasets/base.py | 12 ++-- prompting/tools/datasets/batch_wiki.py | 3 +- prompting/tools/datasets/context.py | 2 +- prompting/tools/datasets/math.py | 2 +- prompting/tools/datasets/wiki.py | 8 ++- prompting/utils/__init__.py | 2 +- prompting/utils/async_wiki_utils.py | 85 +++++++++++++++++--------- prompting/utils/config.py | 6 +- prompting/utils/misc.py | 6 +- prompting/utils/wiki.py | 1 + tests/fixtures/dataset.py | 4 +- tests/test_dataset.py | 10 +-- tests/test_forward.py | 12 ++-- 15 files changed, 106 insertions(+), 66 deletions(-) diff --git a/prompting/base/neuron.py b/prompting/base/neuron.py index 80fe4226..91b00c63 100644 --- a/prompting/base/neuron.py +++ b/prompting/base/neuron.py @@ -81,7 +81,9 @@ def __init__(self, config=None): if self.config.mock: self.wallet = bt.MockWallet(config=self.config) self.subtensor = MockSubtensor(self.config.netuid, wallet=self.wallet) - self.metagraph = MockMetagraph(netuid=self.config.netuid, subtensor=self.subtensor) + self.metagraph = MockMetagraph( + netuid=self.config.netuid, subtensor=self.subtensor + ) else: self.wallet = bt.wallet(config=self.config) self.subtensor = bt.subtensor(config=self.config) diff --git a/prompting/forward.py b/prompting/forward.py index 6a9b6ed3..a0ee2f19 100644 --- a/prompting/forward.py +++ b/prompting/forward.py @@ -33,17 +33,22 @@ from prompting.utils.misc import async_log, serialize_exception_to_string from dataclasses import dataclass + @async_log -async def generate_reference(agent): +async def generate_reference(agent): loop = asyncio.get_running_loop() - result = await loop.run_in_executor(None, agent.task.generate_reference, agent.llm_pipeline) - return result + result = await loop.run_in_executor( + None, agent.task.generate_reference, agent.llm_pipeline + ) + return result + @async_log async def execute_dendrite_call(dendrite_call): responses = await dendrite_call return responses + @dataclass class StreamResult: synapse: StreamPromptingSynapse = None @@ -217,8 +222,8 @@ async def run_step( log_stream_results(stream_results) - all_synapses_results = [stream_result.synapse for stream_result in stream_results] - + all_synapses_results = [stream_result.synapse for stream_result in stream_results] + # Encapsulate the responses in a response event (dataclass) response_event = DendriteResponseEvent( responses=all_synapses_results, uids=uids, timeout=timeout diff --git a/prompting/tools/datasets/base.py b/prompting/tools/datasets/base.py index 755c720d..3c50ed3e 100644 --- a/prompting/tools/datasets/base.py +++ b/prompting/tools/datasets/base.py @@ -86,18 +86,19 @@ def next( class BatchDataset(ABC): """Base class for datasets.""" - max_tries: int = 10 + + max_tries: int = 10 @abstractmethod async def random_batch(self, name): - ... + ... async def next( self, method: str = "random", selector: Selector = Selector(), **kwargs - ) -> BatchContext: + ) -> BatchContext: t0 = time.time() - for tries in range(1, self.max_tries + 1): + for tries in range(1, self.max_tries + 1): if method == "random": results = await self.random_batch() stats = { @@ -107,7 +108,7 @@ async def next( "fetch_method": method, "next_kwargs": kwargs, } - + return BatchContext(results=results, stats=stats) else: raise ValueError(f"Unknown dataset get method {method!r}") @@ -116,4 +117,3 @@ async def next( raise MaxRetryError( f"Could not find any samples which meet {self.__class__.__name__} requirements after {self.max_tries} tries." ) - \ No newline at end of file diff --git a/prompting/tools/datasets/batch_wiki.py b/prompting/tools/datasets/batch_wiki.py index ba228212..907cd821 100644 --- a/prompting/tools/datasets/batch_wiki.py +++ b/prompting/tools/datasets/batch_wiki.py @@ -7,6 +7,7 @@ class BatchWikiDataset(BatchDataset): """Wikipedia dataset. Uses the wikipedia python api to fetch articles and sections.""" + def __init__( self, batch_size: int = 16, @@ -18,6 +19,6 @@ def __init__( """ self.batch_size = batch_size - async def random_batch(self) -> List[Context]: + async def random_batch(self) -> List[Context]: contexts = await get_batch_random_sections(self.batch_size) return contexts diff --git a/prompting/tools/datasets/context.py b/prompting/tools/datasets/context.py index 20962ce0..22915d31 100644 --- a/prompting/tools/datasets/context.py +++ b/prompting/tools/datasets/context.py @@ -20,4 +20,4 @@ class Context: @dataclass class BatchContext: results: List[Context] - stats: dict = None \ No newline at end of file + stats: dict = None diff --git a/prompting/tools/datasets/math.py b/prompting/tools/datasets/math.py index af3096d4..494399d2 100644 --- a/prompting/tools/datasets/math.py +++ b/prompting/tools/datasets/math.py @@ -57,7 +57,7 @@ def get( """ bt.logging.info(f"Getting math problem {name!r}") info = mathgenerator.generate_context(name, **kwargs) - if info["reward_type"] != "float" or info["topic"] == "computer_science": + if info["reward_type"] != "float" or info["topic"] == "computer_science": return None math_words = [ diff --git a/prompting/tools/datasets/wiki.py b/prompting/tools/datasets/wiki.py index 69c6ffd8..6ff9c975 100644 --- a/prompting/tools/datasets/wiki.py +++ b/prompting/tools/datasets/wiki.py @@ -91,8 +91,12 @@ def get( "subtopic": section_title, "content": content, "internal_links": list(filter(lambda x: x not in exclude, page.sections)), - "external_links": wiki_utils.most_relevant_links(page, num_links=self.max_links), - "tags": wiki_utils.filter_categories(page.categories, exclude=self.EXCLUDE_CATEGORIES), + "external_links": wiki_utils.most_relevant_links( + page, num_links=self.max_links + ), + "tags": wiki_utils.filter_categories( + page.categories, exclude=self.EXCLUDE_CATEGORIES + ), "source": "Wikipedia", "extra": { "url": page.url, diff --git a/prompting/utils/__init__.py b/prompting/utils/__init__.py index 39a17e70..2dbd39e9 100644 --- a/prompting/utils/__init__.py +++ b/prompting/utils/__init__.py @@ -3,4 +3,4 @@ from . import uids from . import logging from . import async_wiki_utils -from . import wiki \ No newline at end of file +from . import wiki diff --git a/prompting/utils/async_wiki_utils.py b/prompting/utils/async_wiki_utils.py index d508ce36..6b0ebbfe 100644 --- a/prompting/utils/async_wiki_utils.py +++ b/prompting/utils/async_wiki_utils.py @@ -9,12 +9,16 @@ EXCLUDE_HEADERS = ("See also", "References", "Further reading", "External links") EXCLUDE_CATEGORIES = ("articles", "wiki", "pages", "cs1") + class SectionNotFoundException(Exception): """Exception raised when no valid section is found.""" + pass + class MaxRetriesReachedException(Exception): """Exception raised when maximum retry attempts are reached.""" + pass @@ -31,6 +35,7 @@ class Context: extra: Dict[str, any] = field(default_factory=dict) stats: Dict[str, any] = field(default_factory=dict) + async def fetch_content(session: aiohttp.ClientSession, pageid: str) -> str: url = "https://en.wikipedia.org/w/api.php" params = { @@ -38,13 +43,14 @@ async def fetch_content(session: aiohttp.ClientSession, pageid: str) -> str: "format": "json", "prop": "extracts", "explaintext": "", - "pageids": pageid + "pageids": pageid, } async with session.get(url, params=params) as response: data = await response.json() - content = data['query']['pages'][str(pageid)]['extract'] + content = data["query"]["pages"][str(pageid)]["extract"] return content + async def fetch_random_page(session: aiohttp.ClientSession) -> str: url = "https://en.wikipedia.org/w/api.php" params = { @@ -52,13 +58,16 @@ async def fetch_random_page(session: aiohttp.ClientSession) -> str: "format": "json", "list": "random", "rnnamespace": "0", - "rnlimit": "1" + "rnlimit": "1", } async with session.get(url, params=params) as response: data = await response.json() - return data['query']['random'][0]['id'] + return data["query"]["random"][0]["id"] + -async def fetch_page_details(session: aiohttp.ClientSession, pageid: str) -> Dict[str, any]: +async def fetch_page_details( + session: aiohttp.ClientSession, pageid: str +) -> Dict[str, any]: url = "https://en.wikipedia.org/w/api.php" params = { "action": "parse", @@ -66,14 +75,16 @@ async def fetch_page_details(session: aiohttp.ClientSession, pageid: str) -> Dic "pageid": pageid, "prop": "sections|links|categories|externallinks", "disabletoc": "1", - "disableeditsection": "1" + "disableeditsection": "1", } async with session.get(url, params=params) as response: data = await response.json() - return data['parse'] + return data["parse"] -async def fetch_random_section_context(session: aiohttp.ClientSession, progress: tqdm) -> Context: +async def fetch_random_section_context( + session: aiohttp.ClientSession, progress: tqdm +) -> Context: max_attempts = 10 for attempt in range(1, max_attempts + 1): try: @@ -82,29 +93,39 @@ async def fetch_random_section_context(session: aiohttp.ClientSession, progress: page_details = await fetch_page_details(session, pageid) content = await fetch_content(session, pageid) - # Filter sections here - filtered_sections = [section for section in page_details['sections'] if section['line'] not in EXCLUDE_HEADERS] + # Filter sections here + filtered_sections = [ + section + for section in page_details["sections"] + if section["line"] not in EXCLUDE_HEADERS + ] if not filtered_sections: bt.logging.error("No valid sections found.") raise SectionNotFoundException("No valid sections found.") selected_section = random.choice(filtered_sections) - - internal_links = [link['*'] for link in page_details['links'] if link['ns'] == 0] - external_links = page_details.get('externallinks', []) - tags = [category['*'] for category in page_details['categories'] if not any(excl in category['*'].lower() for excl in EXCLUDE_CATEGORIES)] - + + internal_links = [ + link["*"] for link in page_details["links"] if link["ns"] == 0 + ] + external_links = page_details.get("externallinks", []) + tags = [ + category["*"] + for category in page_details["categories"] + if not any(excl in category["*"].lower() for excl in EXCLUDE_CATEGORIES) + ] + context = Context( - title=page_details['title'], - topic=selected_section.get('line', 'No Topic'), - subtopic=selected_section['line'], + title=page_details["title"], + topic=selected_section.get("line", "No Topic"), + subtopic=selected_section["line"], content=content, internal_links=internal_links, external_links=external_links, tags=tags, source="Wikipedia", - extra={} + extra={}, ) progress.update(1) return context @@ -113,20 +134,26 @@ async def fetch_random_section_context(session: aiohttp.ClientSession, progress: bt.logging.warning(f"Attempt {attempt} failed: {e}") if attempt == max_attempts: bt.logging.error("Maximum retry attempts reached, failing...") - raise MaxRetriesReachedException(f"Maximum retry attempts reached: {max_attempts}") - + raise MaxRetriesReachedException( + f"Maximum retry attempts reached: {max_attempts}" + ) + -async def get_batch_random_sections(batch_size: int = 16) -> List[Context]: - async with aiohttp.ClientSession() as session: +async def get_batch_random_sections(batch_size: int = 16) -> List[Context]: + async with aiohttp.ClientSession() as session: tasks: List[asyncio.Task] = [] - progress = tqdm(total=batch_size, desc=f"Fetching {batch_size} random wikipedia sections", unit="section") # Total is the number of tasks - + progress = tqdm( + total=batch_size, + desc=f"Fetching {batch_size} random wikipedia sections", + unit="section", + ) # Total is the number of tasks + # Creates a list of tasks to be executed concurrently for _ in range(batch_size): task = asyncio.create_task(fetch_random_section_context(session, progress)) tasks.append(task) - + results = await asyncio.gather(*tasks) - progress.close() # Ensure the progress bar closes after all tasks complete - - return results \ No newline at end of file + progress.close() # Ensure the progress bar closes after all tasks complete + + return results diff --git a/prompting/utils/config.py b/prompting/utils/config.py index daabbb11..59481761 100644 --- a/prompting/utils/config.py +++ b/prompting/utils/config.py @@ -285,7 +285,7 @@ def add_validator_args(cls, parser): type=float, nargs="+", help="The probability of sampling each task.", - default=[.25, .25, 0, .25, .25], + default=[0.25, 0.25, 0, 0.25, 0.25], ) parser.add_argument( @@ -381,14 +381,14 @@ def add_validator_args(cls, parser): help="Only query a single hotkey per ip.", default=False, ) - + parser.add_argument( "--neuron.forward_max_time", type=int, help="Max time to wait for a forward call to complete in seconds.", default=120, ) - + parser.add_argument( "--neuron.batch_size", type=int, diff --git a/prompting/utils/misc.py b/prompting/utils/misc.py index fbfbf06a..20d65ca9 100644 --- a/prompting/utils/misc.py +++ b/prompting/utils/misc.py @@ -134,12 +134,12 @@ async def wrapper(*args, **kwargs): return wrapper -def serialize_exception_to_string(e): +def serialize_exception_to_string(e): if isinstance(e, BaseException): # Format the traceback - tb_str = ''.join(traceback.format_exception(type(e), e, e.__traceback__)) + tb_str = "".join(traceback.format_exception(type(e), e, e.__traceback__)) # Combine type, message, and traceback into one string serialized_str = f"Exception Type: {type(e).__name__}, Message: {str(e)}, Traceback: {tb_str}" return serialized_str - else: + else: return e diff --git a/prompting/utils/wiki.py b/prompting/utils/wiki.py index 233a2bd7..011b4751 100644 --- a/prompting/utils/wiki.py +++ b/prompting/utils/wiki.py @@ -6,6 +6,7 @@ from typing import Dict, List from functools import lru_cache + # speed up page loading @lru_cache(maxsize=1000) def _get_page( diff --git a/tests/fixtures/dataset.py b/tests/fixtures/dataset.py index 253b5496..631ddf3a 100644 --- a/tests/fixtures/dataset.py +++ b/tests/fixtures/dataset.py @@ -4,7 +4,7 @@ WikiDataset, WikiDateDataset, MathDataset, - BatchWikiDataset + BatchWikiDataset, ) DATASETS = [ @@ -16,7 +16,7 @@ ] BATCH_DATASETS = [ - BatchWikiDataset, + BatchWikiDataset, ] MOCK_CONTEXT = MockDataset().next() diff --git a/tests/test_dataset.py b/tests/test_dataset.py index e2f97e8d..7c52d69a 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -3,6 +3,7 @@ from prompting.tools.datasets import Dataset from prompting.tools import Context + @pytest.mark.parametrize("dataset", DATASETS) def test_create_dataset(dataset: Dataset): ds = dataset() @@ -69,20 +70,19 @@ def test_context_stats_field_contains_expected_keys(dataset: Dataset, field: str @pytest.mark.asyncio @pytest.mark.parametrize("dataset", BATCH_DATASETS) @pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) -async def test_batch_size_parameter(dataset, batch_size): +async def test_batch_size_parameter(dataset, batch_size): batch_context = await dataset(batch_size=batch_size).next() results = batch_context.results # Check if results match expected batch size - assert len(results) == batch_size + assert len(results) == batch_size @pytest.mark.asyncio @pytest.mark.parametrize("dataset", BATCH_DATASETS) -async def test_random_batch_retrieval(dataset): +async def test_random_batch_retrieval(dataset): # Fetch batches batch1_results = (await dataset(batch_size=2).next()).results - batch2_results = (await dataset(batch_size=2).next()).results + batch2_results = (await dataset(batch_size=2).next()).results # Check that batches have different elements assert batch1_results != batch2_results - diff --git a/tests/test_forward.py b/tests/test_forward.py index 0b5a58b9..baa10771 100644 --- a/tests/test_forward.py +++ b/tests/test_forward.py @@ -24,14 +24,14 @@ def generate_reference(x, delay=1): async def mock_dendrite_call(delay=1, **kwargs): asyncio.run(asyncio.sleep(delay)) - - async def async_fn_mock(): + + async def async_fn_mock(): mock_synapse = StreamPromptingSynapse(roles=["user"], messages=[""]) mock_synapse.completion = "Fake response" - + yield mock_synapse - - mock_stream_synapse = async_fn_mock() + + mock_stream_synapse = async_fn_mock() return [mock_stream_synapse] @@ -57,5 +57,5 @@ def test_generate_reference_parallel_to_dendrite( # TODO: Fix unit test to work with abs=0.1 assert network_and_reference_gen_time == pytest.approx( - expected_forward_time, abs=1#0.1 + expected_forward_time, abs=1 # 0.1 ) From 6d4d14ac17f665374a95392fc39459e1bd27d125 Mon Sep 17 00:00:00 2001 From: p-ferreira <38992619+p-ferreira@users.noreply.github.com> Date: Wed, 24 Apr 2024 15:58:33 -0400 Subject: [PATCH 08/21] Update prompting/tools/datasets/base.py Co-authored-by: Steffen Cruz --- prompting/tools/datasets/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prompting/tools/datasets/base.py b/prompting/tools/datasets/base.py index 959b533c..c6f27f90 100644 --- a/prompting/tools/datasets/base.py +++ b/prompting/tools/datasets/base.py @@ -87,7 +87,7 @@ def next( class BatchDataset(ABC): - """Base class for datasets.""" + """Base class for batch datasets.""" max_tries: int = 10 From d7157b81df081b8694d968b04a6c79b5e371b4b5 Mon Sep 17 00:00:00 2001 From: p-ferreira <38992619+p-ferreira@users.noreply.github.com> Date: Wed, 24 Apr 2024 15:58:49 -0400 Subject: [PATCH 09/21] Update prompting/tools/datasets/batch_wiki.py Co-authored-by: Steffen Cruz --- prompting/tools/datasets/batch_wiki.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/prompting/tools/datasets/batch_wiki.py b/prompting/tools/datasets/batch_wiki.py index 907cd821..f228a798 100644 --- a/prompting/tools/datasets/batch_wiki.py +++ b/prompting/tools/datasets/batch_wiki.py @@ -14,8 +14,7 @@ def __init__( ): """ Args: - min_length_words (int, optional): Minimum section length. Defaults to 50. - max_links (int, optional): _description_. Defaults to 10. + max_links (int, optional): _description_. Defaults to 16. """ self.batch_size = batch_size From 8e4af7ca45ad5bc7f252c561dba5430723e32598 Mon Sep 17 00:00:00 2001 From: p-ferreira <38992619+p-ferreira@users.noreply.github.com> Date: Wed, 24 Apr 2024 15:58:58 -0400 Subject: [PATCH 10/21] Update prompting/tools/datasets/batch_wiki.py Co-authored-by: Steffen Cruz --- prompting/tools/datasets/batch_wiki.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prompting/tools/datasets/batch_wiki.py b/prompting/tools/datasets/batch_wiki.py index f228a798..ac78742f 100644 --- a/prompting/tools/datasets/batch_wiki.py +++ b/prompting/tools/datasets/batch_wiki.py @@ -6,7 +6,7 @@ class BatchWikiDataset(BatchDataset): - """Wikipedia dataset. Uses the wikipedia python api to fetch articles and sections.""" + """Wikipedia batch dataset. Uses the wikipedia python api to fetch articles and sections.""" def __init__( self, From d886fb54d62a090abcb9a1270f120d8295bb7e19 Mon Sep 17 00:00:00 2001 From: p-ferreira <38992619+p-ferreira@users.noreply.github.com> Date: Wed, 24 Apr 2024 16:04:41 -0400 Subject: [PATCH 11/21] Update prompting/utils/config.py Co-authored-by: Steffen Cruz --- prompting/utils/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prompting/utils/config.py b/prompting/utils/config.py index 3ebb2203..a7bec286 100644 --- a/prompting/utils/config.py +++ b/prompting/utils/config.py @@ -393,7 +393,7 @@ def add_validator_args(cls, parser): parser.add_argument( "--neuron.batch_size", type=int, - help="Max time to wait for a forward call to complete in seconds.", + help="Number of concurrent queries to create in each forward.", default=16, ) From 6ddf250b851bd18cd8d666e124427d1d24d78956 Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Wed, 24 Apr 2024 20:12:46 +0000 Subject: [PATCH 12/21] drops redundant class --- prompting/utils/async_wiki_utils.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/prompting/utils/async_wiki_utils.py b/prompting/utils/async_wiki_utils.py index 6b0ebbfe..509f05c6 100644 --- a/prompting/utils/async_wiki_utils.py +++ b/prompting/utils/async_wiki_utils.py @@ -2,7 +2,7 @@ import asyncio import random import bittensor as bt -from dataclasses import dataclass, field +from prompting.tools.datasets import Context from typing import List, Dict from tqdm.asyncio import tqdm @@ -12,30 +12,14 @@ class SectionNotFoundException(Exception): """Exception raised when no valid section is found.""" - pass class MaxRetriesReachedException(Exception): """Exception raised when maximum retry attempts are reached.""" - pass -@dataclass -class Context: - title: str - topic: str - subtopic: str - content: str - internal_links: List[str] - external_links: List[str] - source: str - tags: List[str] = field(default_factory=list) - extra: Dict[str, any] = field(default_factory=dict) - stats: Dict[str, any] = field(default_factory=dict) - - async def fetch_content(session: aiohttp.ClientSession, pageid: str) -> str: url = "https://en.wikipedia.org/w/api.php" params = { From 52a0309a04779a3979849a59e1e228313b56ad3f Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Wed, 24 Apr 2024 20:12:58 +0000 Subject: [PATCH 13/21] complements unit tests assertions --- tests/test_dataset.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 552ffdc9..9d0d8dfd 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -75,6 +75,8 @@ async def test_batch_size_parameter(dataset, batch_size): results = batch_context.results # Check if results match expected batch size assert len(results) == batch_size + assert type(results) == list + assert all(type(result)==Context for result in results) @pytest.mark.asyncio @@ -86,3 +88,10 @@ async def test_random_batch_retrieval(dataset): # Check that batches have different elements assert batch1_results != batch2_results + + # Check that results are of expected type + assert type(batch1_results) == list + assert all(type(result)==Context for result in batch1_results) + + assert type(batch2_results) == list + assert all(type(result)==Context for result in batch2_results) From 7b4d343be0297b5cb11e4933cf1db5fbd9f6609b Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Wed, 24 Apr 2024 20:25:45 +0000 Subject: [PATCH 14/21] adds stats for context obj --- prompting/utils/async_wiki_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/prompting/utils/async_wiki_utils.py b/prompting/utils/async_wiki_utils.py index 509f05c6..fd9739f6 100644 --- a/prompting/utils/async_wiki_utils.py +++ b/prompting/utils/async_wiki_utils.py @@ -1,3 +1,4 @@ +import time import aiohttp import asyncio import random @@ -72,6 +73,7 @@ async def fetch_random_section_context( max_attempts = 10 for attempt in range(1, max_attempts + 1): try: + request_time_start = time.time() bt.logging.info("Fetching random section context...") pageid = await fetch_random_page(session) page_details = await fetch_page_details(session, pageid) @@ -110,6 +112,11 @@ async def fetch_random_section_context( tags=tags, source="Wikipedia", extra={}, + stats = { + "creator": fetch_random_section_context.__name__, + "fetch_time": time.time() - request_time_start, + "num_tries": attempt, + } ) progress.update(1) return context From e7129b98aa9cdba18f390b6ea57e9a6217fb19e4 Mon Sep 17 00:00:00 2001 From: p-ferreira <38992619+p-ferreira@users.noreply.github.com> Date: Wed, 24 Apr 2024 16:26:44 -0400 Subject: [PATCH 15/21] Update prompting/tools/datasets/base.py Co-authored-by: Steffen Cruz --- prompting/tools/datasets/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prompting/tools/datasets/base.py b/prompting/tools/datasets/base.py index c6f27f90..0efeddfe 100644 --- a/prompting/tools/datasets/base.py +++ b/prompting/tools/datasets/base.py @@ -90,7 +90,7 @@ class BatchDataset(ABC): """Base class for batch datasets.""" max_tries: int = 10 - + batch_size: int = 16 # ensure that child classes contain batch_size attrib @abstractmethod async def random_batch(self, name): ... From f50d547c7cbafdde963600a8788b88eae18555c4 Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Wed, 24 Apr 2024 20:29:28 +0000 Subject: [PATCH 16/21] modifies random_batch naming to just random --- prompting/tools/datasets/base.py | 6 +++--- prompting/tools/datasets/batch_wiki.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/prompting/tools/datasets/base.py b/prompting/tools/datasets/base.py index 0efeddfe..e7a33d87 100644 --- a/prompting/tools/datasets/base.py +++ b/prompting/tools/datasets/base.py @@ -90,9 +90,9 @@ class BatchDataset(ABC): """Base class for batch datasets.""" max_tries: int = 10 - batch_size: int = 16 # ensure that child classes contain batch_size attrib + batch_size: int = 16 # ensure that child classes contain batch_size attrib @abstractmethod - async def random_batch(self, name): + async def random(self, name): ... async def next( @@ -102,7 +102,7 @@ async def next( for tries in range(1, self.max_tries + 1): if method == "random": - results = await self.random_batch() + results = await self.random() stats = { "creator": self.__class__.__name__, "fetch_time": time.time() - t0, diff --git a/prompting/tools/datasets/batch_wiki.py b/prompting/tools/datasets/batch_wiki.py index ac78742f..13df4dd0 100644 --- a/prompting/tools/datasets/batch_wiki.py +++ b/prompting/tools/datasets/batch_wiki.py @@ -18,6 +18,6 @@ def __init__( """ self.batch_size = batch_size - async def random_batch(self) -> List[Context]: + async def random(self) -> List[Context]: contexts = await get_batch_random_sections(self.batch_size) return contexts From 6f33c6d8c3a3b7f9dccf9eedf81e7070fec384c5 Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Wed, 24 Apr 2024 20:41:33 +0000 Subject: [PATCH 17/21] adds docs strings --- prompting/utils/async_wiki_utils.py | 69 +++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/prompting/utils/async_wiki_utils.py b/prompting/utils/async_wiki_utils.py index fd9739f6..3efbd629 100644 --- a/prompting/utils/async_wiki_utils.py +++ b/prompting/utils/async_wiki_utils.py @@ -22,6 +22,19 @@ class MaxRetriesReachedException(Exception): async def fetch_content(session: aiohttp.ClientSession, pageid: str) -> str: + """ + Asynchronously fetches the plain text content of a Wikipedia page given its page ID. + + Args: + session (aiohttp.ClientSession): The session used to make HTTP requests. + pageid (str): The Wikipedia page ID of the page from which to fetch content. + + Returns: + str: The plain text content of the Wikipedia page. + + Raises: + aiohttp.ClientError: If there's an HTTP related error during the request. + """ url = "https://en.wikipedia.org/w/api.php" params = { "action": "query", @@ -37,6 +50,18 @@ async def fetch_content(session: aiohttp.ClientSession, pageid: str) -> str: async def fetch_random_page(session: aiohttp.ClientSession) -> str: + """ + Asynchronously fetches the page ID of a random Wikipedia page. + + Args: + session (aiohttp.ClientSession): The session used to make HTTP requests. + + Returns: + str: The page ID of a randomly selected Wikipedia page. + + Raises: + aiohttp.ClientError: If there's an HTTP related error during the request. + """ url = "https://en.wikipedia.org/w/api.php" params = { "action": "query", @@ -53,6 +78,19 @@ async def fetch_random_page(session: aiohttp.ClientSession) -> str: async def fetch_page_details( session: aiohttp.ClientSession, pageid: str ) -> Dict[str, any]: + """ + Asynchronously fetches detailed information about a Wikipedia page, including sections, links, categories, and external links. + + Args: + session (aiohttp.ClientSession): The session used to make HTTP requests. + pageid (str): The Wikipedia page ID from which to fetch details. + + Returns: + Dict[str, Any]: A dictionary containing detailed information about the Wikipedia page. + + Raises: + aiohttp.ClientError: If there's an HTTP related error during the request. + """ url = "https://en.wikipedia.org/w/api.php" params = { "action": "parse", @@ -70,6 +108,20 @@ async def fetch_page_details( async def fetch_random_section_context( session: aiohttp.ClientSession, progress: tqdm ) -> Context: + """ + Asynchronously fetches the context of a random section from a random Wikipedia page, including title, topic, subtopic, content, and links. + + Args: + session (aiohttp.ClientSession): The session used to make HTTP requests. + progress (tqdm): A tqdm progress bar instance to update progress. + + Returns: + Any: A context object containing various details about the section. + + Raises: + SectionNotFoundException: If no valid section is found after filtering. + MaxRetriesReachedException: If the maximum number of retry attempts is reached. + """ max_attempts = 10 for attempt in range(1, max_attempts + 1): try: @@ -131,6 +183,23 @@ async def fetch_random_section_context( async def get_batch_random_sections(batch_size: int = 16) -> List[Context]: + """ + Asynchronously fetches a batch of random sections from Wikipedia pages. This function utilizes concurrency to fetch multiple sections in parallel. + + Args: + batch_size (int, optional): The number of random sections to fetch. Defaults to 16. + + Returns: + List[Context]: A list of context objects, each containing details about a random section of a Wikipedia page. + + Details: + The function creates an asynchronous session and a number of tasks equal to the batch size. Each task fetches a random section context from a Wikipedia page. All tasks are run concurrently, and the function waits for all tasks to complete before returning the results. A progress bar is displayed to track the progress of fetching the sections. + + Raises: + aiohttp.ClientError: If there's an HTTP related error during any request in the tasks. + SectionNotFoundException: If no valid section is found after filtering in any task. + MaxRetriesReachedException: If the maximum number of retry attempts is reached in any task. + """ async with aiohttp.ClientSession() as session: tasks: List[asyncio.Task] = [] progress = tqdm( From 2bb5f3f26d839968c80b5fab99293498fec94ab0 Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Wed, 24 Apr 2024 20:57:26 +0000 Subject: [PATCH 18/21] runs black --- neurons/validator.py | 2 +- prompting/conversation.py | 2 +- prompting/task_registry.py | 33 +++++++++++++++---- prompting/tasks/__init__.py | 2 +- prompting/tasks/date_qa.py | 2 +- prompting/tasks/generic_instruction.py | 4 +-- prompting/tasks/mock.py | 9 ++--- prompting/tasks/summarization.py | 1 - prompting/tasks/task.py | 4 ++- prompting/tools/__init__.py | 9 ++--- prompting/tools/datasets/base.py | 3 +- prompting/tools/datasets/code.py | 2 ++ .../tools/datasets/generic_instruction.py | 2 +- prompting/tools/datasets/math.py | 14 +++++--- prompting/tools/datasets/mock.py | 1 + prompting/tools/datasets/wiki.py | 1 + prompting/utils/async_wiki_utils.py | 6 ++-- tests/fixtures/task.py | 8 ++++- tests/test_dataset.py | 10 +++--- tests/test_registry.py | 5 ++- 20 files changed, 78 insertions(+), 42 deletions(-) diff --git a/neurons/validator.py b/neurons/validator.py index 6f8370d2..d8901943 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -40,7 +40,7 @@ def __init__(self, config=None): mock=self.config.mock, ) - if abs(1-sum(self.config.neuron.task_p)) > 0.001: + if abs(1 - sum(self.config.neuron.task_p)) > 0.001: raise ValueError("Task probabilities do not sum to 1.") # Filter out tasks with 0 probability diff --git a/prompting/conversation.py b/prompting/conversation.py index e17f5cd3..3ddd761a 100644 --- a/prompting/conversation.py +++ b/prompting/conversation.py @@ -47,4 +47,4 @@ def create_task( llm_pipeline=llm_pipeline, context=dataset.next(), create_reference=create_reference, - ) \ No newline at end of file + ) diff --git a/prompting/task_registry.py b/prompting/task_registry.py index 2e006694..ab8c6b52 100644 --- a/prompting/task_registry.py +++ b/prompting/task_registry.py @@ -1,19 +1,38 @@ -from .tasks import Task, MockTask, SummarizationTask, QuestionAnsweringTask, DebuggingTask, MathTask, DateQuestionAnsweringTask, GenericInstructionTask -from .tools import MockDataset, WikiDataset, HFCodingDataset, StackOverflowDataset, MathDataset, WikiDateDataset, GenericInstructionDataset +from .tasks import ( + Task, + MockTask, + SummarizationTask, + QuestionAnsweringTask, + DebuggingTask, + MathTask, + DateQuestionAnsweringTask, + GenericInstructionTask, +) +from .tools import ( + MockDataset, + WikiDataset, + HFCodingDataset, + StackOverflowDataset, + MathDataset, + WikiDateDataset, + GenericInstructionDataset, +) # TODO: Expand this to include extra information beyond just the task and dataset names summarization_task, summarization_dataset = SummarizationTask.name, [WikiDataset.name] qa_task, qa_dataset = QuestionAnsweringTask.name, [WikiDataset.name] -#debugging_task, debugging_dataset = DebuggingTask.name, [HFCodingDataset.name] +# debugging_task, debugging_dataset = DebuggingTask.name, [HFCodingDataset.name] math_task, math_dataset = MathTask.name, [MathDataset.name] date_qa_task, date_qa_dataset = DateQuestionAnsweringTask.name, [WikiDateDataset.name] -generic_instruction_task, generic_instruction_dataset = GenericInstructionTask.name, [GenericInstructionDataset.name] +generic_instruction_task, generic_instruction_dataset = GenericInstructionTask.name, [ + GenericInstructionDataset.name +] TASK_REGISTRY = { summarization_task: summarization_dataset, qa_task: qa_dataset, - #debugging_task: debugging_dataset, + # debugging_task: debugging_dataset, math_task: math_dataset, date_qa_task: date_qa_dataset, - generic_instruction_task: generic_instruction_dataset -} \ No newline at end of file + generic_instruction_task: generic_instruction_dataset, +} diff --git a/prompting/tasks/__init__.py b/prompting/tasks/__init__.py index 01292ded..ba3bb28a 100644 --- a/prompting/tasks/__init__.py +++ b/prompting/tasks/__init__.py @@ -12,7 +12,7 @@ QuestionAnsweringTask.name: QuestionAnsweringTask, DateQuestionAnsweringTask.name: DateQuestionAnsweringTask, SummarizationTask.name: SummarizationTask, - #DebuggingTask.name: DebuggingTask, + # DebuggingTask.name: DebuggingTask, GenericInstructionTask.name: GenericInstructionTask, MathTask.name: MathTask, } diff --git a/prompting/tasks/date_qa.py b/prompting/tasks/date_qa.py index bee5776b..0f13015e 100644 --- a/prompting/tasks/date_qa.py +++ b/prompting/tasks/date_qa.py @@ -22,7 +22,7 @@ class DateQuestionAnsweringTask(Task): static_reference = True static_query = True - def __init__(self, llm_pipeline, context, create_reference =True): + def __init__(self, llm_pipeline, context, create_reference=True): self.context = context self.query = ( diff --git a/prompting/tasks/generic_instruction.py b/prompting/tasks/generic_instruction.py index d5b706a6..b7cac0dc 100644 --- a/prompting/tasks/generic_instruction.py +++ b/prompting/tasks/generic_instruction.py @@ -13,7 +13,7 @@ class GenericInstructionTask(Task): - challenge_type = 'query' + challenge_type = "query" name = "generic" desc = "get help on answering a general instruction" goal = "to get the answer to the following instruction" @@ -38,7 +38,7 @@ def __init__(self, llm_pipeline, context, create_reference=True): self.query_prompt = QUERY_PROMPT_TEMPLATE.format(context=context.content) self.query = self.generate_query(llm_pipeline) - self.reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(query = self.query) + self.reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(query=self.query) if create_reference: self.reference = self.generate_reference(llm_pipeline) diff --git a/prompting/tasks/mock.py b/prompting/tasks/mock.py index 668827fb..5f79964f 100644 --- a/prompting/tasks/mock.py +++ b/prompting/tasks/mock.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from prompting.tasks import Task + @dataclass class MockTask(Task): name = "mock" @@ -18,12 +19,8 @@ class MockTask(Task): def __init__(self, llm_pipeline, context, create_reference=True): self.context = context - self.query = ( - "How can I solve the following problem, " - + context.content - + "?" - ) + self.query = "How can I solve the following problem, " + context.content + "?" self.reference = "This is the reference answer" self.topic = context.title self.subtopic = context.topic - self.tags = context.tags \ No newline at end of file + self.tags = context.tags diff --git a/prompting/tasks/summarization.py b/prompting/tasks/summarization.py index 8e413f82..e832e426 100644 --- a/prompting/tasks/summarization.py +++ b/prompting/tasks/summarization.py @@ -2,7 +2,6 @@ from prompting.tasks import Task - # TODO: introduce criteria for the query and reference answer (length, layout, etc.) and make these arguments # TODO: Also add a query system prompt and a query prompt template diff --git a/prompting/tasks/task.py b/prompting/tasks/task.py index cfa03a3d..3618e882 100644 --- a/prompting/tasks/task.py +++ b/prompting/tasks/task.py @@ -17,6 +17,8 @@ def make_system_prompt(): return CHATTENSOR_SYSTEM_PROMPT.format(date=time.strftime("%B %d, %Y")) + + class TaskEvaluationType(Enum): REWARD_STACK = "reward" FILTER_STACK = "filter" @@ -108,7 +110,7 @@ def generate_query(self, pipeline: BasePipeline, clean=True) -> str: if not self.static_query: bt.logging.info("🤖 Generating query...") self.query = self.generate( - system=self.query_system_prompt, #Could possibly add the chattensor system prompt to query but I don't think it adds anything + system=self.query_system_prompt, # Could possibly add the chattensor system prompt to query but I don't think it adds anything prompt=self.query_prompt, pipeline=pipeline, clean=clean, diff --git a/prompting/tools/__init__.py b/prompting/tools/__init__.py index c296ce17..a4f634a2 100644 --- a/prompting/tools/__init__.py +++ b/prompting/tools/__init__.py @@ -12,13 +12,10 @@ from .selector import Selector DATASETS = { - #HFCodingDataset.name: HFCodingDataset, + # HFCodingDataset.name: HFCodingDataset, WikiDataset.name: WikiDataset, - #StackOverflowDataset.name: StackOverflowDataset, + # StackOverflowDataset.name: StackOverflowDataset, MathDataset.name: MathDataset, WikiDateDataset.name: WikiDateDataset, - GenericInstructionDataset.name: GenericInstructionDataset, + GenericInstructionDataset.name: GenericInstructionDataset, } - - - \ No newline at end of file diff --git a/prompting/tools/datasets/base.py b/prompting/tools/datasets/base.py index e7a33d87..0dd789c3 100644 --- a/prompting/tools/datasets/base.py +++ b/prompting/tools/datasets/base.py @@ -90,7 +90,8 @@ class BatchDataset(ABC): """Base class for batch datasets.""" max_tries: int = 10 - batch_size: int = 16 # ensure that child classes contain batch_size attrib + batch_size: int = 16 # ensure that child classes contain batch_size attrib + @abstractmethod async def random(self, name): ... diff --git a/prompting/tools/datasets/code.py b/prompting/tools/datasets/code.py index bdc2948b..ada3cdc6 100644 --- a/prompting/tools/datasets/code.py +++ b/prompting/tools/datasets/code.py @@ -524,6 +524,7 @@ def filter_comments(code, language): # TODO: why not define the chain_in, chain_out logic in the class itself? class HFCodingDataset(Dataset): name = "hf_coding" + def __init__( self, dataset_id="codeparrot/github-code", @@ -617,6 +618,7 @@ def get_special_contents(self, code, language, remove_comments=True): class StackOverflowDataset: name = "stack_overflow" + def __init__(self): # Stack Overflow API endpoint for a random article self.url = "https://api.stackexchange.com/2.3/questions" diff --git a/prompting/tools/datasets/generic_instruction.py b/prompting/tools/datasets/generic_instruction.py index b4d1e1d2..e24dea22 100644 --- a/prompting/tools/datasets/generic_instruction.py +++ b/prompting/tools/datasets/generic_instruction.py @@ -49,4 +49,4 @@ class GenericInstructionDataset(TemplateDataset): "tech", "history", ], - ) \ No newline at end of file + ) diff --git a/prompting/tools/datasets/math.py b/prompting/tools/datasets/math.py index b0dfc7ef..02071fc0 100644 --- a/prompting/tools/datasets/math.py +++ b/prompting/tools/datasets/math.py @@ -30,7 +30,7 @@ class MathDataset(Dataset): - name = 'math' + name = "math" topics_list = mathgenerator.getGenList() def __init__(self, seed=None): @@ -60,7 +60,7 @@ def get( max_tries = 10 for _ in range(max_tries): info = mathgenerator.generate_context(name, **kwargs) - if info["reward_type"] != "float" or info["topic"] == "computer_science": + if info["reward_type"] != "float" or info["topic"] == "computer_science": pass else: math_words = [ @@ -83,10 +83,16 @@ def get( "topic": info["topic"], # title of problem topic "subtopic": info["subtopic"], # title of problem subtopic "content": info["problem"], # problem statement - "internal_links": [info["topic"], info["subtopic"]], # internal links + "internal_links": [ + info["topic"], + info["subtopic"], + ], # internal links "external_links": external_links, "tags": info["forward_words"], - "extra": {"reward_type": info["reward_type"], "solution": info["solution"]}, + "extra": { + "reward_type": info["reward_type"], + "solution": info["solution"], + }, } def search( diff --git a/prompting/tools/datasets/mock.py b/prompting/tools/datasets/mock.py index 54008269..9b18070d 100644 --- a/prompting/tools/datasets/mock.py +++ b/prompting/tools/datasets/mock.py @@ -5,6 +5,7 @@ class MockDataset(Dataset): name = "mock" + def get(self, name, exclude=None, selector=None): return { "title": name, diff --git a/prompting/tools/datasets/wiki.py b/prompting/tools/datasets/wiki.py index 5f8ee6bc..83f88690 100644 --- a/prompting/tools/datasets/wiki.py +++ b/prompting/tools/datasets/wiki.py @@ -28,6 +28,7 @@ class WikiDataset(Dataset): """Wikipedia dataset. Uses the wikipedia python api to fetch articles and sections.""" + name = "wiki" EXCLUDE_HEADERS = ("See also", "References", "Further reading", "External links") EXCLUDE_CATEGORIES = ("articles", "wiki", "pages", "cs1") diff --git a/prompting/utils/async_wiki_utils.py b/prompting/utils/async_wiki_utils.py index 3efbd629..b9cf4f41 100644 --- a/prompting/utils/async_wiki_utils.py +++ b/prompting/utils/async_wiki_utils.py @@ -13,11 +13,13 @@ class SectionNotFoundException(Exception): """Exception raised when no valid section is found.""" + pass class MaxRetriesReachedException(Exception): """Exception raised when maximum retry attempts are reached.""" + pass @@ -164,11 +166,11 @@ async def fetch_random_section_context( tags=tags, source="Wikipedia", extra={}, - stats = { + stats={ "creator": fetch_random_section_context.__name__, "fetch_time": time.time() - request_time_start, "num_tries": attempt, - } + }, ) progress.update(1) return context diff --git a/tests/fixtures/task.py b/tests/fixtures/task.py index 3eca818e..2a289575 100644 --- a/tests/fixtures/task.py +++ b/tests/fixtures/task.py @@ -8,7 +8,13 @@ DateQuestionAnsweringTask, ) from prompting.tools import Context -from .dataset import WIKI_CONTEXT, CODING_CONTEXT, MATH_CONTEXT, DATEQA_CONTEXT, MOCK_CONTEXT +from .dataset import ( + WIKI_CONTEXT, + CODING_CONTEXT, + MATH_CONTEXT, + DATEQA_CONTEXT, + MOCK_CONTEXT, +) TASKS = [ MockTask, diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 9d0d8dfd..9b1e860c 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -76,7 +76,7 @@ async def test_batch_size_parameter(dataset, batch_size): # Check if results match expected batch size assert len(results) == batch_size assert type(results) == list - assert all(type(result)==Context for result in results) + assert all(type(result) == Context for result in results) @pytest.mark.asyncio @@ -88,10 +88,10 @@ async def test_random_batch_retrieval(dataset): # Check that batches have different elements assert batch1_results != batch2_results - + # Check that results are of expected type assert type(batch1_results) == list - assert all(type(result)==Context for result in batch1_results) - + assert all(type(result) == Context for result in batch1_results) + assert type(batch2_results) == list - assert all(type(result)==Context for result in batch2_results) + assert all(type(result) == Context for result in batch2_results) diff --git a/tests/test_registry.py b/tests/test_registry.py index 49c72fd2..825749a9 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -10,7 +10,10 @@ def test_task_registry(): assert ( not registry_missing_task ), f"Missing tasks in TASK_REGISTRY: {registry_missing_task}" - assert not registry_extra_task, f"Extra tasks in TASK_REGISTRY: {registry_extra_task}" + assert ( + not registry_extra_task + ), f"Extra tasks in TASK_REGISTRY: {registry_extra_task}" + def test_task_registry_datasets(): registry_datasets = set( From 5db2e370ec85c6ef280ac5b6db2461ba49b1960a Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Wed, 24 Apr 2024 21:02:05 +0000 Subject: [PATCH 19/21] adds black to gh action --- .github/workflows/python-package.yml | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 41e1fbc6..08592466 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -16,7 +16,7 @@ jobs: python-version: ["3.9", "3.10"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v3 with: @@ -25,7 +25,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install flake8 pytest black - bash install.sh + bash install.sh - name: Lint with flake8 run: | @@ -33,9 +33,15 @@ jobs: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Black - run: | - black . + uses: psf/black@stable + with: + options: "--check --verbose" + src: "." + jupyter: true + version: "21.5b1" + - name: Test with pytest run: | # run tests in tests/ dir and only fail if there are failures or errors From 19c4b11240ba97cee1f60ac9cb6d092fd9fb58a4 Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Wed, 24 Apr 2024 21:14:31 +0000 Subject: [PATCH 20/21] fix black command on gh action --- .github/workflows/python-package.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 08592466..4911bb08 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -39,8 +39,6 @@ jobs: with: options: "--check --verbose" src: "." - jupyter: true - version: "21.5b1" - name: Test with pytest run: | From 968688d662cfd66fbece47d01564e43375a7862a Mon Sep 17 00:00:00 2001 From: p-ferreira Date: Wed, 24 Apr 2024 21:27:51 +0000 Subject: [PATCH 21/21] runs latest version of black --- prompting/base/neuron.py | 6 ++---- prompting/forward.py | 6 +++++- prompting/llms/base_llm.py | 9 +++------ prompting/rewards/reward.py | 3 +-- prompting/tools/datasets/base.py | 12 ++++-------- 5 files changed, 15 insertions(+), 21 deletions(-) diff --git a/prompting/base/neuron.py b/prompting/base/neuron.py index 91b00c63..f85e01c1 100644 --- a/prompting/base/neuron.py +++ b/prompting/base/neuron.py @@ -104,12 +104,10 @@ def __init__(self, config=None): self.step = 0 @abstractmethod - def forward(self, synapse: bt.Synapse) -> bt.Synapse: - ... + def forward(self, synapse: bt.Synapse) -> bt.Synapse: ... @abstractmethod - def run(self): - ... + def run(self): ... def sync(self): """ diff --git a/prompting/forward.py b/prompting/forward.py index a0ee2f19..cd78d078 100644 --- a/prompting/forward.py +++ b/prompting/forward.py @@ -60,7 +60,11 @@ async def process_response(uid: int, async_generator: Awaitable): """Process a single response asynchronously.""" try: chunk = None # Initialize chunk with a default value - async for chunk in async_generator: # most important loop, as this is where we acquire the final synapse. + async for ( + chunk + ) in ( + async_generator + ): # most important loop, as this is where we acquire the final synapse. bt.logging.debug(f"\nchunk for uid {uid}: {chunk}") if chunk is not None: diff --git a/prompting/llms/base_llm.py b/prompting/llms/base_llm.py index 0ed7b139..8d275e50 100644 --- a/prompting/llms/base_llm.py +++ b/prompting/llms/base_llm.py @@ -6,8 +6,7 @@ class BasePipeline(ABC): @abstractmethod - def __call__(self, composed_prompt: str, **kwargs: dict) -> Any: - ... + def __call__(self, composed_prompt: str, **kwargs: dict) -> Any: ... class BaseLLM(ABC): @@ -29,11 +28,9 @@ def query( role: str = "user", disregard_system_prompt: bool = False, cleaner: CleanerPipeline = None, - ) -> str: - ... + ) -> str: ... - def forward(self, messages: List[Dict[str, str]]): - ... + def forward(self, messages: List[Dict[str, str]]): ... def clean_response(self, cleaner: CleanerPipeline, response: str) -> str: if cleaner is not None: diff --git a/prompting/rewards/reward.py b/prompting/rewards/reward.py index 1adf35a7..feb29f07 100644 --- a/prompting/rewards/reward.py +++ b/prompting/rewards/reward.py @@ -146,8 +146,7 @@ def __post_init__(self): class BaseRewardModel(ABC): @property @abstractmethod - def name(self) -> str: - ... + def name(self) -> str: ... @abstractmethod def __init__(self, **kwargs): diff --git a/prompting/tools/datasets/base.py b/prompting/tools/datasets/base.py index 0dd789c3..0127d243 100644 --- a/prompting/tools/datasets/base.py +++ b/prompting/tools/datasets/base.py @@ -34,16 +34,13 @@ class Dataset(ABC): max_tries: int = 10 @abstractmethod - def search(self, name): - ... + def search(self, name): ... @abstractmethod - def random(self, name): - ... + def random(self, name): ... @abstractmethod - def get(self, name): - ... + def get(self, name): ... def next( self, method: str = "random", selector: Selector = Selector(), **kwargs @@ -93,8 +90,7 @@ class BatchDataset(ABC): batch_size: int = 16 # ensure that child classes contain batch_size attrib @abstractmethod - async def random(self, name): - ... + async def random(self, name): ... async def next( self, method: str = "random", selector: Selector = Selector(), **kwargs