diff --git a/python/cugraph-dgl/examples/graphsage/node-classification.py b/python/cugraph-dgl/examples/graphsage/node-classification.py index 320890b0312..539fd86d136 100644 --- a/python/cugraph-dgl/examples/graphsage/node-classification.py +++ b/python/cugraph-dgl/examples/graphsage/node-classification.py @@ -243,7 +243,9 @@ def train(args, device, g, dataset, model): else: g = g.to("cuda" if args.mode == "gpu_dgl" else "cpu") - device = torch.device("cpu" if args.mode == "cpu" else "cuda") + device = torch.device( + "cpu" if args.mode == "cpu" or args.mode == "mixed" else "cuda" + ) # create GraphSAGE model feat_shape = (