Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Dec 8, 2023
1 parent e08c46c commit a9fc5af
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 276 deletions.
270 changes: 29 additions & 241 deletions benchmarks/cugraph-pyg/bench_cugraph_pyg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,17 @@
# limitations under the License.


import re
import json
import time
import argparse
import gc
import os
import socket
import json
import warnings

import torch
import numpy as np
import pandas

import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as tmp
from torch.nn.parallel import DistributedDataParallel as ddp
from torch.distributed.optim import ZeroRedundancyOptimizer

from typing import Union, List

from models_cugraph import CuGraphSAGE
from trainers_cugraph import PyGCuGraphTrainer
from trainers_native import PyGNativeTrainer

from datasets import OGBNPapers100MDataset

Expand All @@ -47,7 +34,7 @@ def init_pytorch_worker(rank: int, use_rmm_torch_allocator: bool=False) -> None:
from pynvml.smi import nvidia_smi

smi = nvidia_smi.getInstance()
pool_size=8e9 # FIXME calculate this
pool_size=16e9 # FIXME calculate this

rmm.reinitialize(
devices=[rank],
Expand All @@ -56,8 +43,13 @@ def init_pytorch_worker(rank: int, use_rmm_torch_allocator: bool=False) -> None:
)

if use_rmm_torch_allocator:
from rmm.allocators.torch import rmm_torch_allocator
torch.cuda.memory.change_current_allocator(rmm_torch_allocator)
warnings.warn(
"Using the rmm pytorch allocator is currently unsupported."
" The default allocator will be used instead."
)
# FIXME somehow get the pytorch allocator to work
#from rmm.allocators.torch import rmm_torch_allocator
#torch.cuda.memory.change_current_allocator(rmm_torch_allocator)

from rmm.allocators.cupy import rmm_cupy_allocator
cupy.cuda.set_allocator(rmm_cupy_allocator)
Expand All @@ -68,229 +60,6 @@ def init_pytorch_worker(rank: int, use_rmm_torch_allocator: bool=False) -> None:
# Pytorch training worker initialization
torch.distributed.init_process_group(backend="nccl")

def train(bulk_samples_dir: str, output_dir:str, native_times:List[float], device: int, features_device: Union[str, int] = "cpu", world_size=1, num_epochs=1) -> None:
"""
Parameters
----------
device: int
The CUDA device where the model, graph data, and node labels will be stored.
features_device: Union[str, int]
The device (CUDA device or CPU) where features will be stored.
"""

import cudf
import cugraph
from cugraph_pyg.data import CuGraphStore
from cugraph_pyg.loader import BulkSampleLoader

with torch.cuda.device(device):

with open(os.path.join(bulk_samples_dir, 'output_meta.json'), 'r') as f:
output_meta = json.load(f)

dataset_path = os.path.join(output_meta['dataset_dir'], output_meta['dataset'])
with open(os.path.join(dataset_path, 'meta.json'), 'r') as f:
input_meta = json.load(f)

replication_factor = output_meta['replication_factor']
G = {tuple(edge_type.split('__')): t * replication_factor for edge_type, t in input_meta['num_edges'].items()}
N = {node_type: t * replication_factor for node_type, t in input_meta['num_nodes'].items()}

fs = cugraph.gnn.FeatureStore(backend="torch")

num_input_features = 0
num_output_features = 0
for node_type in input_meta['num_nodes'].keys():
feature_data = load_disk_features(output_meta, node_type, replication_factor=replication_factor)
print(f'features shape: {feature_data.shape}')
fs.add_data(
torch.as_tensor(feature_data, device=features_device),
node_type,
"x",
)
if feature_data.shape[1] > num_input_features:
num_input_features = feature_data.shape[1]

label_path = os.path.join(dataset_path, 'parquet', node_type, 'node_label.parquet')
if os.path.exists(label_path):
node_label = cudf.read_parquet(label_path)
if replication_factor > 1:
base_num_nodes = input_meta['num_nodes'][node_type]
print('base num nodes:', base_num_nodes)
dfr = cudf.DataFrame({
'node': cudf.concat([node_label.node + (r * base_num_nodes) for r in range(1, replication_factor)]),
'label': cudf.concat([node_label.label for r in range(1, replication_factor)]),
})
node_label = cudf.concat([node_label, dfr]).reset_index(drop=True)

node_label_tensor = torch.full((N[node_type],), -1, dtype=torch.float32, device='cuda')
node_label_tensor[torch.as_tensor(node_label.node.values, device='cuda')] = \
torch.as_tensor(node_label.label.values, device='cuda')

