Skip to content

Commit

Permalink
update examples
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Aug 27, 2024
1 parent 5e1b986 commit a365a24
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 31 deletions.
24 changes: 18 additions & 6 deletions python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_mnmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ def run_train(
wall_clock_start,
tempdir=None,
num_layers=3,
in_memory=False,
seeds_per_call=-1,
):
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005)

Expand All @@ -196,20 +198,23 @@ def run_train(
from cugraph_pyg.loader import NeighborLoader

ix_train = split_idx["train"].cuda()
train_path = os.path.join(tempdir, f"train_{global_rank}")
os.mkdir(train_path)
train_path = None if in_memory else os.path.join(tempdir, f"train_{global_rank}")
if train_path:
os.mkdir(train_path)
train_loader = NeighborLoader(
data,
input_nodes=ix_train,
directory=train_path,
shuffle=True,
drop_last=True,
local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None,
**kwargs,
)

ix_test = split_idx["test"].cuda()
test_path = os.path.join(tempdir, f"test_{global_rank}")
os.mkdir(test_path)
test_path = None if in_memory else os.path.join(tempdir, f"test_{global_rank}")
if test_path:
os.mkdir(test_path)
test_loader = NeighborLoader(
data,
input_nodes=ix_test,
Expand All @@ -221,14 +226,16 @@ def run_train(
)

ix_valid = split_idx["valid"].cuda()
valid_path = os.path.join(tempdir, f"valid_{global_rank}")
os.mkdir(valid_path)
valid_path = None if in_memory else os.path.join(tempdir, f"valid_{global_rank}")
if valid_path:
os.mkdir(valid_path)
valid_loader = NeighborLoader(
data,
input_nodes=ix_valid,
directory=valid_path,
shuffle=True,
drop_last=True,
local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None,
**kwargs,
)

Expand Down Expand Up @@ -347,6 +354,9 @@ def parse_args():
parser.add_argument("--skip_partition", action="store_true")
parser.add_argument("--wg_mem_type", type=str, default="distributed")

parser.add_argument("--in_memory", action="store_true", default=False)
parser.add_argument("--seeds_per_call", type=int, default=-1)

return parser.parse_args()


Expand Down Expand Up @@ -429,6 +439,8 @@ def parse_args():
wall_clock_start,
tempdir,
args.num_layers,
args.in_memory,
args.seeds_per_call,
)
else:
warnings.warn("This script should be run with 'torchrun`. Exiting.")
24 changes: 20 additions & 4 deletions python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_sg.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,28 @@ def test(loader: NeighborLoader, val_steps: Optional[int] = None):


def create_loader(
data, num_neighbors, input_nodes, replace, batch_size, samples_dir, stage_name
data,
num_neighbors,
input_nodes,
replace,
batch_size,
samples_dir,
stage_name,
local_seeds_per_call,
):
directory = os.path.join(samples_dir, stage_name)
os.mkdir(directory)
if samples_dir is not None:
directory = os.path.join(samples_dir, stage_name)
os.mkdir(directory)
else:
directory = None
return NeighborLoader(
data,
num_neighbors=num_neighbors,
input_nodes=input_nodes,
replace=replace,
batch_size=batch_size,
directory=directory,
local_seeds_per_call=local_seeds_per_call,
)


Expand Down Expand Up @@ -147,6 +158,8 @@ def parse_args():
parser.add_argument("--tempdir_root", type=str, default=None)
parser.add_argument("--dataset_root", type=str, default="dataset")
parser.add_argument("--dataset", type=str, default="ogbn-products")
parser.add_argument("--in_memory", action="store_true", default=False)
parser.add_argument("--seeds_per_call", type=int, default=-1)

return parser.parse_args()

Expand All @@ -170,7 +183,10 @@ def parse_args():
"num_neighbors": [args.fan_out] * args.num_layers,
"replace": False,
"batch_size": args.batch_size,
"samples_dir": samples_dir,
"samples_dir": None if args.in_memory else samples_dir,
"local_seeds_per_call": None
if args.seeds_per_call <= 0
else args.seeds_per_call,
}

