diff --git a/pretrain/dataset.py b/pretrain/dataset.py index de19704..fbc50bf 100644 --- a/pretrain/dataset.py +++ b/pretrain/dataset.py @@ -12,17 +12,17 @@ from pprint import pprint import os -from dotenv import load_dotenv +from dotenv import load_dotenv load_dotenv() class SubsetLoader(IterableDataset): """Base class for data-specific subset loader classes.""" - + name: str = None # Dataset name rows_base_url: str = "https://datasets-server.huggingface.co/rows" size_base_url: str = "https://datasets-server.huggingface.co/size" max_pages: int = None - + def __init__( self, batch_size=None, @@ -78,9 +78,9 @@ def __init__( self._initialize_pages() fetch_attempt += 1 - # Exit if the buffer has at least one batch + # Exit if the buffer has at least one batch if len(self.buffer) >= self.sequence_length: - break + break bt.logging.warning( f"All fetched pages seem to be empty or have an extremely low token count. " @@ -139,14 +139,14 @@ def _fetch_data_for_page(self, page): }) else: self.params["offset"] = page - + self.params["length"] = self.num_rows_per_page - + attempt = 0 while attempt < self.retry_limit: try: response = requests.get( - self.rows_base_url, + self.rows_base_url, params=self.params, headers=self._get_request_headers() ) @@ -183,9 +183,9 @@ def get_page_names(self): """Get page names in consistent format""" if not hasattr(self, 'pages'): return [] - + if isinstance(self.pages[0], tuple): - return [f"{cfg_name}_{num_rows}_{split}" + return [f"{cfg_name}_{num_rows}_{split}" for cfg_name, num_rows, split in self.pages] return self.pages @@ -257,15 +257,15 @@ def __init__(self, **kwargs): aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"]) self.s3_sess = session.client("s3") - + super().__init__(requires_auth=True, **kwargs) - + def _download_row_content(self, blob_id, src_encoding): """Download the row content from S3. """ - - s3_url = f"s3://softwareheritage/content/{blob_id}" + + s3_url = f"https://softwareheritage.s3.amazonaws.com/content/{blob_id}" with smart_open.open(s3_url, "rb", compression=".gz", transport_params={"client": self.s3_sess}) as fin: content = fin.read().decode(src_encoding) @@ -277,7 +277,7 @@ def _get_content_from_row(self, row): content = self._download_row_content(row['row']['blob_id'], row['row']['src_encoding']) return content - + class SubsetFalconLoader(SubsetLoader): max_pages: int = 968000015 @@ -286,14 +286,14 @@ class SubsetFalconLoader(SubsetLoader): class SubsetFineWebEdu2Loader(SubsetLoader): name: str = "HuggingFaceFW/fineweb-edu-score-2" - + def fetch_dataset_configs(self) -> typing.Dict[str, typing.Dict]: """ Fetch dataset configs and their metadata. Returns a dictionary with config names as keys and metadata as values. """ params = dict(dataset=self.name) - + attempt = 0 while attempt < self.retry_limit: try: @@ -385,7 +385,7 @@ def get_random_pages(self, num_pages, initial_offset): split = self.configs_data[config_name]["split"] pages.append((config_name, selected_page_start, split)) return pages - + def fetch_data_to_rows(self, num_pages): """Fetch data and return raw text rows instead of adding to buffer.""" downloaded_pages = set()