From da064ca1fb32ce9909f0f658e933d8c89b3bf583 Mon Sep 17 00:00:00 2001 From: Zecheng Zhang Date: Sat, 28 Sep 2024 22:40:46 +0000 Subject: [PATCH] Fix LightGCN --- examples/light_gcn.py | 73 ++++++++++++++++++++----------------------- 1 file changed, 34 insertions(+), 39 deletions(-) diff --git a/examples/light_gcn.py b/examples/light_gcn.py index 1c0df02..4393c5a 100644 --- a/examples/light_gcn.py +++ b/examples/light_gcn.py @@ -1,6 +1,7 @@ """Example script to run the models in this repository. python3 light_gcn.py --dataset rel-hm --task user-item-purchase --val_loss +python3 light_gcn.py --dataset rel-avito --task user-ad-visit --val_loss """ from __future__ import annotations @@ -37,7 +38,7 @@ parser.add_argument("--dataset", type=str, default="rel-trial") parser.add_argument("--task", type=str, default="site-sponsor-run") parser.add_argument("--lr", type=float, default=0.0001) -parser.add_argument("--epochs", type=int, default=10) +parser.add_argument("--epochs", type=int, default=5) parser.add_argument("--eval_epochs_interval", type=int, default=1) parser.add_argument("--batch_size", type=int, default=1024) parser.add_argument("--channels", type=int, default=64) @@ -87,7 +88,7 @@ num_total_nodes = num_src_nodes + num_dst_nodes split_edge_index_dict: Dict[str, Tensor] = {} -split_edge_attr_dict: Dict[str, Tensor] = {} +split_edge_weight_dict: Dict[str, Tensor] = {} n_id_dict: Dict[str, Tensor] = {} for split in ["train", "val", "test"]: table = task.get_table(split) @@ -108,31 +109,32 @@ # Get edge_index using src and column indices edge_index = torch.stack([src, dst_csr.col_indices()], dim=0) # Convert to bipartite graph - edge_index[1, :] += num_dst_nodes + edge_index[1, :] += num_src_nodes # Remove duplicated edges but use edge weight for message passing - edge_attr = torch.ones(edge_index.size(1)).to(edge_index.device) - edge_index, edge_attr = coalesce(edge_index, edge_attr=edge_attr, - num_nodes=num_total_nodes) + edge_weight = torch.ones(edge_index.size(1)).to(edge_index.device) + edge_index, edge_weight = coalesce(edge_index, edge_attr=edge_weight, + num_nodes=num_total_nodes) split_edge_index_dict[split] = edge_index - split_edge_attr_dict[split] = edge_attr + split_edge_weight_dict[split] = edge_weight model = LightGCN(num_total_nodes, embedding_dim=args.channels, num_layers=args.num_layers).to(device) train_edge_index = split_edge_index_dict["train"].to("cpu") -train_edge_weight = split_edge_attr_dict["train"].to("cpu") -# Shuffle train edges to avoid only using earlier edges +train_edge_weight = split_edge_weight_dict["train"].to("cpu") +# Shuffle train edges to avoid only using same edges for supervision each time perm = torch.randperm(train_edge_index.size(1), device="cpu") train_edge_index = train_edge_index[:, perm][:, :args.max_num_train_edges].to( device) train_edge_weight = train_edge_weight[perm][:args.max_num_train_edges].to( device) +# Convert to undirected graph train_mp_edge_index, train_mp_edge_weight = to_undirected( train_edge_index, train_edge_weight) train_mp_edge_index = train_mp_edge_index.to(device) train_mp_edge_weight = train_mp_edge_weight.to(device) val_edge_index = split_edge_index_dict["val"].to(device) -val_edge_weight = split_edge_attr_dict["val"].to(device) +val_edge_weight = split_edge_weight_dict["val"].to(device) val_mp_edge_index, val_mp_edge_weight = to_undirected(val_edge_index, val_edge_weight) val_mp_edge_index = val_mp_edge_index.to(device) @@ -149,6 +151,24 @@ writer = SummaryWriter() +def get_edge_label_index(sup_edge_index: Tensor, index: Tensor) -> Tensor: + pos_edge_label_index = sup_edge_index[:, index].to(device) + neg_edge_label_index = torch.stack([ + pos_edge_label_index[0], + torch.randint( + num_src_nodes, + num_total_nodes, + (index.numel(), ), + device=device, + ) + ], dim=0) + edge_label_index = torch.cat([ + pos_edge_label_index, + neg_edge_label_index, + ], dim=1) + return edge_label_index + + def train(epoch: int) -> float: model.train() total_loss = total_examples = 0 @@ -157,20 +177,7 @@ def train(epoch: int) -> float: tqdm(train_loader, total=total_steps, desc="Train")): if i >= args.max_steps_per_epoch: break - pos_edge_label_index = train_edge_index[:, index].to(device) - neg_edge_label_index = torch.stack([ - pos_edge_label_index[0], - torch.randint( - num_src_nodes, - num_src_nodes + num_dst_nodes, - (index.numel(), ), - device=device, - ) - ], dim=0) - edge_label_index = torch.cat([ - pos_edge_label_index, - neg_edge_label_index, - ], dim=1) + edge_label_index = get_edge_label_index(train_edge_index, index) optimizer.zero_grad() pos_rank, neg_rank = model( train_mp_edge_index, @@ -219,21 +226,9 @@ def test( total_loss = total_examples = 0 for start in tqdm(range(0, val_edge_index.size(1), args.batch_size), desc=desc): - end = start + args.batch_size - pos_edge_label_index = val_edge_index[:, start:end].to(device) - neg_edge_label_index = torch.stack([ - pos_edge_label_index[0], - torch.randint( - num_src_nodes, - num_src_nodes + num_dst_nodes, - (pos_edge_label_index.size(1), ), - device=device, - ) - ], dim=0) - edge_label_index = torch.cat([ - pos_edge_label_index, - neg_edge_label_index, - ], dim=1) + end = min(start + args.batch_size, val_edge_index.size(1)) + index = torch.arange(start, end) + edge_label_index = get_edge_label_index(val_edge_index, index) pos_rank, neg_rank = model( mp_edge_index, edge_label_index,