Skip to content

Commit

Permalink
Changed s3 url for the stack v2 data loading
Browse files Browse the repository at this point in the history
  • Loading branch information
cryptal-mc committed Dec 11, 2024
1 parent 4e49418 commit 2b25449
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions pretrain/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
from pprint import pprint

import os
from dotenv import load_dotenv
from dotenv import load_dotenv
load_dotenv()

class SubsetLoader(IterableDataset):
"""Base class for data-specific subset loader classes."""

name: str = None # Dataset name
rows_base_url: str = "https://datasets-server.huggingface.co/rows"
size_base_url: str = "https://datasets-server.huggingface.co/size"
max_pages: int = None

def __init__(
self,
batch_size=None,
Expand Down Expand Up @@ -78,9 +78,9 @@ def __init__(
self._initialize_pages()
fetch_attempt += 1

# Exit if the buffer has at least one batch
# Exit if the buffer has at least one batch
if len(self.buffer) >= self.sequence_length:
break
break

bt.logging.warning(
f"All fetched pages seem to be empty or have an extremely low token count. "
Expand Down Expand Up @@ -139,14 +139,14 @@ def _fetch_data_for_page(self, page):
})
else:
self.params["offset"] = page

self.params["length"] = self.num_rows_per_page

attempt = 0
while attempt < self.retry_limit:
try:
response = requests.get(
self.rows_base_url,
self.rows_base_url,
params=self.params,
headers=self._get_request_headers()
)
Expand Down Expand Up @@ -183,9 +183,9 @@ def get_page_names(self):
"""Get page names in consistent format"""
if not hasattr(self, 'pages'):
return []

if isinstance(self.pages[0], tuple):
return [f"{cfg_name}_{num_rows}_{split}"
return [f"{cfg_name}_{num_rows}_{split}"
for cfg_name, num_rows, split in self.pages]
return self.pages

Expand Down Expand Up @@ -257,15 +257,15 @@ def __init__(self, **kwargs):
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"])

self.s3_sess = session.client("s3")

super().__init__(requires_auth=True, **kwargs)


def _download_row_content(self, blob_id, src_encoding):
"""Download the row content from S3.
"""
s3_url = f"s3://softwareheritage/content/{blob_id}"

s3_url = f"https://softwareheritage.s3.amazonaws.com/content/{blob_id}"

with smart_open.open(s3_url, "rb", compression=".gz", transport_params={"client": self.s3_sess}) as fin:
content = fin.read().decode(src_encoding)
Expand All @@ -277,7 +277,7 @@ def _get_content_from_row(self, row):

content = self._download_row_content(row['row']['blob_id'], row['row']['src_encoding'])
return content


class SubsetFalconLoader(SubsetLoader):
max_pages: int = 968000015
Expand All @@ -286,14 +286,14 @@ class SubsetFalconLoader(SubsetLoader):

class SubsetFineWebEdu2Loader(SubsetLoader):
name: str = "HuggingFaceFW/fineweb-edu-score-2"

def fetch_dataset_configs(self) -> typing.Dict[str, typing.Dict]:
"""
Fetch dataset configs and their metadata.
Returns a dictionary with config names as keys and metadata as values.
"""
params = dict(dataset=self.name)

attempt = 0
while attempt < self.retry_limit:
try:
Expand Down Expand Up @@ -385,7 +385,7 @@ def get_random_pages(self, num_pages, initial_offset):
split = self.configs_data[config_name]["split"]
pages.append((config_name, selected_page_start, split))
return pages

def fetch_data_to_rows(self, num_pages):
"""Fetch data and return raw text rows instead of adding to buffer."""
downloaded_pages = set()
Expand Down

0 comments on commit 2b25449

Please sign in to comment.