From cd0dad016b6c82f6f13bfde08d44a4cc5dbd1667 Mon Sep 17 00:00:00 2001
From: Theodore Vasiloudis <thvasilo@amazon.com>
Date: Wed, 8 Nov 2023 07:11:52 +0000
Subject: [PATCH] [GSProcessing] Add thread-parallelism to in-memory
 repartition implementation.

Also add more detailed docs about re-partitioning step.
---
 .../gs-processing/usage/amazon-sagemaker.rst  |   1 +
 .../gs-processing/usage/emr-serverless.rst    |   4 +-
 docs/source/gs-processing/usage/example.rst   |   3 +
 .../usage/row-count-alignment.rst             | 143 ++++++++++++++
 .../repartition_files.py                      | 181 +++++++++++-------
 .../scripts/run_repartitioning.py             |  19 ++
 .../tests/test_repartition_files.py           |  25 ++-
 7 files changed, 304 insertions(+), 72 deletions(-)
 create mode 100644 docs/source/gs-processing/usage/row-count-alignment.rst

diff --git a/docs/source/gs-processing/usage/amazon-sagemaker.rst b/docs/source/gs-processing/usage/amazon-sagemaker.rst
index 96522d9cf1..78621c4909 100644
--- a/docs/source/gs-processing/usage/amazon-sagemaker.rst
+++ b/docs/source/gs-processing/usage/amazon-sagemaker.rst
@@ -75,6 +75,7 @@ job, followed by the re-partitioning job, both on SageMaker:
     The re-partitioning job runs on a single instance, so for large graphs you will
     want to scale up to an instance with more memory to avoid memory errors. `ml.r5` instances
     should allow you to re-partition graph data with billions of nodes and edges.
+    For more details on the re-partitioning step see ::doc:`row-count-alignment`.
 
 The ``--num-output-files`` parameter
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/docs/source/gs-processing/usage/emr-serverless.rst b/docs/source/gs-processing/usage/emr-serverless.rst
index af987a2519..adef4a4a05 100644
--- a/docs/source/gs-processing/usage/emr-serverless.rst
+++ b/docs/source/gs-processing/usage/emr-serverless.rst
@@ -243,10 +243,12 @@ and building the GSProcessing SageMaker ECR image:
         --instance-type ${INSTANCE_TYPE} --wait-for-job
 
 
-
 Note that ``${OUTPUT_PREFIX}`` here will need to match the value assigned when launching
 the EMR-S job, i.e. ``"s3://${OUTPUT_BUCKET}/gsprocessing/emr-s/small-graph/4files/"``
 
+For more details on the re-partitioning step see
+::doc:`row-count-alignment`.
+
 Examine the output
 ------------------
 
diff --git a/docs/source/gs-processing/usage/example.rst b/docs/source/gs-processing/usage/example.rst
index 0034e6aa8c..6fe435f355 100644
--- a/docs/source/gs-processing/usage/example.rst
+++ b/docs/source/gs-processing/usage/example.rst
@@ -188,6 +188,9 @@ guarantees the data conform to the expectations of DGL:
 
     gs-repartition --input-prefix /tmp/gsprocessing-example/
 
+For more details on the re-partitioning step see
+::doc:`row-count-alignment`.
+
 .. _gsp-examining-output:
 
 Examining the job output
