From e413c8ac20affa003848770357cb12ad8c4f0831 Mon Sep 17 00:00:00 2001 From: Alexandria Barghi Date: Fri, 20 Sep 2024 08:33:05 -0700 Subject: [PATCH] skip snmg example --- .../examples/rgcn_link_class_snmg.py | 66 ++++++++++--------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/examples/rgcn_link_class_snmg.py b/python/cugraph-pyg/cugraph_pyg/examples/rgcn_link_class_snmg.py index 1cc7c29f467..2c0ae53a08e 100644 --- a/python/cugraph-pyg/cugraph_pyg/examples/rgcn_link_class_snmg.py +++ b/python/cugraph-pyg/cugraph_pyg/examples/rgcn_link_class_snmg.py @@ -15,6 +15,7 @@ import os import argparse +import warnings from typing import Tuple, Any @@ -275,42 +276,45 @@ def get_eval_loader(stage: str): if __name__ == "__main__": - args = parse_args() + if "CI_RUN" in os.environ and os.environ["CI_RUN"] == "1": + warnings.warn("Skipping SMNG example in CI due to memory limit") + else: + args = parse_args() - # change the allocator before any allocations are made - from rmm.allocators.torch import rmm_torch_allocator + # change the allocator before any allocations are made + from rmm.allocators.torch import rmm_torch_allocator - torch.cuda.memory.change_current_allocator(rmm_torch_allocator) + torch.cuda.memory.change_current_allocator(rmm_torch_allocator) - # import ogb here to stop it from creating a context and breaking pytorch/rmm - from ogb.linkproppred import PygLinkPropPredDataset + # import ogb here to stop it from creating a context and breaking pytorch/rmm + from ogb.linkproppred import PygLinkPropPredDataset - data = PygLinkPropPredDataset(args.dataset, root=args.dataset_root) - dataset = data[0] + data = PygLinkPropPredDataset(args.dataset, root=args.dataset_root) + dataset = data[0] - splits = data.get_edge_split() + splits = data.get_edge_split() - meta = {} - meta["num_nodes"] = dataset.num_nodes - meta["num_rels"] = dataset.edge_reltype.max() + 1 + meta = {} + meta["num_nodes"] = dataset.num_nodes + meta["num_rels"] = dataset.edge_reltype.max() + 1 - model = RGCNEncoder( - meta["num_nodes"], - hidden_channels=args.hidden_channels, - num_relations=meta["num_rels"], - ) + model = RGCNEncoder( + meta["num_nodes"], + hidden_channels=args.hidden_channels, + num_relations=meta["num_rels"], + ) - print("Data =", data) - if args.n_devices == -1: - world_size = torch.cuda.device_count() - else: - world_size = args.n_devices - print("Using", world_size, "GPUs...") - - uid = cugraph_comms_create_unique_id() - torch.multiprocessing.spawn( - run_train, - (world_size, uid, model, data, meta, splits, args), - nprocs=world_size, - join=True, - ) + print("Data =", data) + if args.n_devices == -1: + world_size = torch.cuda.device_count() + else: + world_size = args.n_devices + print("Using", world_size, "GPUs...") + + uid = cugraph_comms_create_unique_id() + torch.multiprocessing.spawn( + run_train, + (world_size, uid, model, data, meta, splits, args), + nprocs=world_size, + join=True, + )