Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parallel merge index #590

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
84 changes: 58 additions & 26 deletions streaming/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import tempfile
import urllib.parse
from collections import OrderedDict
from multiprocessing import Pool
from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory
from pathlib import Path
from time import sleep, time
Expand Down Expand Up @@ -253,6 +254,50 @@ def merge_index(*args: Any, **kwargs: Any):
raise ValueError(f'Invalid arguments to merge_index: {args}, {kwargs}')


def _download_url(url_info):
"""Download a file given URL information."""
from streaming.base.storage.download import download_file
src, dest, download_timeout = url_info
try:
download_file(src, dest, download_timeout)
except Exception as ex:
return f'Failed to download index.json: {src} to {dest}: {str(ex)}', ex
return dest, None
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved


def _merge_partition_indices(partition_indices):
XiaohanZhangCMU marked this conversation as resolved.
Show resolved Hide resolved
"""Function to be executed by each process to merge a subset of partition indices."""
shards = []
for partition_index in partition_indices:
p = Path(partition_index)
with open(partition_index, 'r') as f:
obj = json.load(f)
for shard in obj['shards']:
for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we really ought to make this a Shard method, which is subject to inheritance and so on

this code won't work for parquet shards :/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific suggestion how to deal with this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work for json/xsv or just for mds index files? Could you test that as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do json/xsv index files have the same file format? @knighton

if shard.get(key):
basename = shard[key]['basename']
shard[key]['basename'] = os.path.join(os.path.basename(p.parent), basename)
shards.extend(obj['shards'])
return shards


def _parallel_merge_partitions(partitions, n_processes=4):
"""Divide the list of partitions among multiple processes and merge them in parallel."""
with Pool(processes=n_processes) as pool:
# Split the list of partitions into N chunks where N is the number of processes
chunk_size = len(partitions) // n_processes + (len(partitions) % n_processes > 0)
XiaohanZhangCMU marked this conversation as resolved.
Show resolved Hide resolved
partition_chunks = [
partitions[i:i + chunk_size] for i in range(0, len(partitions), chunk_size)
]

# Process each chunk in parallel
results = pool.map(_merge_partition_indices, partition_chunks)
XiaohanZhangCMU marked this conversation as resolved.
Show resolved Hide resolved

# Combine the results from all processes
final_shards = [shard for result in results for shard in result]
return final_shards


