Skip to content

Commit

Permalink
Add flush operation (#689)
Browse files Browse the repository at this point in the history
*Description of changes:*
When writing data to distributed tensors, we need to flush data to
ensure all data have been written to distributed tensors so that we can
perform the next operation (e.g., read).

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.

---------

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
3 people authored Dec 19, 2023
1 parent 5e7cbb8 commit 1bbb300
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 5 deletions.
1 change: 1 addition & 0 deletions .github/workflow_scripts/lint_check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ python3 -m pip install --upgrade prospector pip
yes | pip3 install astroid==v3.0.0
FORCE_CUDA=1 python3 -m pip install -e '.[test]' --no-build-isolation
pylint --rcfile=./tests/lint/pylintrc ./python/graphstorm/data/*.py
pylint --rcfile=./tests/lint/pylintrc ./python/graphstorm/distributed/
pylint --rcfile=./tests/lint/pylintrc ./python/graphstorm/dataloading/
pylint --rcfile=./tests/lint/pylintrc ./python/graphstorm/gconstruct/
pylint --rcfile=./tests/lint/pylintrc ./python/graphstorm/config/
Expand Down
17 changes: 17 additions & 0 deletions python/graphstorm/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
Copyright 2023 Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from .primitives import flush_data
91 changes: 91 additions & 0 deletions python/graphstorm/distributed/primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
Copyright 2023 Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
The primitives required for distributed computations.
"""

from dgl.distributed import rpc

from ..utils import barrier

FLUSH_DATA = 1000001

class FlushRequest(rpc.Request):
"""This request flushes data in DGL's distributed computation components.
When DGL performs writing to distributed tensors, it returns without data
being fully written to the distributed tensors. This operation is to ensure
that all data has been written to the distributed tensors on the server
when the operation returns. In practice, we don't need to perform any operations
in the request, except just sending responses to the client. The reason is
that when servers receive requests from clients, they processes them in
the FIFO order. When a server gets the opportunities to process the Flush request,
it means that the server has processed all requests before it and has written
data to the distributed tensors.
"""

def __init__(self):
pass

def __getstate__(self):
return None

def __setstate__(self, state):
pass

def process_request(self, server_state):
""" Process the request.
Here we don't need to do anything except returning the flush response.
"""
return FlushResponse()

class FlushResponse(rpc.Response):
"""Ack the flush request"""

def __init__(self):
pass

def __getstate__(self):
return None

def __setstate__(self, state):
pass

rpc.register_service(FLUSH_DATA, FlushRequest, FlushResponse)

def flush_data():
""" Flush data in distributed writes of DGL.
All processes need to talk to all server processes and make sure
all server processes complete processing the write requests issued by
the trainer processes. The reason that we need to have all processes
to communicate with all servers is that there are N*M communication channels,
where N is the number of trainer processes and M is the number of servers.
We need to make sure we flush data in all communication channels.
This function is called after trainer processes have finished issuing write
requests to servers and have written data to shared memory.
We can guarantee that all data are written to distributed tensors when
this function returns.
"""
request = FlushRequest()
# send request to all the server nodes
server_count = rpc.get_num_server()
for server_id in range(server_count):
rpc.send_request(server_id, request)
# recv response from all the server nodes
for _ in range(server_count):
_ = rpc.recv_response()
barrier()
2 changes: 2 additions & 0 deletions python/graphstorm/model/gnn_encoder_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .gs_layer import GSLayer

from ..utils import get_rank, barrier, is_distributed, create_dist_tensor, is_wholegraph
from ..distributed import flush_data

class GraphConvEncoder(GSLayer): # pylint: disable=abstract-method
r"""General encoder for graph data.
Expand Down Expand Up @@ -336,6 +337,7 @@ def dist_inference_one_layer(layer_id, g, dataloader, target_ntypes, layer, get_
if k in output_nodes:
assert k in y, "All mini-batch outputs should have the same tensor names."
y[k][output_nodes[k]] = h[k].cpu()
flush_data()
return y

def dist_inference(g, gnn_encoder, get_input_embeds, batch_size, fanout,
Expand Down
3 changes: 2 additions & 1 deletion python/graphstorm/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ..gconstruct.file_io import stream_dist_tensors_to_hdf5
from ..utils import get_rank, barrier, get_world_size, create_dist_tensor
from ..data.utils import alltoallv_cpu, alltoallv_nccl
from ..distributed import flush_data

# placeholder of the ntype for homogeneous graphs
NTYPE = dgl.NTYPE
Expand Down Expand Up @@ -1039,7 +1040,7 @@ def _load_id_mapping(self, g, ntype, id_mappings):
f"Expect {id_mapping.shape[0]}, but get {num_nodes}"
# Save ID mapping into dist tensor
id_mapping_info[th.arange(num_nodes)] = id_mapping
barrier()
flush_data()
return id_mapping_info

def shuffle_nids(self, ntype, nids):
Expand Down
8 changes: 4 additions & 4 deletions tests/end2end-tests/graphstorm-nc/mgpu_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ python3 $GS_HOME/tests/end2end-tests/check_np_infer_emb.py --train-embout /data/
error_and_exit $?

echo "**************dataset: Movielens, do inference on saved model, remap without shared file system"
python3 -m graphstorm.run.gs_node_classification --inference --workspace $GS_HOME/inference_scripts/np_infer/ --num-trainers $NUM_INFERs --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_infer.yaml --use-mini-batch-infer false --save-embed-path /data/gsgnn_nc_ml/infer-emb-nosfs/ --restore-model-path /data/gsgnn_nc_ml/epoch-$best_epoch/ --save-prediction-path /data/gsgnn_nc_ml/prediction-nosfs/ --logging-file /tmp/log.txt --preserve-input True --with-shared-fs False
python3 -m graphstorm.run.gs_node_classification --inference --workspace $GS_HOME/inference_scripts/np_infer/ --num-trainers $NUM_INFERs --num-servers 2 --num-samplers 0 --part-config /data/movielen_100k_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_infer.yaml --use-mini-batch-infer false --save-embed-path /data/gsgnn_nc_ml/infer-emb-nosfs/ --restore-model-path /data/gsgnn_nc_ml/epoch-$best_epoch/ --save-prediction-path /data/gsgnn_nc_ml/prediction-nosfs/ --logging-file /tmp/log.txt --preserve-input True --with-shared-fs False

error_and_exit $?
rm /tmp/log.txt

python3 $GS_HOME/tests/end2end-tests/check_np_infer_emb.py --train-embout /data/gsgnn_nc_ml/emb/ --infer-embout /data/gsgnn_nc_ml/infer-emb-nosfs/

echo "**************dataset: Movielens, use gen_embeddings to generate embeddings on node classification"
python3 -m graphstorm.run.gs_gen_node_embedding --workspace $GS_HOME/training_scripts/gsgnn_np/ --num-trainers $NUM_TRAINERS --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc.yaml --restore-model-path /data/gsgnn_nc_ml/epoch-$best_epoch/ --save-embed-path /data/gsgnn_nc_ml/save-emb/ --use-mini-batch-infer false --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True
python3 -m graphstorm.run.gs_gen_node_embedding --workspace $GS_HOME/training_scripts/gsgnn_np/ --num-trainers $NUM_TRAINERS --num-servers 2 --num-samplers 0 --part-config /data/movielen_100k_train_val_1p_4t/movie-lens-100k.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc.yaml --restore-model-path /data/gsgnn_nc_ml/epoch-$best_epoch/ --save-embed-path /data/gsgnn_nc_ml/save-emb/ --use-mini-batch-infer false --logging-file /tmp/train_log.txt --logging-level debug --preserve-input True

error_and_exit $?

Expand Down Expand Up @@ -250,7 +250,7 @@ error_and_exit $?
rm /tmp/train_log.txt

echo "**************dataset: Movielens, do inference on saved model, RGCN layer: 1, node feat: BERT nodes: movie, user"
python3 -m graphstorm.run.gs_node_classification --inference --workspace $GS_HOME/inference_scripts/np_infer/ --num-trainers $NUM_INFERs --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lm_encoder_train_val_1p_4t/movie-lens-100k-text.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_text_infer.yaml --use-mini-batch-infer false --save-embed-path /data/gsgnn_nc_ml_text/infer-emb/ --restore-model-path /data/gsgnn_nc_ml_text/epoch-$best_epoch/ --save-prediction-path /data/gsgnn_nc_ml_text/prediction/ --logging-file /tmp/log.txt --logging-level debug --preserve-input True
python3 -m graphstorm.run.gs_node_classification --inference --workspace $GS_HOME/inference_scripts/np_infer/ --num-trainers $NUM_INFERs --num-servers 2 --num-samplers 0 --part-config /data/movielen_100k_lm_encoder_train_val_1p_4t/movie-lens-100k-text.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_text_infer.yaml --use-mini-batch-infer false --save-embed-path /data/gsgnn_nc_ml_text/infer-emb/ --restore-model-path /data/gsgnn_nc_ml_text/epoch-$best_epoch/ --save-prediction-path /data/gsgnn_nc_ml_text/prediction/ --logging-file /tmp/log.txt --logging-level debug --preserve-input True

error_and_exit $?

Expand Down Expand Up @@ -326,7 +326,7 @@ echo "The best model is saved in epoch $best_epoch"
rm /tmp/train_log.txt

echo "**************dataset: Movielens, do inference on saved model, node feat: BERT nodes: movie, user, with warmup"
python3 -m graphstorm.run.gs_node_classification --inference --workspace $GS_HOME/inference_scripts/np_infer/ --num-trainers $NUM_INFERs --num-servers 1 --num-samplers 0 --part-config /data/movielen_100k_lm_encoder_train_val_1p_4t/movie-lens-100k-text.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_text_infer.yaml --use-mini-batch-infer false --save-embed-path /data/gsgnn_nc_ml_text/infer-emb/ --restore-model-path /data/gsgnn_nc_ml_text/epoch-$best_epoch/ --save-prediction-path /data/gsgnn_nc_ml_text/prediction/ --logging-file /tmp/log.txt --logging-level debug --preserve-input True
python3 -m graphstorm.run.gs_node_classification --inference --workspace $GS_HOME/inference_scripts/np_infer/ --num-trainers $NUM_INFERs --num-servers 2 --num-samplers 0 --part-config /data/movielen_100k_lm_encoder_train_val_1p_4t/movie-lens-100k-text.json --ip-config ip_list.txt --ssh-port 2222 --cf ml_nc_text_infer.yaml --use-mini-batch-infer false --save-embed-path /data/gsgnn_nc_ml_text/infer-emb/ --restore-model-path /data/gsgnn_nc_ml_text/epoch-$best_epoch/ --save-prediction-path /data/gsgnn_nc_ml_text/prediction/ --logging-file /tmp/log.txt --logging-level debug --preserve-input True

error_and_exit $?

Expand Down

0 comments on commit 1bbb300

Please sign in to comment.