diff --git a/docs/source/gs-processing/usage/row-count-alignment.rst b/docs/source/gs-processing/usage/row-count-alignment.rst
new file mode 100644
index 0000000000..3c73b57953
--- /dev/null
+++ b/docs/source/gs-processing/usage/row-count-alignment.rst
@@ -0,0 +1,143 @@
+Row count alignment
+===================
+
+After the data processing step we need to perform an additional step
+to ensure that our processed data conform to the assumptions of the distributed
+partitioning pipeline. In particular DistPartitioning expects that:
+
+* For each node/edge type:
+    * Every file output has the same number of files.
+        * For example, for an edge type ``x:to:y``, that had
+          two features, ``feat1`` and ``feat2``, the number
+          (partition count) of the files produced separately
+          for ``feat1``, ``feat2`` and the edge structure
+          needs to be the same.
+    * Each respective file in the output has the same row count.
+        * For example, assuming ``feat1``, ``feat2``, and ``edges``
+          had 2 part-files each, the number of rows in file-part-1
+          needs to be the same across all three file sources, and the
+          number of rows in file-part-2 needs to be the same
+          across all three file sources.
+
+
+In code the above means:
+
+.. code-block:: python
+
+    files_for_feat1 = os.listdir("edge_data/x:to:y-feat1/")
+    files_for_feat2 = os.listdir("edge_data/x:to:y-feat2/")
+    files_for_edges = os.listdir("edges/x:to:y-feat2/")
+
+    num_feat1_files = len(files_for_feat1)
+    num_feat2_files = len(files_for_feat2)
+    num_edges_files = len(files_for_edges)
+
+    assert num_feat1_files == num_feat2_files == num_edges_files
+
+In addition, for each node/edge type, the row counts of each respective file
+in their output needs to match, i.e.:
+
+.. code-block:: python
+
+    from pyarrow import parquet as pq
+
+    row_counts_feat1 = [pq.read_metadata(fpath).num_rows for fpath in files_for_feat1]
+    row_counts_feat2 = [pq.read_metadata(fpath).num_rows for fpath in files_for_feat2]
+    row_counts_edges = [pq.read_metadata(fpath).num_rows for fpath in files_for_edges]
+
+    assert row_counts_feat1 == row_counts_feat2 == row_counts_edges
+
+Note that these assumptions only apply `per type`; file counts and per-file
+row counts do not need to match between different node/edge types.
+
+Because of the distributed and speculative nature of Spark execution, it's
+not possible to guarantee that the row counts will match between the file
+outputs we create for every node types features, or the structure and
+features of an edge type.
+
+Therefore and additional step which we call `repartitioning` is necessary
+after the processing step. This step performs two functions:
+
+1. Align the row counts for each edge/node type.
+2. Ensure that data shapes for masks and labels match what
+   what DistPartitioning expects, which are flat ``(N,)`` arrays,
+   instead of what Spark produces which is ``(N, 1)`` Parquet output.
+
+Local repartitioning
+--------------------
+
+The simplest way to apply the re-partitioning step is to do so using a local
+installation of GSProcessing:
+
+.. code-block:: bash
+
+    gs-repartition --input-prefix local_or_s3_path_to_processed_data
+
+The repartitioning command will call the ``graphstorm_processing/repartition_files.py``
+Python script and execute the step locally. The script only requires the
+``input-prefix`` argument to function, but provides optional arguments
+to customize the input/output file names and whether to use an
+in-memory or file streaming implementation for row-count alignment.
+
+You can use `gs-repartition --help` for more details on the arguments.
+
+Repartitioning on SageMaker
+---------------------------
+
+To avoid local processing it is also possible to run re-partitioning on
+SageMaker. You would need to complete the steps described in
+:doc:`distributed-processing-setup` to build and push a SageMaker
+ECR image, and then you're able to launch the re-partitioning job
+on SageMaker:
+
+.. code-block:: bash
+
+    bash docker/build_gsprocessing_image.sh --environment sagemaker --region ${REGION}
+    bash docker/push_gsprocessing_image.sh --environment sagemaker --region ${REGION}
+
+    SAGEMAKER_ROLE_NAME="enter-your-sagemaker-execution-role-name-here"
+    IMAGE_URI="${ACCOUNT}.dkr.ecr.${REGION}.amazonaws.com/graphstorm-processing-sagemaker:0.2.1"
+    ROLE="arn:aws:iam::${ACCOUNT}:role/service-role/${SAGEMAKER_ROLE_NAME}"
+    INSTANCE_TYPE="ml.t3.xlarge"
+
+    python scripts/run_repartitioning.py --s3-input-prefix ${PROCESSED_OUTPUT} \
+        --role ${ROLE} --image ${IMAGE_URI}  --config-filename "metadata.json" \
+        --instance-type ${INSTANCE_TYPE} --wait-for-job
+
+File streaming repartitioning
+-----------------------------
+
+The default implementation of re-partitioning will load each
+feature/edge type in memory and perform the row-count alignment.
+Using SageMaker Processing with instances such as ``ml.r5.24xlarge``
+with 768GB of memory, you should be able to process data with
+billions of nodes/edges and hundreds of features.
+
+If however your data are so large that they cause out-of-memory
+errors even on SageMaker, you can use the file streaming implementation
+of re-partitioning, which should allow you to scale to any file size.
+
+To do so, simply modify your call to include:
+
+.. code-block:: bash
+
+    gs-repartition --input-prefix local_or_s3_path_to_processed_data \
+        --streaming-repartitioning True
+
+A similar modification can be applied to the SageMaker launch call:
+
+.. code-block:: bash
+
+    python scripts/run_repartitioning.py --s3-input-prefix ${PROCESSED_OUTPUT} \
+        --role ${ROLE} --image ${IMAGE_URI}  --config-filename "metadata.json" \
+        --instance-type ${INSTANCE_TYPE} --wait-for-job \
+        --streaming-repartitioning True
+
+The file streaming implementation will hold at most 2 files worth of data
+in memory, so by choosing an appropriate file number when processing you should
+be able to process any data size.
+
+.. note:: text
+
+    The file streaming implementation will be much slower than the in-memory
+    one, so only use in case no instance size can handle your data.
diff --git a/graphstorm-processing/graphstorm_processing/repartition_files.py b/graphstorm-processing/graphstorm_processing/repartition_files.py
index 1d13d34253..bddff0e5ce 100644
--- a/graphstorm-processing/graphstorm_processing/repartition_files.py
+++ b/graphstorm-processing/graphstorm_processing/repartition_files.py
@@ -34,8 +34,10 @@
 import sys
 import tempfile
 import time