def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]]],
out: Union[str, Tuple[str, str]],
keep_local: bool = True,
Expand All @@ -273,7 +318,6 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]]
keep_local (bool): Keep local copy of the merged index file. Defaults to ``True``
download_timeout (int): The allowed time for downloading each json file. Defaults to 60.
"""
from streaming.base.storage.download import download_file
from streaming.base.storage.upload import CloudUploader

if not index_file_urls or not out:
Expand All @@ -297,10 +341,10 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]]

# Prepare a temp folder to download index.json from remote if necessary. Removed in the end.
with tempfile.TemporaryDirectory() as temp_root:
logging.warning(f'A temporary folder {temp_root} is created to store index files')
logging.info(f'A temporary folder {temp_root} is created to store index files')
XiaohanZhangCMU marked this conversation as resolved.
Show resolved Hide resolved

# Copy files to a temporary directory. Download if necessary
partitions = []
download_tasks = []
for url in urls:
if isinstance(url, tuple):
src = url[0] if os.path.exists(url[0]) else url[1]
Expand All @@ -314,30 +358,18 @@ def _merge_index_from_list(index_file_urls: Sequence[Union[str, Tuple[str, str]]
f'Check data availability! local index {url[0]} is not accessible.' +
f'remote index {url[1]} does not have a valid url format')
dest = os.path.join(temp_root, path.lstrip('/'))
download_tasks.append((src, dest, download_timeout))
XiaohanZhangCMU marked this conversation as resolved.
Show resolved Hide resolved

try:
download_file(src, dest, download_timeout)
except Exception as ex:
raise RuntimeError(f'Failed to download index.json: {src} to {dest}') from ex

if not os.path.exists(dest):
raise FileNotFoundError(f'Index file {dest} does not exist or not accessible.')

partitions.append(dest)

# merge shards from all index files
shards = []
for partition_index in partitions:
p = Path(partition_index)
obj = json.load(open(partition_index))
for i in range(len(obj['shards'])):
shard = obj['shards'][i]
for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'):
if shard.get(key):
basename = shard[key]['basename']
obj['shards'][i][key]['basename'] = os.path.join(
os.path.basename(p.parent), basename)
shards += obj['shards']
with Pool(processes=os.cpu_count()) as pool:
results = pool.map(_download_url, download_tasks)
XiaohanZhangCMU marked this conversation as resolved.
Show resolved Hide resolved

partitions = []
for partition_index, error in results:
if error:
raise RuntimeError(partition_index)
partitions.append(partition_index)

shards = _parallel_merge_partitions(partitions)

# Save merged index locally
obj = {
Expand Down
138 changes: 123 additions & 15 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
import urllib.parse
from multiprocessing.shared_memory import SharedMemory as BuiltinSharedMemory
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union

import pytest

Expand Down Expand Up @@ -194,9 +194,9 @@ def test_format_remote_index_files(scheme: str):
assert obj.scheme == scheme


@pytest.mark.parametrize('index_file_urls_pattern', [1, 2, 3])
@pytest.mark.parametrize('keep_local', [True, False])
@pytest.mark.parametrize('scheme', ['gs://', 's3://', 'oci://'])
@pytest.mark.parametrize('index_file_urls_pattern', [1]) # , 2, 3])
XiaohanZhangCMU marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize('keep_local', [True]) # , False])
@pytest.mark.parametrize('scheme', ['gs://']) # , 's3://', 'oci://'])
def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_local: bool,
index_file_urls_pattern: int, scheme: str):
"""Validate the final merge index json for following patterns of index_file_urls:
Expand All @@ -206,10 +206,10 @@ def test_merge_index_from_list_local(local_remote_dir: Tuple[str, str], keep_loc
4. All URLs are tuple (local, remote). At least one url is not accessible locally -> download all
5. All URLs are str (remote) -> download all
"""
from decimal import Decimal
import random
import string

from pyspark.sql import SparkSession
from pyspark.sql.types import DecimalType, IntegerType, StringType, StructField, StructType

from streaming.base.converters import dataframeToMDS

Expand All @@ -223,15 +223,18 @@ def not_merged_index(index_file_path: str, out: str):
mds_out = out = local

spark = SparkSession.builder.getOrCreate() # pyright: ignore
schema = StructType([
StructField('id', IntegerType(), nullable=False),
StructField('name', StringType(), nullable=False),
StructField('amount', DecimalType(10, 2), nullable=False)
])
data = [(1, 'Alice', Decimal('123.45')), (2, 'Bob', Decimal('67.89')),
(3, 'Charlie', Decimal('987.65'))]
df = spark.createDataFrame(data=data, schema=schema).repartition(3)
mds_kwargs = {'out': mds_out, 'columns': {'id': 'int', 'name': 'str'}, 'keep_local': True}

def random_string(length=1000):
"""Generate a random string of fixed length."""
letters = string.ascii_letters + string.digits + string.punctuation + ' '
return ''.join(random.choice(letters) for _ in range(length))
XiaohanZhangCMU marked this conversation as resolved.
Show resolved Hide resolved

# Generate a DataFrame with 10000 rows of random text
num_rows = 100
data = [(i, random_string(), random_string()) for i in range(num_rows)]
df = spark.createDataFrame(data, ['id', 'name', 'amount'])

mds_kwargs = {'out': mds_out, 'columns': {'id': 'int64', 'name': 'str'}, 'keep_local': True}
dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs)

local_cu = CloudUploader.get(local, exist_ok=True, keep_local=True)
Expand All @@ -241,6 +244,16 @@ def not_merged_index(index_file_path: str, out: str):

if index_file_urls_pattern == 1:
merge_index(local_index_files, out, keep_local=keep_local)
d1 = json.load(open(os.path.join(out, 'index.json')))

_merge_index_from_list_serial(local_index_files, out, keep_local=keep_local)
d2 = json.load(open(os.path.join(out, 'index.json')))

print('d1 = ', d1)
print('d2 = ', d2)

assert len(d1['shards']) == len(d2['shards']), 'parallel and serial results different'
assert d1['shards'] == d2['shards'], 'parallel and serial results different'

if index_file_urls_pattern == 2:
with tempfile.TemporaryDirectory() as a_temporary_folder:
Expand Down Expand Up @@ -323,3 +336,98 @@ def flaky_function():
return "Third time's a charm"

assert flaky_function() == "Third time's a charm"


def _merge_index_from_list_serial(index_file_urls: Sequence[Union[str, Tuple[str, str]]],
XiaohanZhangCMU marked this conversation as resolved.
Show resolved Hide resolved
out: Union[str, Tuple[str, str]],
keep_local: bool = True,
download_timeout: int = 60) -> None:
import logging
import shutil
import urllib.parse
from collections import OrderedDict
from pathlib import Path

from streaming.base.format.index import get_index_basename
from streaming.base.storage.download import download_file
from streaming.base.storage.upload import CloudUploader

if not index_file_urls or not out:
return

# This is the index json file name, e.g., it is index.json as of 0.6.0
index_basename = get_index_basename()

cu = CloudUploader.get(out, keep_local=True, exist_ok=True)

# Remove duplicates, and strip '/' from right if any
index_file_urls = list(OrderedDict.fromkeys(index_file_urls))
urls = []
for url in index_file_urls:
if isinstance(url, str):
urls.append(url.rstrip('/').strip())
else:
urls.append((url[0].rstrip('/').strip(), url[1].rstrip('/').strip()))

# Prepare a temp folder to download index.json from remote if necessary. Removed in the end.
with tempfile.TemporaryDirectory() as temp_root:
logging.warning(f'A temporary folder {temp_root} is created to store index files')

# Copy files to a temporary directory. Download if necessary
partitions = []
for url in urls:
if isinstance(url, tuple):
src = url[0] if os.path.exists(url[0]) else url[1]
else:
src = url

obj = urllib.parse.urlparse(src)
scheme, bucket, path = obj.scheme, obj.netloc, obj.path
if scheme == '' and bucket == '' and path == '':
raise FileNotFoundError(
f'Check data availability! local index {url[0]} is not accessible.' +
f'remote index {url[1]} does not have a valid url format')
dest = os.path.join(temp_root, path.lstrip('/'))

try:
download_file(src, dest, download_timeout)
except Exception as ex:
raise RuntimeError(f'Failed to download index.json: {src} to {dest}') from ex

if not os.path.exists(dest):
raise FileNotFoundError(f'Index file {dest} does not exist or not accessible.')

partitions.append(dest)

# merge shards from all index files
shards = []
for partition_index in partitions:
p = Path(partition_index)
obj = json.load(open(partition_index))
for i in range(len(obj['shards'])):
shard = obj['shards'][i]
for key in ('raw_data', 'zip_data', 'raw_meta', 'zip_meta'):
if shard.get(key):
basename = shard[key]['basename']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait why are you just taking the basename of the child file here? and to be clear, why the basename of the parent as well, what if the dir to merge is >1 hops deep?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if anyone takes basename in the format literally, that would be a mistake lol, those are actually always relative paths, was originally named wrong

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@XiaohanZhangCMU was this resolved?

obj['shards'][i][key]['basename'] = os.path.join(
os.path.basename(p.parent), basename)
shards += obj['shards']

# Save merged index locally
obj = {
'version': 2,
'shards': shards,
}
merged_index_path = os.path.join(temp_root, index_basename)
with open(merged_index_path, 'w') as outfile:
json.dump(obj, outfile)

# Move merged index from temp path to local part in out
# Upload merged index to remote if out has remote part
shutil.move(merged_index_path, cu.local)
if cu.remote is not None:
cu.upload_file(index_basename)

# Clean up
if not keep_local:
shutil.rmtree(cu.local, ignore_errors=True)
Loading