diff --git a/python/cugraph-dgl/examples/dataset_from_disk_cudf.ipynb b/python/cugraph-dgl/examples/dataset_from_disk_cudf.ipynb index dd99d78..66ee205 100644 --- a/python/cugraph-dgl/examples/dataset_from_disk_cudf.ipynb +++ b/python/cugraph-dgl/examples/dataset_from_disk_cudf.ipynb @@ -22,7 +22,7 @@ "import rmm\n", "import torch\n", "from rmm.allocators.torch import rmm_torch_allocator\n", - "rmm.reinitialize(pool_allocator=True, initial_pool_size=15e9)\n", + "rmm.reinitialize(pool_allocator=True, initial_pool_size=\"14GiB\")\n", "#Switch to async pool in case of memory issues due to fragmentation of the pool\n", "#rmm.mr.set_current_device_resource(rmm.mr.CudaAsyncMemoryResource(initial_pool_size=15e9))\n", "torch.cuda.memory.change_current_allocator(rmm_torch_allocator)" @@ -51,8 +51,10 @@ "source": [ "def load_dgl_dataset(dataset_name='ogbn-products'):\n", " from ogb.nodeproppred import DglNodePropPredDataset\n", - " dataset_root = '/raid/vjawa/gnn/'\n", - " dataset = DglNodePropPredDataset(name = dataset_name, root=dataset_root)\n", + " from unittest.mock import patch\n", + " dataset_root = '/tmp/'\n", + " with patch(\"builtins.input\", return_value=\"y\"):\n", + " dataset = DglNodePropPredDataset(name = dataset_name, root=dataset_root)\n", " split_idx = dataset.get_idx_split()\n", " train_idx, valid_idx, test_idx = split_idx[\"train\"], split_idx[\"valid\"], split_idx[\"test\"]\n", " g, label = dataset[0]\n", @@ -111,7 +113,7 @@ " g, \n", " train_idx.to('cuda'), # train_nid must be on GPU.\n", " sampler,\n", - " sampling_output_dir=\"/raid/vjawa/obgn_products_sampling/\", # Path to save sampling results to, Change to the fastest IO path available\n", + " sampling_output_dir=\"/tmp/\", # Path to save sampling results to, Change to the fastest IO path available\n", " device=torch.device('cuda'), # The device argument must be GPU.\n", " num_workers=0, # Number of workers must be 0.\n", " batch_size=batch_size,\n", @@ -254,7 +256,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.12.7" }, "vscode": { "interpreter": {