del node_label
gc.collect()

fs.add_data((node_label_tensor > -1).contiguous(), node_type, 'train')
fs.add_data(node_label_tensor.contiguous(), node_type, 'y')
num_classes = int(node_label_tensor.max()) + 1
if num_classes > num_output_features:
num_output_features = num_classes
print('done loading data')
dist.barrier()

print(f"num input features: {num_input_features}; num output features: {num_output_features}; fanout: {output_meta['fanout']}")

num_hidden_channels = 64
num_layers = len(output_meta['fanout'])
model = CuGraphSAGE(
in_channels=num_input_features,
hidden_channels=num_hidden_channels,
out_channels=num_output_features,
num_layers=num_layers
).to(torch.float32).to(device)

model = ddp(model, device_ids=[device])

print('done creating model')
dist.barrier()

cugraph_store = CuGraphStore(fs, G, N)
print('done creating store')
dist.barrier()

#optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer = ZeroRedundancyOptimizer(model.parameters(), torch.optim.Adam, lr=0.01)
dist.barrier()

for epoch in range(num_epochs):
start_time_train = time.perf_counter_ns()
model.train()

print('creating loader...')
samples_dir = os.path.join(bulk_samples_dir, 'samples')
input_files = np.array(os.listdir(samples_dir))
input_files = np.array_split(
input_files, world_size
)[device].tolist()

cugraph_loader = BulkSampleLoader(
cugraph_store,
cugraph_store,
input_nodes=None,
input_files=input_files,
directory=samples_dir,
)
print('done creating loader')
dist.barrier()

total_loss, num_batches, mean_total_time, mean_time_fw, mean_time_bw, mean_time_loader, mean_additional_feature_time = train_epoch(model, cugraph_loader, optimizer)

end_time_train = time.perf_counter_ns()
train_time = (end_time_train - start_time_train) / 1e9
print(
f"epoch {epoch} time: "
f"{train_time:3.4f} s"
f"\n trained {num_batches} batches"
)
print(f"loss after epoch {epoch}: {total_loss / num_batches}")

train_time = mean_total_time * num_batches
output_result_filename = f'results[{device}].csv'
results_native = {
'Dataset': f"{output_meta['dataset']} x {replication_factor}",
'Framework': 'PyG',
'Setup Details': f"GraphSAGE, {num_layers} layers",
'Batch Size': output_meta['batch_size'],
'Fanout': str(output_meta['fanout']),
'Machine Details': socket.gethostname(),
'Sampling per epoch': native_times[4] * num_batches,
'MFG Creation': 0.0,
'Feature Loading': native_times[3] * num_batches,
'Model FWD': native_times[1] * num_batches,
'Model BWD': native_times[2] * num_batches,
'Time Per Epoch': native_times[0] * num_batches,
'Time Per Batch': native_times[0],
'Speedup': 1,
}
results_cugraph = {
'Dataset': f"{output_meta['dataset']} x {replication_factor}",
'Framework': 'cuGraph-PyG',
'Setup Details': f"GraphSAGE, {num_layers} layers",
'Batch Size': output_meta['batch_size'],
'Fanout': str(output_meta['fanout']),
'Machine Details': socket.gethostname(),
'Sampling per epoch': output_meta['execution_time'],
'MFG Creation': cugraph_loader._total_convert_time + cugraph_loader._total_read_time,
'Feature Loading': cugraph_loader._total_feature_time + (mean_additional_feature_time * num_batches),
'Model FWD': mean_time_fw * num_batches,
'Model BWD': mean_time_bw * num_batches,
'Time Per Epoch': train_time + output_meta['execution_time'],
'Time Per Batch': (train_time + output_meta['execution_time']) / num_batches,
'Speedup': (native_times[0] * num_batches) / (train_time + output_meta['execution_time']),
}
results = {
'Machine': socket.gethostname(),
'Comms': output_meta['comms'] if 'comms' in output_meta else 'tcp',
'Dataset': output_meta['dataset'],
'Replication Factor': replication_factor,
'Model': 'GraphSAGE',
'# Layers': num_layers,
'# Input Channels': num_input_features,
'# Output Channels': num_output_features,
'# Hidden Channels': num_hidden_channels,
'# Vertices': output_meta['total_num_nodes'],
'# Edges': output_meta['total_num_edges'],
'# Vertex Types': len(N.keys()),
'# Edge Types': len(G.keys()),
'Sampling # GPUs': output_meta['num_sampling_gpus'],
'Seeds Per Call': output_meta['seeds_per_call'],
'Batch Size': output_meta['batch_size'],
'# Train Batches': num_batches,
'Batches Per Partition': output_meta['batches_per_partition'],
'Fanout': str(output_meta['fanout']),
'Training # GPUs': 1,
'Feature Storage': 'cpu' if features_device == 'cpu' else 'gpu',
'Memory Type': 'Device', # could be managed if configured

'Total Time': train_time + output_meta['execution_time'],
'Native Equivalent Time': native_times[0] * num_batches,
'Total Speedup': (native_times[0] * num_batches) / (train_time + output_meta['execution_time']),

'Bulk Sampling Time': output_meta['execution_time'],
'Bulk Sampling Time Per Batch': output_meta['execution_time'] / num_batches,

'Parquet Read Time': cugraph_loader._total_read_time,
'Parquet Read Time Per Batch': cugraph_loader._total_read_time / num_batches,

'Minibatch Conversion Time': cugraph_loader._total_convert_time,
'Minibatch Conversion Time Per Batch': cugraph_loader._total_convert_time / num_batches,

'Feature Fetch Time': cugraph_loader._total_feature_time,
'Feature Fetch Time Per Batch': cugraph_loader._total_feature_time / num_batches,

'Foward Time': mean_time_fw * num_batches,
'Native Forward Time': native_times[1] * num_batches,

'Forward Time Per Batch': mean_time_fw,
'Native Forward Time Per Batch': native_times[1],

'Backward Time': mean_time_bw * num_batches,
'Native Backward Time': native_times[2] * num_batches,

'Backward Time Per Batch': mean_time_bw,
'Native Backward Time Per Batch': native_times[2],
}
df = pandas.DataFrame(results, index=[0])
df.to_csv(os.path.join(output_dir, output_result_filename),header=True, sep=',', index=False, mode='a')