train_loader = create_loader(
Expand Down
23 changes: 17 additions & 6 deletions python/cugraph-pyg/cugraph_pyg/examples/gcn_dist_snmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def run_train(
wall_clock_start,
tempdir=None,
num_layers=3,
in_memory=False,
seeds_per_call=-1,
):

init_pytorch_worker(
Expand Down Expand Up @@ -119,20 +121,23 @@ def run_train(
dist.barrier()

ix_train = torch.tensor_split(split_idx["train"], world_size)[rank].cuda()
train_path = os.path.join(tempdir, f"train_{rank}")
os.mkdir(train_path)
train_path = None if in_memory else os.path.join(tempdir, f"train_{rank}")
if train_path:
os.mkdir(train_path)
train_loader = NeighborLoader(
(feature_store, graph_store),
input_nodes=ix_train,
directory=train_path,
shuffle=True,
drop_last=True,
local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None,
**kwargs,
)

ix_test = torch.tensor_split(split_idx["test"], world_size)[rank].cuda()
test_path = os.path.join(tempdir, f"test_{rank}")
os.mkdir(test_path)
test_path = None if in_memory else os.path.join(tempdir, f"test_{rank}")
if test_path:
os.mkdir(test_path)
test_loader = NeighborLoader(
(feature_store, graph_store),
input_nodes=ix_test,
Expand All @@ -144,14 +149,16 @@ def run_train(
)

ix_valid = torch.tensor_split(split_idx["valid"], world_size)[rank].cuda()
valid_path = os.path.join(tempdir, f"valid_{rank}")
os.mkdir(valid_path)
valid_path = None if in_memory else os.path.join(tempdir, f"valid_{rank}")
if valid_path:
os.mkdir(valid_path)
valid_loader = NeighborLoader(
(feature_store, graph_store),
input_nodes=ix_valid,
directory=valid_path,
shuffle=True,
drop_last=True,
local_seeds_per_call=seeds_per_call if seeds_per_call > 0 else None,
**kwargs,
)

Expand Down Expand Up @@ -269,6 +276,8 @@ def run_train(
parser.add_argument("--tempdir_root", type=str, default=None)
parser.add_argument("--dataset_root", type=str, default="dataset")
parser.add_argument("--dataset", type=str, default="ogbn-products")
parser.add_argument("--in_memory", action="store_true", default=False)
parser.add_argument("--seeds_per_call", type=int, default=-1)

parser.add_argument(
"--n_devices",
Expand Down Expand Up @@ -322,6 +331,8 @@ def run_train(
wall_clock_start,
tempdir,
args.num_layers,
args.in_memory,
args.seeds_per_call,
),
nprocs=world_size,
join=True,
Expand Down
28 changes: 13 additions & 15 deletions python/cugraph-pyg/cugraph_pyg/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# limitations under the License.

import warnings
import tempfile

from typing import Union, Tuple, Optional, Callable, List, Dict

Expand Down Expand Up @@ -123,14 +122,14 @@ def __init__(
The number of input nodes per output minibatch.
See torch.utils.dataloader.
directory: str (optional, default=None)
The directory where samples will be temporarily stored.
It is recommend that this be set by the user, usually
setting it to a tempfile.TemporaryDirectory with a context
The directory where samples will be temporarily stored,
if spilling samples to disk. If None, this loader
will perform buffered in-memory sampling.
If writing to disk, setting this argument
to a tempfile.TemporaryDirectory with a context
manager is a good option but depending on the filesystem,
you may want to choose an alternative location with fast I/O
intead.
If not set, this will create a TemporaryDirectory that will
persist until this object is garbage collected.
See cugraph.gnn.DistSampleWriter.
batches_per_partition: int (optional, default=256)
The number of batches per partition if writing samples to
Expand Down Expand Up @@ -182,20 +181,19 @@ def __init__(
# Will eventually automatically convert these objects to cuGraph objects.
raise NotImplementedError("Currently can't accept non-cugraph graphs")

if directory is None:
warnings.warn("Setting a directory to store samples is recommended.")
self._tempdir = tempfile.TemporaryDirectory()
directory = self._tempdir.name

if compression is None:
compression = "CSR"
elif compression not in ["CSR", "COO"]:
raise ValueError("Invalid value for compression (expected 'CSR' or 'COO')")

writer = DistSampleWriter(
directory=directory,
batches_per_partition=batches_per_partition,
format=format,
writer = (
None
if directory is None
else DistSampleWriter(
directory=directory,
batches_per_partition=batches_per_partition,
format=format,
)
)

feature_store, graph_store = data
Expand Down
3 changes: 3 additions & 0 deletions python/cugraph/cugraph/gnn/data_loading/dist_io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def __init__(
if rank is None:
self.__batch_count = batch_count
else:
# TODO maybe remove this in favor of warning users that they are
# probably going to cause a hang, instead of attempting to resolve
# the hang for them by dropping batches.
batch_count = torch.tensor([batch_count], device="cuda")
torch.distributed.all_reduce(batch_count, torch.distributed.ReduceOp.MIN)
self.__batch_count = int(batch_count)
Expand Down
5 changes: 5 additions & 0 deletions python/cugraph/cugraph/gnn/data_loading/dist_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ def __sample_from_nodes_func(
# rename renumber_map -> map to match unbuffered format
minibatch_dict["map"] = minibatch_dict["renumber_map"]
del minibatch_dict["renumber_map"]
minibatch_dict = {
k: torch.as_tensor(v, device="cuda")
for k, v in minibatch_dict.items()
if v is not None
}

return iter([(minibatch_dict, current_batches[0], current_batches[-1])])
else:
Expand Down

0 comments on commit a365a24

Please sign in to comment.