-
Notifications
You must be signed in to change notification settings - Fork 148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
parallel merge index #590
base: main
Are you sure you want to change the base?
parallel merge index #590
Changes from all commits
70d8e8f
feee52d
a0605a2
e7edd52
8112858
a583329
d6206f0
21c591a
168d3dd
f82d47e
6add8ea
18a2f97
eb5a16f
1a0c458
9c05860
6004e81
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,11 +13,14 @@ | |
import tempfile | ||
import urllib.parse | ||
from collections import OrderedDict | ||
from multiprocessing import Pool | ||
from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory | ||
from pathlib import Path | ||
from time import sleep, time | ||
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union, cast, overload | ||
|
||
import numpy as np | ||
import psutil | ||
import torch.distributed as dist | ||
|
||
from streaming.base.constant import SHM_TO_CLEAN | ||
|
@@ -217,26 +220,26 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str: | |
|
||
|
||
def merge_index(*args: Any, **kwargs: Any): | ||
r"""Merge index.json from partitions to form a global index.json. | ||
r"""Merge index.json from streams to form a global index.json. | ||
|
||
This can be called as | ||
|
||
merge_index(index_file_urls, out, keep_local, download_timeout) | ||
|
||
merge_index(out, keep_local, download_timeout) | ||
|
||
The first signature takes in a list of index files URLs of MDS partitions. | ||
The second takes the root of a MDS dataset and parse the partition folders from there. | ||
The first signature takes in a list of index files URLs of MDS streams. | ||
The second takes the root of a MDS dataset and parse the streams folders from there. | ||
|
||
Args: | ||
index_file_urls (List[Union[str, Tuple[str,str]]]): index.json from all the partitions. | ||
index_file_urls (List[Union[str, Tuple[str,str]]]): index.json from all the streams. | ||
Each element can take the form of a single path string or a tuple string. | ||
|
||
1. If ``index_file_urls`` is a List of local URLs, merge locally without download. | ||
2. If ``index_file_urls`` is a List of tuple (local, remote) URLs, check if local index.json are missing, download before merging. | ||
3. If ``index_file_urls`` is a List of remote URLs, download all and merge. | ||
|
||
out (Union[str, Tuple[str,str]]): folder that contain MDS partitions and to put the merged index file | ||
out (Union[str, Tuple[str,str]]): folder that contain MDS streams and to put the merged index file | ||
|
||
1. A local directory, merge index happens locally. | ||
2. A remote directory, download all the sub-directories index.json, merge locally and upload. | ||
|
@@ -253,14 +256,59 @@ def merge_index(*args: Any, **kwargs: Any): | |
raise ValueError(f'Invalid arguments to merge_index: {args}, {kwargs}') | ||
|
||
|
||
def _download_url(url_info: Tuple[str, str, int]): | ||
"""Download a file given URL information.""" | ||
from streaming.base.storage.download import download_file | ||
src, dst, download_timeout = url_info | ||
try: | ||
download_file(src, dst, download_timeout) | ||
except Exception as ex: | ||
return f'Failed to download index.json: {src} to {dst}: {str(ex)}', ex | ||
return dst, None | ||
|
||
|
||
def _merge_stream_indices(stream_indices: List[str]): | ||
"""Function to be executed by each process to merge a subset of stream indices.""" | ||
shards = [] | ||
for stream_index in stream_indices: | ||
p = Path(stream_index) | ||
with open(stream_index, 'r') as f: | ||
obj = json.load(f) | ||
for shard in obj['shards']: | ||
for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we really ought to make this a Shard method, which is subject to inheritance and so on this code won't work for parquet shards :/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any specific suggestion how to deal with this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this work for json/xsv or just for mds index files? Could you test that as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do json/xsv index files have the same file format? @knighton |
||
if shard.get(key): | ||
basename = shard[key]['basename'] | ||
shard[key]['basename'] = os.path.join(os.path.basename(p.parent), basename) | ||
shards.extend(obj['shards']) | ||
return shards | ||
|
||
|
||
def _parallel_merge_streams(streams: List[str], n_processes: int = 1): | ||
"""Divide the list of streams among multiple processes and merge their shards in parallel.""" | ||
with Pool(processes=n_processes) as pool: | ||
# Split the list of streams into N chunks where N is the number of processes | ||
chunk_size = int(np.ceil(len(streams) / n_processes)) | ||
stream_chunks = [streams[i:i + chunk_size] for i in range(0, len(streams), chunk_size)] | ||
|
||
# Process each chunk in parallel | ||
results = pool.map(_merge_stream_indices, stream_chunks) | ||
pool.close() | ||
pool.join() | ||
|
||
# Combine the results from all processes | ||
final_shards = [shard for result in results for shard in result] | ||
return final_shards | ||
|
||
|
||
def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]]], | ||
out: Union[str, Tuple[str, str]], | ||
keep_local: bool = True, | ||
download_timeout: int = 60) -> None: | ||
download_timeout: int = 60, | ||
n_processes: int = 1) -> None: | ||
"""Merge index.json from a list of index files of MDS directories to create joined index. | ||
|
||
Args: | ||
index_file_urls (Union[str, Tuple[str,str]]): index.json from all the partitions | ||
index_file_urls (Union[str, Tuple[str,str]]): index.json from all the streams | ||
each element can take the form of a single path string or a tuple string. | ||
|
||
The pattern of index_file_urls and corresponding reaction is one of: | ||
|
@@ -272,8 +320,8 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] | |
out (Union[str, Tuple[str, str]]): path to put the merged index file | ||
keep_local (bool): Keep local copy of the merged index file. Defaults to ``True`` | ||
download_timeout (int): The allowed time for downloading each json file. Defaults to 60. | ||
n_processes (int): The number of cores to run the function in parallel | ||
""" | ||
from streaming.base.storage.download import download_file | ||
from streaming.base.storage.upload import CloudUploader | ||
|
||
if not index_file_urls or not out: | ||
|
@@ -295,12 +343,17 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] | |
else: | ||
urls.append((url[0].rstrip('/').strip(), url[1].rstrip('/').strip())) | ||
|
||
cpu_count = max(psutil.cpu_count() - 2, 1) | ||
XiaohanZhangCMU marked this conversation as resolved.
Show resolved
Hide resolved
|
||
n_processes = n_processes if (1 <= n_processes <= cpu_count) else 1 | ||
|
||
logger.warning(f'Using n_processes = {n_processes} to download and merge index in parallel') | ||
|
||
# Prepare a temp folder to download index.json from remote if necessary. Removed in the end. | ||
with tempfile.TemporaryDirectory() as temp_root: | ||
logging.warning(f'A temporary folder {temp_root} is created to store index files') | ||
logging.info(f'Created temporary folder {temp_root} to store index files') | ||
|
||
# Copy files to a temporary directory. Download if necessary | ||
partitions = [] | ||
download_tasks = [] | ||
for url in urls: | ||
if isinstance(url, tuple): | ||
src = url[0] if os.path.exists(url[0]) else url[1] | ||
|
@@ -313,31 +366,21 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]] | |
raise FileNotFoundError( | ||
f'Check data availability! local index {url[0]} is not accessible.' + | ||
f'remote index {url[1]} does not have a valid url format') | ||
dest = os.path.join(temp_root, path.lstrip('/')) | ||
dst = os.path.join(temp_root, path.lstrip('/')) | ||
download_tasks.append((src, dst, download_timeout)) | ||
|
||
try: | ||
download_file(src, dest, download_timeout) | ||
except Exception as ex: | ||
raise RuntimeError(f'Failed to download index.json: {src} to {dest}') from ex | ||
|
||
if not os.path.exists(dest): | ||
raise FileNotFoundError(f'Index file {dest} does not exist or not accessible.') | ||
|
||
partitions.append(dest) | ||
|
||
# merge shards from all index files | ||
shards = [] | ||
for partition_index in partitions: | ||
p = Path(partition_index) | ||
obj = json.load(open(partition_index)) | ||
for i in range(len(obj['shards'])): | ||
shard = obj['shards'][i] | ||
for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'): | ||
if shard.get(key): | ||
basename = shard[key]['basename'] | ||
obj['shards'][i][key]['basename'] = os.path.join( | ||
os.path.basename(p.parent), basename) | ||
shards += obj['shards'] | ||
with Pool(processes=n_processes) as pool: | ||
results = pool.map(_download_url, download_tasks) | ||
XiaohanZhangCMU marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pool.close() | ||
pool.join() | ||
|
||
streams = [] | ||
for stream_index, error in results: | ||
if error: | ||
raise RuntimeError(stream_index) | ||
streams.append(stream_index) | ||
|
||
shards = _parallel_merge_streams(streams, n_processes) | ||
|
||
# Save merged index locally | ||
obj = { | ||
|
@@ -403,7 +446,7 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]], | |
"""Merge index.json given the root of MDS dataset. Write merged index to the root folder. | ||
|
||
Args: | ||
out (Union[str, Tuple[str,str]]): folder that contain MDS partitions. | ||
out (Union[str, Tuple[str,str]]): folder that contain MDS shards. | ||
:A local directory, merge index happens locally | ||
:A remote directory, download all the sub-directories index.json in a temporary | ||
sub-directories, merge locally, and then upload it to out location | ||
|
@@ -424,28 +467,38 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]], | |
|
||
local_index_files = [] | ||
cl = CloudUploader.get(cu.local, exist_ok=True, keep_local=True) | ||
|
||
logger.warning( | ||
f'We will be listing objects from {out}, which may take a long time if the number of stream folders is large. Consider provide the list of path/to/index.json directly.' | ||
) | ||
|
||
for file in cl.list_objects(): | ||
if file.endswith('.json') and _not_merged_index(file, cu.local): | ||
local_index_files.append(file) | ||
|
||
cpu_count = max(psutil.cpu_count() - 2, 1) | ||
|
||
if cu.remote: | ||
remote_index_files = _format_remote_index_files(cu.remote, cu.list_objects()) | ||
if len(local_index_files) == len(remote_index_files): | ||
_merge_index_from_list(list(zip(local_index_files, remote_index_files)), | ||
out, | ||
keep_local=keep_local, | ||
download_timeout=download_timeout) | ||
download_timeout=download_timeout, | ||
n_processes=cpu_count) | ||
else: | ||
_merge_index_from_list(remote_index_files, | ||
out, | ||
keep_local=keep_local, | ||
download_timeout=download_timeout) | ||
download_timeout=download_timeout, | ||
n_processes=cpu_count) | ||
return | ||
|
||
_merge_index_from_list(local_index_files, | ||
out, | ||
keep_local=keep_local, | ||
download_timeout=download_timeout) | ||
download_timeout=download_timeout, | ||
n_processes=cpu_count) | ||
|
||
|
||
@overload | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this function had typed arguments before, could you add that back instead of using generic
*args, **kwargs
? Would help clarify what the function is expecting, since otherwise, people have to go looking for the docstringThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an entry function and it branches to either merge_index from list or merge_index from a root folder. Are you recommending to not remove this entry function, and just keep the two actual implementations? Can you elaborate a bit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I mean that the signature of the
merge_index
function isn't typed -- it just has*args, **kwargs
, but the docstring has types for the arguments, for example,index_file_urls
has type(List[Union[str, Tuple[str,str]]])
. I'm suggesting we add types to themerge_index
function so that users will easily be able to see what themerge_index
function takes in (for example, IDEs all support this, would be better for our docs, etc). Otherwise they have to reference the docstring, which doesn't match up with the function signature. lmk if i should elaborate moreThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keeping this entry function is fine