-from pathlib import Path
+import uuid
 from collections import Counter, defaultdict
+from itertools import accumulate
+from pathlib import Path
 from typing import Collection, Dict, List, Optional
 
 import boto3
@@ -71,6 +73,7 @@ def __init__(
         filesystem_type: str,
         region: Optional[str] = None,
         verify_outputs: bool = True,
+        streaming_repartitioning=False,
     ):
         assert filesystem_type in [
             "local",
@@ -86,6 +89,7 @@ def __init__(
         else:
             self.pyarrow_fs = fs.LocalFileSystem()
         self.verify_outputs = verify_outputs
+        self.streaming_repartitioning = streaming_repartitioning
 
     def read_dataset_from_relative_path(self, relative_path: str) -> ds.Dataset:
         """
@@ -125,8 +129,7 @@ def write_parquet_to_relative_path(
         # this is called to ensure consistency?
         file_path = os.path.join(self.input_prefix, relative_path)
         if self.filesystem_type == "local":
-            if not os.path.exists(Path(file_path).parent):
-                os.makedirs(Path(file_path).parent)
+            os.makedirs(Path(file_path).parent, exist_ok=True)
         pq.write_table(table, file_path, filesystem=self.pyarrow_fs, compression="snappy")
         if self.verify_outputs:
             expected_rows = desired_count if desired_count else table.num_rows
@@ -135,17 +138,45 @@ def write_parquet_to_relative_path(
 
     @staticmethod
     def create_new_relative_path_from_existing(
-        original_relative_path: str, repartitioned_file_index: int
+        original_relative_path: str, repartitioned_file_index: int, suffix: str = None
     ) -> str:
-        """
+        """Changes the index in the `original_relative_path` to `original_relative_path`.
+
         Given a path of the form 'path/to/parquet/part-00001-filename.snappy.parquet', changes the
         numerical part of the filename to match the provided `repartitioned_file_index`,
-        and changes the path prefix to 'path/to/parquet-repartitioned/'.
+        and changes the path prefix to 'path/to/parquet-repartitioned/', or
+        'path/to/parquet-repartitioned-{suffix}/' if `suffix` is provided.
+
+
+        Parameters
+        ----------
+        original_relative_path : str
+            Filepath of the form 'path/to/parquet/part-00001-filename.snappy.parquet'.
+        repartitioned_file_index : int
+            The new index to assign to the file
+        suffix : str, optional
+            Suffix to add to the returned path, by default None
+
+        Returns
+        -------
+        str
+            The `original_relative_path` with the part index modified, and `parquet/`
+            modified to `parquet-repartitioned` or `parquet-repartitioned-{suffix}/`}
 
-        Example:
-            > create_new_relative_path_from_existing(
+        Raises
+        ------
+        RuntimeError
+            If the filename does not conform to the ``r"^part-[0-9]{5}"`` regex,
+            which is the expected Spark filename output.
+
+        Examples
+        --------
+            >>> create_new_relative_path_from_existing(
                 "path/to/parquet/part-00001-filename.snappy.parquet", 3)
-            > "path/to/parquet-repartitioned/part-00003-filename.snappy.parquet"
+            "path/to/parquet-repartitioned-{uuid}/part-00003-filename.snappy.parquet"
+            >>> create_new_relative_path_from_existing(
+                "path/to/parquet/part-00001-filename.snappy.parquet", 3, "my-suffix")
+            "path/to/parquet-repartitioned-my-suffix/part-00003-filename.snappy.parquet"
         """
         original_relative_path_obj = Path(original_relative_path)
         # We expect files to have a path of the form /path/to/parquet/part-00001.snappy.parquet
@@ -164,18 +195,21 @@ def create_new_relative_path_from_existing(
             r"^part-[0-9]{5}", padded_file_idx, original_relative_path_obj.parts[-1]
         )
 
+        new_sub_path = (
+            "parquet-repartitioned" if suffix is None else f"parquet-repartitioned-{suffix}"
+        )
         new_relative_path = "/".join(
-            [*original_relative_path_obj.parts[:-2], "parquet-repartitioned", new_file_name]
+            [*original_relative_path_obj.parts[:-2], new_sub_path, new_file_name]
         )
 
         return new_relative_path
 
-    def repartition_parquet_files_in_memory(
+    def repartition_parquet_files(
         self, data_entry_dict: Dict, desired_counts: Collection[int]
     ) -> Dict:
         """
-        Re-partitions the parquet files in `data_entry_dict` so that their row count
-        matches the one provided in desired_counts. We assume that the file counts between the
+        Re-partitions the Parquet files in `data_entry_dict` so that their row count
+        matches the one provided in `desired_counts`. We assume that the numer of files between the
         input and output will remain the same.
 
         The output is written to storage and the `data_entry_dict` dictionary file is
@@ -208,17 +242,28 @@ def repartition_parquet_files_in_memory(
         Dict
             A data format dictionary with the row
             counts updated to match desired_counts.
+        """
+        if self.streaming_repartitioning:
+            return self._repartition_parquet_files_streaming(data_entry_dict, desired_counts)
+        else:
+            return self._repartition_parquet_files_in_memory(data_entry_dict, desired_counts)
 
-        Raises
-        ------
-        RuntimeError
-            In cases where the sum of the desired counts does not match
-            the sum of actual file row counts, or the files are not in Parquet format.
+    def _repartition_parquet_files_in_memory(
+        self, data_entry_dict: Dict, desired_counts: Collection[int]
+    ) -> Dict:
+        """
+        In-memory, thread-parallel implementation of Parquet file repartitioning.
 
         Notes
         -----
         This function assumes the entire dataset described in `data_entry_dict`
         can be held in memory.
+
+        Raises
+        ------
+        RuntimeError
+            In cases where the sum of the desired counts does not match
+            the sum of actual file row counts, or the files are not in Parquet format.
         """
         if sum(desired_counts) != sum(data_entry_dict["row_counts"]):
             raise RuntimeError(
@@ -260,31 +305,37 @@ def repartition_parquet_files_in_memory(
         logging.debug("Desired counts: %s", desired_counts)
         logging.debug("Row counts: %s", data_entry_dict["row_counts"])
 
-        offset = 0
-        new_data_entries = []
         # From the dataset we read into memory, we slice a part according to desired_counts and
         # write a new file to S3.
-        for idx, desired_count in enumerate(desired_counts):
-            sliced_data = table.slice(offset=offset, length=desired_count)
-            new_relative_path = self.create_new_relative_path_from_existing(datafile_list[0], idx)
-            self.write_parquet_to_relative_path(new_relative_path, sliced_data, desired_count)
-            new_data_entries.append(new_relative_path)
-            offset += desired_count
+        offsets = accumulate([0] + desired_counts)
+        zero_copy_slices = [
+            table.slice(offset=offset, length=desired_count)
+            for offset, desired_count in zip(offsets, desired_counts)
+        ]
+        uid_for_entry = uuid.uuid4().hex[:8]
+        relative_paths = [
+            self.create_new_relative_path_from_existing(datafile_list[0], idx, uid_for_entry)
+            for idx in range(len(desired_counts))
+        ]
+        with Parallel(n_jobs=min(16, os.cpu_count()), verbose=10, prefer="threads") as parallel:
+            parallel(
+                delayed(self.write_parquet_to_relative_path)(
+                    relative_path,
+                    slice,
+                )
+                for slice, relative_path in zip(zero_copy_slices, relative_paths)
+            )
 
-        data_entry_dict["data"] = new_data_entries
+        data_entry_dict["data"] = relative_paths
         data_entry_dict["row_counts"] = desired_counts
 
         return data_entry_dict
 
-    def repartition_parquet_files_streaming(
+    def _repartition_parquet_files_streaming(
         self, data_entry_dict: Dict, desired_counts: Collection[int]
     ) -> Dict:
         """Repartition parquet files using file streaming.
 
-        Re-partitions the parquet files in data_entry_dict so that their row count
-        matches the one provided in desired_counts. We assume that the file counts between the
-        input and output will remain the same.
-
         This function will maintain at most 2 files worth of data in memory.
 
         We iterate over the desired counts and compare against the existing file counts.
@@ -297,33 +348,6 @@ def repartition_parquet_files_streaming(
         The output is written to storage and the `data_entry_dict` dictionary file is
         modified in-place and returned.
 
-        Parameters
-        ----------
-        data_entry_dict : Dict
-            A data format dictionary formatted as:
-            {
-                "format": {
-                    "name": "parquet"
-                },
-                "data": [
-                    "relative/path/to/file1.parquet",
-                    "relative/path/to/file2.parquet",
-                    ...
-                ] # n files
-                "row_counts": [
-                    10,
-                    12,
-                    ...
-                ] # n row counts
-            }
-        desired_counts : Collection[int]
-            A list of desired row counts.
-
-        Returns
-        -------
-            A data format dictionary with the row
-            count of each file updated to match desired_counts.
-
         Raises
         ------
         RuntimeError
@@ -345,7 +369,11 @@ def repartition_parquet_files_streaming(
         remainder_table = None  # pyarrow.Table
         new_data_entries = []
 
+        # TODO: Instead of limiting to two tables in memory, we could monitor memory and load files
+        # until we run out of memory and process together to speed up the process.
+
         # TODO: Zip with original counts, if num rows match, no need to read the file into memory
+        uid_for_entry = uuid.uuid4().hex[:8]
         for repartitioned_file_index, desired_count in enumerate(desired_counts):
             logging.debug(
                 "At start of iter: repartitioned_file_index: %d, original_file_index: %d",
@@ -358,7 +386,7 @@ def repartition_parquet_files_streaming(
             # and rename to
             # relative/path/to/file/parquet-repartitioned/part-00000.snappy.parquet
             new_relative_path = self.create_new_relative_path_from_existing(
-                original_relative_path, repartitioned_file_index
+                original_relative_path, repartitioned_file_index, uid_for_entry
             )
 
             remainder_used = False
@@ -698,6 +726,14 @@ def parse_args(args):
         "from the distributed processing pipeline. "
         "Can be a local path (starting with '/') or S3 prefix (starting with 's3://').",
     )
+    parser.add_argument(
+        "--streaming-repartitioning",
+        type=lambda x: (str(x).lower() in ["true", "1"]),
+        default=False,
+        help="When True will use low-memory file-streaming repartitioning. "
+        "Note that this option is much slower than the in-memory default.",
+        choices=["True", "False", "1", "0"],
+    )
     parser.add_argument(
         "--metadata-file-name",
         default="metadata.json",
@@ -741,6 +777,9 @@ def main():
         Prefix path to where the output was generated
         from the distributed processing pipeline.
         Can be a local path or S3 prefix (starting with 's3://').
+    streaming_repartitioning: bool
+        When True will use low-memory file-streaming repartitioning.
+        Note that this option is much slower than the in-memory default.
     metadata_file_name : str
         Name of the original partitioning pipeline metadata file.
     updated_metadata_file_name : str
@@ -848,7 +887,11 @@ def main():
             reverse_edge_type_name = f"{dst}:{relation}-rev:{src}"
             most_frequent_counts = list(edge_row_counts_frequencies[type_name].most_common(1)[0][0])
             repartitioner = ParquetRepartitioner(
-                input_prefix, filesystem_type, region, verify_outputs=True
+                input_prefix,
+                filesystem_type,
+                region,
+                verify_outputs=True,
+                streaming_repartitioning=args.streaming_repartitioning,
             )
 
             structure_counts = edge_structure_meta[type_name]["row_counts"]
@@ -862,7 +905,7 @@ def main():
                     type_name,
                 )
 
-                edge_structure_meta[type_name] = repartitioner.repartition_parquet_files_in_memory(
+                edge_structure_meta[type_name] = repartitioner.repartition_parquet_files(
                     edge_structure_meta[type_name], most_frequent_counts
                 )
             else:
@@ -886,7 +929,7 @@ def main():
                     )
                     edge_structure_meta[
                         reverse_edge_type_name
-                    ] = repartitioner.repartition_parquet_files_in_memory(
+                    ] = repartitioner.repartition_parquet_files(
                         edge_structure_meta[reverse_edge_type_name], most_frequent_counts
                     )
 
@@ -900,7 +943,7 @@ def main():
                         len(type_data_dict),
                         feature_name,
                     )
-                    feature_dict = repartitioner.repartition_parquet_files_in_memory(
+                    feature_dict = repartitioner.repartition_parquet_files(
                         feature_dict, most_frequent_counts
                     )
                     if (
@@ -962,7 +1005,11 @@ def main():
         for type_idx, (type_name, type_data_dict) in enumerate(node_data_meta.items()):
             most_frequent_counts = list(node_row_counts_frequencies[type_name].most_common(1)[0][0])
             repartitioner = ParquetRepartitioner(
-                input_prefix, filesystem_type, region, verify_outputs=True
+                input_prefix,
+                filesystem_type,
+                region,
+                verify_outputs=True,
+                streaming_repartitioning=args.streaming_repartitioning,
             )
 
             for feature_idx, (feature_name, feature_dict) in enumerate(type_data_dict.items()):
@@ -974,9 +1021,7 @@ def main():
                         feature_idx + 1,
                         len(type_data_dict),
                     )
-                    repartitioner.repartition_parquet_files_in_memory(
-                        feature_dict, most_frequent_counts
-                    )
+                    repartitioner.repartition_parquet_files(feature_dict, most_frequent_counts)
                 else:
                     logging.info(
                         "Skipping repartitioning feature files for node type '%s',"
diff --git a/graphstorm-processing/scripts/run_repartitioning.py b/graphstorm-processing/scripts/run_repartitioning.py
index 36263f5865..6335f25449 100644
--- a/graphstorm-processing/scripts/run_repartitioning.py
+++ b/graphstorm-processing/scripts/run_repartitioning.py
@@ -73,6 +73,21 @@ def parse_args() -> argparse.Namespace:
     """Parse repartitioning args"""
     parser = script_utils.get_common_parser()  # type: argparse.ArgumentParser
 
+    parser.add_argument(
+        "--streaming-repartitioning",
+        type=lambda x: (str(x).lower() in ["true", "1"]),
+        default=False,
+        help="When True will use low-memory file-streaming repartitioning. "
+        "Note that this option is much slower than the in-memory default.",
+        choices=["True", "False", "1", "0"],
+    )
+    parser.add_argument(
+        "--updated-metadata-file-name",
+        type=str,
+        help="The name for the updated metadata file.",
+        default="updated_row_counts_metadata.json",
+    )
+
     return parser.parse_args()
 
 
@@ -90,10 +105,14 @@ def main():
     container_args = [
         "--input-prefix",
         s3_input_prefix,
+        "--streaming-repartitioning",
+        "True" if args.streaming_repartitioning else "False",
         "--metadata-file-name",
         args.config_filename,
         "--log-level",
         args.container_log_level,
+        "--updated-metadata-file-name",
+        args.updated_metadata_file_name,
     ]
 
     if args.job_name is None:
diff --git a/graphstorm-processing/tests/test_repartition_files.py b/graphstorm-processing/tests/test_repartition_files.py
index eb9dc45bba..01d3ee08e9 100644
--- a/graphstorm-processing/tests/test_repartition_files.py
+++ b/graphstorm-processing/tests/test_repartition_files.py
@@ -14,11 +14,13 @@
 limitations under the License.
 """
 import json
+from pathlib import Path
 import os
 import shutil
 import sys
 from typing import Callable, List
 
+from numpy.testing import assert_array_equal
 import pytest
 import pyarrow as pa
 from pyarrow import parquet as pq
@@ -161,10 +163,13 @@ def create_parquet_files_fixture():
 )
 @pytest.mark.parametrize(
     "partition_function_name",
-    ["repartition_parquet_files_in_memory", "repartition_parquet_files_streaming"],
+    [
+        "_repartition_parquet_files_in_memory",
+        "_repartition_parquet_files_streaming",
+    ],
 )
 def test_repartition_functions(desired_counts: List[int], partition_function_name: str):
-    """Test the repartition function, streaming and in-memory"""
+    """Test the repartition functions, streaming and in-memory"""
     assert sum(desired_counts) == 50
 
     my_partitioner = ParquetRepartitioner(TEMP_DATA_PREFIX, filesystem_type="local")
@@ -186,13 +191,27 @@ def test_repartition_functions(desired_counts: List[int], partition_function_nam
     assert updated_meta["row_counts"] == desired_counts
     assert len(updated_meta["data"]) == len(desired_counts)
 
-    # Ensure actual rows match to expectation
+    # Ensure actual row counts match to expectation
     for expected_count, result_filepath in zip(desired_counts, updated_meta["data"]):
         assert (
             expected_count
             == pq.read_metadata(os.path.join(TEMP_DATA_PREFIX, result_filepath)).num_rows
         )
 
+    # Ensure order/content of rows matches to expectation
+    original_table = (
+        pq.read_table(os.path.join(TEMP_DATA_PREFIX, Path(edge_type_meta["data"][0]).parent))
+        .to_pandas()
+        .to_numpy()
+    )
+    repartitioned_table = (
+        pq.read_table(os.path.join(TEMP_DATA_PREFIX, Path(updated_meta["data"][0]).parent))
+        .to_pandas()
+        .to_numpy()
+    )
+
+    assert_array_equal(original_table, repartitioned_table)
+
 
 # TODO: Add simple tests for the load functions