Skip to content

Commit

Permalink
Add warning
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaohanZhangCMU committed Feb 24, 2024
1 parent 6add8ea commit 18a2f97
Showing 1 changed file with 26 additions and 23 deletions.
49 changes: 26 additions & 23 deletions streaming/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,26 +220,26 @@ def get_import_exception_message(package_name: str, extra_deps: str) -> str:


def merge_index(*args: Any, **kwargs: Any):
r"""Merge index.json from shards to form a global index.json.
r"""Merge index.json from streams to form a global index.json.
This can be called as
merge_index(index_file_urls, out, keep_local, download_timeout)
merge_index(out, keep_local, download_timeout)
The first signature takes in a list of index files URLs of MDS shards.
The second takes the root of a MDS dataset and parse the shards folders from there.
The first signature takes in a list of index files URLs of MDS streams.
The second takes the root of a MDS dataset and parse the streams folders from there.
Args:
index_file_urls (List[Union[str, Tuple[str,str]]]): index.json from all the shards.
index_file_urls (List[Union[str, Tuple[str,str]]]): index.json from all the streams.
Each element can take the form of a single path string or a tuple string.
1. If ``index_file_urls`` is a List of local URLs, merge locally without download.
2. If ``index_file_urls`` is a List of tuple (local, remote) URLs, check if local index.json are missing, download before merging.
3. If ``index_file_urls`` is a List of remote URLs, download all and merge.
out (Union[str, Tuple[str,str]]): folder that contain MDS shards and to put the merged index file
out (Union[str, Tuple[str,str]]): folder that contain MDS streams and to put the merged index file
1. A local directory, merge index happens locally.
2. A remote directory, download all the sub-directories index.json, merge locally and upload.
Expand Down Expand Up @@ -267,12 +267,12 @@ def _download_url(url_info: Tuple[str, str, int]):
return dst, None


def _merge_shard_indices(shard_indices: List[str]):
"""Function to be executed by each process to merge a subset of shard indices."""
def _merge_stream_indices(stream_indices: List[str]):
"""Function to be executed by each process to merge a subset of stream indices."""
shards = []
for shard_index in shard_indices:
p = Path(shard_index)
with open(shard_index, 'r') as f:
for stream_index in stream_indices:
p = Path(stream_index)
with open(stream_index, 'r') as f:
obj = json.load(f)
for shard in obj['shards']:
for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'):
Expand All @@ -283,15 +283,15 @@ def _merge_shard_indices(shard_indices: List[str]):
return shards


def _parallel_merge_shards(shards: List[str], n_processes: int = 1):
"""Divide the list of shards among multiple processes and merge them in parallel."""
def _parallel_merge_streams(streams: List[str], n_processes: int = 1):
"""Divide the list of streams among multiple processes and merge their shards in parallel."""
with Pool(processes=n_processes) as pool:
# Split the list of shards into N chunks where N is the number of processes
chunk_size = int(np.ceil(len(shards) / n_processes))
shard_chunks = [shards[i:i + chunk_size] for i in range(0, len(shards), chunk_size)]
# Split the list of streams into N chunks where N is the number of processes
chunk_size = int(np.ceil(len(streams) / n_processes))
stream_chunks = [streams[i:i + chunk_size] for i in range(0, len(streams), chunk_size)]

# Process each chunk in parallel
results = pool.imap_unordered(_merge_shard_indices, shard_chunks)
results = pool.map(_merge_stream_indices, stream_chunks)
pool.close()
pool.join()

Expand All @@ -308,7 +308,7 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]]
"""Merge index.json from a list of index files of MDS directories to create joined index.
Args:
index_file_urls (Union[str, Tuple[str,str]]): index.json from all the shards
index_file_urls (Union[str, Tuple[str,str]]): index.json from all the streams
each element can take the form of a single path string or a tuple string.
The pattern of index_file_urls and corresponding reaction is one of:
Expand Down Expand Up @@ -370,17 +370,17 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]]
download_tasks.append((src, dst, download_timeout))

with Pool(processes=n_processes) as pool:
results = pool.imap_unordered(_download_url, download_tasks)
results = pool.map(_download_url, download_tasks)
pool.close()
pool.join()

shards = []
for shard_index, error in results:
streams = []
for stream_index, error in results:
if error:
raise RuntimeError(shard_index)
shards.append(shard_index)
raise RuntimeError(stream_index)
streams.append(stream_index)

shards = _parallel_merge_shards(shards, n_processes)
shards = _parallel_merge_streams(streams, n_processes)

# Save merged index locally
obj = {
Expand Down Expand Up @@ -467,6 +467,9 @@ def _merge_index_from_root(out: Union[str, Tuple[str, str]],

local_index_files = []
cl = CloudUploader.get(cu.local, exist_ok=True, keep_local=True)

logger.warning(f"We will be listing objects from {out}, which may take a long time if the number of stream folders is large. Consider provide the list of path/to/index.json directly.")

for file in cl.list_objects():
if file.endswith('.json') and _not_merged_index(file, cu.local):
local_index_files.append(file)
Expand Down

0 comments on commit 18a2f97

Please sign in to comment.