Skip to content

Commit

Permalink
[Partition Image] Add dependency for gs partition (#790)
Browse files Browse the repository at this point in the history
*Issue #, if available:*

*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
  • Loading branch information
jalencato authored Mar 29, 2024
1 parent 6e63e4c commit 8ee4db3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
3 changes: 2 additions & 1 deletion docker/Dockerfile.local
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ RUN apt-get install -y python3-pip git wget psmisc
RUN apt-get install -y cmake

# Install Pytorch
RUN pip3 install networkx==3.1
RUN pip3 install networkx==3.1 pydantic
RUN pip3 install torch==2.1.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118

# Install DGL
RUN pip3 install dgl==1.0.4+cu117 -f https://data.dgl.ai/wheels/cu117/repo.html
ENV PYTHONPATH="/root/dgl/tools/:${PYTHONPATH}"

# Install related Python packages
RUN pip3 install ogb==1.3.6 scipy pyarrow boto3 scikit-learn transformers
Expand Down
11 changes: 8 additions & 3 deletions python/graphstorm/gpartition/dist_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def run_build_dglgraph(
ip_list,
output_path,
metadata_filename,
dgl_tool_path):
dgl_tool_path,
ssh_port):
""" Build DistDGL Graph
Parameters
Expand All @@ -54,6 +55,8 @@ def run_build_dglgraph(
Output Path
metadata_filename: str
The filename for the graph partitioning metadata file we'll use to determine data sources.
ssh_port: int
SSH port
"""
# Get the python interpreter used right now.
# If we can not get it we go with the default `python3`
Expand All @@ -68,7 +71,7 @@ def run_build_dglgraph(
"--partitions-dir", partitions_dir,
"--ip-config", ip_list,
"--out-dir", output_path,
"--ssh-port", "22",
"--ssh-port", f"{ssh_port}",
"--python-path", f"{python_bin}",
"--log-level", logging.getLevelName(logging.root.getEffectiveLevel()),
"--save-orig-nids",
Expand Down Expand Up @@ -134,7 +137,8 @@ def main():
args.ip_list,
os.path.join(output_path, "dist_graph"),
args.metadata_filename,
args.dgl_tool_path)
args.dgl_tool_path,
args.ssh_port)

logging.info("DGL graph building took %f sec", part_end - time.time())

Expand All @@ -153,6 +157,7 @@ def parse_args() -> argparse.Namespace:
help="Path to store the partitioned data")
argparser.add_argument("--num-parts", type=int, required=True,
help="Number of partitions to generate")
argparser.add_argument("--ssh-port", type=int, default=22, help="SSH Port")
argparser.add_argument("--dgl-tool-path", type=str,
help="The path to dgl/tools")
argparser.add_argument("--partition-algorithm", type=str, default="random",
Expand Down

0 comments on commit 8ee4db3

Please sign in to comment.