Skip to content

Commit

Permalink
Merge pull request #33 from kumo-ai/zecheng_fix_lightgcn
Browse files Browse the repository at this point in the history
Fix LightGCN
  • Loading branch information
zechengz authored Sep 29, 2024
2 parents 385b38a + da064ca commit 2408e78
Showing 1 changed file with 34 additions and 39 deletions.
73 changes: 34 additions & 39 deletions examples/light_gcn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Example script to run the models in this repository.
python3 light_gcn.py --dataset rel-hm --task user-item-purchase --val_loss
python3 light_gcn.py --dataset rel-avito --task user-ad-visit --val_loss
"""

from __future__ import annotations
Expand Down Expand Up @@ -37,7 +38,7 @@
parser.add_argument("--dataset", type=str, default="rel-trial")
parser.add_argument("--task", type=str, default="site-sponsor-run")
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--eval_epochs_interval", type=int, default=1)
parser.add_argument("--batch_size", type=int, default=1024)
parser.add_argument("--channels", type=int, default=64)
Expand Down Expand Up @@ -87,7 +88,7 @@
num_total_nodes = num_src_nodes + num_dst_nodes

split_edge_index_dict: Dict[str, Tensor] = {}
split_edge_attr_dict: Dict[str, Tensor] = {}
split_edge_weight_dict: Dict[str, Tensor] = {}
n_id_dict: Dict[str, Tensor] = {}
for split in ["train", "val", "test"]:
table = task.get_table(split)
Expand All @@ -108,31 +109,32 @@
# Get edge_index using src and column indices
edge_index = torch.stack([src, dst_csr.col_indices()], dim=0)
# Convert to bipartite graph
edge_index[1, :] += num_dst_nodes
edge_index[1, :] += num_src_nodes
# Remove duplicated edges but use edge weight for message passing
edge_attr = torch.ones(edge_index.size(1)).to(edge_index.device)
edge_index, edge_attr = coalesce(edge_index, edge_attr=edge_attr,
num_nodes=num_total_nodes)
edge_weight = torch.ones(edge_index.size(1)).to(edge_index.device)
edge_index, edge_weight = coalesce(edge_index, edge_attr=edge_weight,
num_nodes=num_total_nodes)
split_edge_index_dict[split] = edge_index
split_edge_attr_dict[split] = edge_attr
split_edge_weight_dict[split] = edge_weight

model = LightGCN(num_total_nodes, embedding_dim=args.channels,
num_layers=args.num_layers).to(device)

train_edge_index = split_edge_index_dict["train"].to("cpu")
train_edge_weight = split_edge_attr_dict["train"].to("cpu")
# Shuffle train edges to avoid only using earlier edges
train_edge_weight = split_edge_weight_dict["train"].to("cpu")
# Shuffle train edges to avoid only using same edges for supervision each time
perm = torch.randperm(train_edge_index.size(1), device="cpu")
train_edge_index = train_edge_index[:, perm][:, :args.max_num_train_edges].to(
device)
train_edge_weight = train_edge_weight[perm][:args.max_num_train_edges].to(
device)
# Convert to undirected graph
train_mp_edge_index, train_mp_edge_weight = to_undirected(
train_edge_index, train_edge_weight)
train_mp_edge_index = train_mp_edge_index.to(device)
train_mp_edge_weight = train_mp_edge_weight.to(device)
val_edge_index = split_edge_index_dict["val"].to(device)
val_edge_weight = split_edge_attr_dict["val"].to(device)
val_edge_weight = split_edge_weight_dict["val"].to(device)
val_mp_edge_index, val_mp_edge_weight = to_undirected(val_edge_index,
val_edge_weight)
val_mp_edge_index = val_mp_edge_index.to(device)
Expand All @@ -149,6 +151,24 @@
writer = SummaryWriter()


def get_edge_label_index(sup_edge_index: Tensor, index: Tensor) -> Tensor:
pos_edge_label_index = sup_edge_index[:, index].to(device)
neg_edge_label_index = torch.stack([
pos_edge_label_index[0],
torch.randint(
num_src_nodes,
num_total_nodes,
(index.numel(), ),
device=device,
)
], dim=0)
edge_label_index = torch.cat([
pos_edge_label_index,
neg_edge_label_index,
], dim=1)
return edge_label_index


def train(epoch: int) -> float:
model.train()
total_loss = total_examples = 0
Expand All @@ -157,20 +177,7 @@ def train(epoch: int) -> float:
tqdm(train_loader, total=total_steps, desc="Train")):
if i >= args.max_steps_per_epoch:
break
pos_edge_label_index = train_edge_index[:, index].to(device)
neg_edge_label_index = torch.stack([
pos_edge_label_index[0],
torch.randint(
num_src_nodes,
num_src_nodes + num_dst_nodes,
(index.numel(), ),
device=device,
)
], dim=0)
edge_label_index = torch.cat([
pos_edge_label_index,
neg_edge_label_index,
], dim=1)
edge_label_index = get_edge_label_index(train_edge_index, index)
optimizer.zero_grad()
pos_rank, neg_rank = model(
train_mp_edge_index,
Expand Down Expand Up @@ -219,21 +226,9 @@ def test(
total_loss = total_examples = 0
for start in tqdm(range(0, val_edge_index.size(1), args.batch_size),
desc=desc):
end = start + args.batch_size
pos_edge_label_index = val_edge_index[:, start:end].to(device)
neg_edge_label_index = torch.stack([
pos_edge_label_index[0],
torch.randint(
num_src_nodes,
num_src_nodes + num_dst_nodes,
(pos_edge_label_index.size(1), ),
device=device,
)
], dim=0)
edge_label_index = torch.cat([
pos_edge_label_index,
neg_edge_label_index,
], dim=1)
end = min(start + args.batch_size, val_edge_index.size(1))
index = torch.arange(start, end)
edge_label_index = get_edge_label_index(val_edge_index, index)
pos_rank, neg_rank = model(
mp_edge_index,
edge_label_index,
Expand Down

0 comments on commit 2408e78

Please sign in to comment.