Skip to content

Commit

Permalink
skip snmg example
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Sep 20, 2024
1 parent 4cae976 commit e413c8a
Showing 1 changed file with 35 additions and 31 deletions.
66 changes: 35 additions & 31 deletions python/cugraph-pyg/cugraph_pyg/examples/rgcn_link_class_snmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import os
import argparse
import warnings

from typing import Tuple, Any

Expand Down Expand Up @@ -275,42 +276,45 @@ def get_eval_loader(stage: str):


if __name__ == "__main__":
args = parse_args()
if "CI_RUN" in os.environ and os.environ["CI_RUN"] == "1":
warnings.warn("Skipping SMNG example in CI due to memory limit")
else:
args = parse_args()

# change the allocator before any allocations are made
from rmm.allocators.torch import rmm_torch_allocator
# change the allocator before any allocations are made
from rmm.allocators.torch import rmm_torch_allocator

torch.cuda.memory.change_current_allocator(rmm_torch_allocator)
torch.cuda.memory.change_current_allocator(rmm_torch_allocator)

# import ogb here to stop it from creating a context and breaking pytorch/rmm
from ogb.linkproppred import PygLinkPropPredDataset
# import ogb here to stop it from creating a context and breaking pytorch/rmm
from ogb.linkproppred import PygLinkPropPredDataset

data = PygLinkPropPredDataset(args.dataset, root=args.dataset_root)
dataset = data[0]
data = PygLinkPropPredDataset(args.dataset, root=args.dataset_root)
dataset = data[0]

splits = data.get_edge_split()
splits = data.get_edge_split()

meta = {}
meta["num_nodes"] = dataset.num_nodes
meta["num_rels"] = dataset.edge_reltype.max() + 1
meta = {}
meta["num_nodes"] = dataset.num_nodes
meta["num_rels"] = dataset.edge_reltype.max() + 1

model = RGCNEncoder(
meta["num_nodes"],
hidden_channels=args.hidden_channels,
num_relations=meta["num_rels"],
)
model = RGCNEncoder(
meta["num_nodes"],
hidden_channels=args.hidden_channels,
num_relations=meta["num_rels"],
)

print("Data =", data)
if args.n_devices == -1:
world_size = torch.cuda.device_count()
else:
world_size = args.n_devices
print("Using", world_size, "GPUs...")

uid = cugraph_comms_create_unique_id()
torch.multiprocessing.spawn(
run_train,
(world_size, uid, model, data, meta, splits, args),
nprocs=world_size,
join=True,
)
print("Data =", data)
if args.n_devices == -1:
world_size = torch.cuda.device_count()
else:
world_size = args.n_devices
print("Using", world_size, "GPUs...")

uid = cugraph_comms_create_unique_id()
torch.multiprocessing.spawn(
run_train,
(world_size, uid, model, data, meta, splits, args),
nprocs=world_size,
join=True,
)

0 comments on commit e413c8a

Please sign in to comment.