Skip to content

Commit

Permalink
correct pool_size variable type, mock input
Browse files Browse the repository at this point in the history
  • Loading branch information
tingyu66 committed Nov 7, 2024
1 parent 349014c commit c102f46
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions python/cugraph-dgl/examples/dataset_from_disk_cudf.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"import rmm\n",
"import torch\n",
"from rmm.allocators.torch import rmm_torch_allocator\n",
"rmm.reinitialize(pool_allocator=True, initial_pool_size=15e9)\n",
"rmm.reinitialize(pool_allocator=True, initial_pool_size=\"14GiB\")\n",
"#Switch to async pool in case of memory issues due to fragmentation of the pool\n",
"#rmm.mr.set_current_device_resource(rmm.mr.CudaAsyncMemoryResource(initial_pool_size=15e9))\n",
"torch.cuda.memory.change_current_allocator(rmm_torch_allocator)"
Expand Down Expand Up @@ -51,8 +51,10 @@
"source": [
"def load_dgl_dataset(dataset_name='ogbn-products'):\n",
" from ogb.nodeproppred import DglNodePropPredDataset\n",
" dataset_root = '/raid/vjawa/gnn/'\n",
" dataset = DglNodePropPredDataset(name = dataset_name, root=dataset_root)\n",
" from unittest.mock import patch\n",
" dataset_root = '/tmp/'\n",
" with patch(\"builtins.input\", return_value=\"y\"):\n",
" dataset = DglNodePropPredDataset(name = dataset_name, root=dataset_root)\n",
" split_idx = dataset.get_idx_split()\n",
" train_idx, valid_idx, test_idx = split_idx[\"train\"], split_idx[\"valid\"], split_idx[\"test\"]\n",
" g, label = dataset[0]\n",
Expand Down Expand Up @@ -111,7 +113,7 @@
" g, \n",
" train_idx.to('cuda'), # train_nid must be on GPU.\n",
" sampler,\n",
" sampling_output_dir=\"/raid/vjawa/obgn_products_sampling/\", # Path to save sampling results to, Change to the fastest IO path available\n",
" sampling_output_dir=\"/tmp/\", # Path to save sampling results to, Change to the fastest IO path available\n",
" device=torch.device('cuda'), # The device argument must be GPU.\n",
" num_workers=0, # Number of workers must be 0.\n",
" batch_size=batch_size,\n",
Expand Down Expand Up @@ -254,7 +256,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.12.7"
},
"vscode": {
"interpreter": {
Expand Down

0 comments on commit c102f46

Please sign in to comment.