df_n = pandas.DataFrame(results_native, index=[0])
df_c = pandas.DataFrame(results_cugraph, index=[1])
pandas.concat([df_n, df_c]).to_csv(os.path.join(output_dir, output_result_filename),header=True, sep=',', index=False, mode='a')

print('convert:', cugraph_loader._total_convert_time)
print('read:', cugraph_loader._total_read_time)


def parse_args():
parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -392,18 +161,23 @@ def main(args):
local_rank = int(os.environ['LOCAL_RANK'])
global_rank = int(os.environ["RANK"])

init_pytorch_worker(local_rank, use_rmm_torch_allocator=(args.framework == "cuGraph"))
init_pytorch_worker(local_rank, use_rmm_torch_allocator=(args.framework=="cuGraph"))
enable_spilling()
print(f'worker initialized')
dist.barrier()

# Have to import here to avoid creating CUDA context
from trainers_cugraph import PyGCuGraphTrainer
from trainers_native import PyGNativeTrainer

world_size = int(os.environ['SLURM_JOB_NUM_NODES']) * int(os.environ['SLURM_GPUS_PER_NODE'])

dataset = OGBNPapers100MDataset(
replication_factor=args.replication_factor,
dataset_dir=args.dataset_dir,
train_split=args.train_split,
val_split=args.val_split,
load_edge_index=(args.framework=="Native"),
)

if args.framework == "Native":
Expand All @@ -419,6 +193,20 @@ def main(args):
num_neighbors=[int(f) for f in args.fanout.split('_')],
batch_size=args.batch_size,
)
elif args.framework == "cuGraph":
trainer = PyGCuGraphTrainer(
model=args.model,
dataset=dataset,
sample_dir=args.sample_dir,
device=local_rank,
rank=global_rank,
world_size=world_size,
num_epochs=args.num_epochs,
shuffle=True,
replace=False,
num_neighbors=[int(f) for f in args.fanout.split('_')],
batch_size=args.batch_size,
)
else:
raise ValueError("unsuported framework")

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/cugraph-pyg/datasets/ogbn_papers100M.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def edge_index_dict(self) -> Dict[Tuple[str, str, str], Union[Dict[str, torch.Te
logger.info(f"# edges: {len(ei['src'])}")
self.__edge_index = {('paper','cites','paper'): ei}
else:
self.__edge_index = {('paper','cites','paper'): self.__num_edges(('paper','cites','paper'))}
self.__edge_index = {('paper','cites','paper'): self.num_edges(('paper','cites','paper'))}

return self.__edge_index

Expand Down
Loading

0 comments on commit a9fc5af

Please sign in to comment.