Skip to content

Commit

Permalink
make code run faster
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiwen Yuan committed Oct 24, 2024
1 parent 7819215 commit 7911ceb
Showing 1 changed file with 37 additions and 10 deletions.
47 changes: 37 additions & 10 deletions examples/ijcai_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,21 @@
data = HeteroData()


def sparse_matrix_to_sparse_coo(sci_sparse_matrix):
sci_sparse_coo = sci_sparse_matrix.tocoo()

# Get the data, row indices, and column indices
values = torch.tensor(sci_sparse_coo.data, dtype=torch.int64)
row_indices = torch.tensor(sci_sparse_coo.row, dtype=torch.int64)
col_indices = torch.tensor(sci_sparse_coo.col, dtype=torch.int64)

# Create a PyTorch sparse tensor
torch_sparse_tensor = torch.sparse_coo_tensor(
indices=torch.stack([row_indices, col_indices]), values=values,
size=sci_sparse_matrix.shape)
return torch_sparse_tensor


def calculate_hit_rate(pred: torch.Tensor, target: List[Optional[int]],
num_candidates=None):
r"""Calculates hit rate when pred is a tensor and target is a list.
Expand Down Expand Up @@ -329,6 +344,9 @@ def train() -> float:
return loss_accum / count_accum if count_accum > 0 else float("nan")


trnLabel = sparse_matrix_to_sparse_coo(trnLabel)


@torch.no_grad()
def test(loader: NeighborLoader, desc: str, target=None) -> np.ndarray:
model.eval()
Expand Down Expand Up @@ -367,16 +385,25 @@ def test(loader: NeighborLoader, desc: str, target=None) -> np.ndarray:
0, num_sampled_rhs, (batch_user, 100)).to(
scores.device) # Shape: (batch_user, 100)
for i in range(batch_size):
user_idx = batch[src_entity_table].n_id[i].cpu()
neg_item_per_user = np.reshape(
np.argwhere(trnLabel[user_idx].toarray().reshape(-1) == 0),
[-1])
neg_item_per_user_sampled = np.intersect1d(
all_sampled_rhs.cpu(), neg_item_per_user)
random_items[i, :] = torch.tensor(
np.random.choice(neg_item_per_user_sampled, size=100,
replace=False),
dtype=torch.long).to(random_items.device)
user_idx = batch[src_entity_table].n_id[i]
pos_item_per_user = trnLabel[user_idx].coalesce().indices(
).reshape(-1)
#neg_item_per_user = np.reshape(
# np.argwhere(trnLabel[user_idx].toarray().reshape(-1) == 0),
# [-1])

#neg_item_per_user_sampled = np.intersect1d(
# all_sampled_rhs.cpu(), neg_item_per_user)
#sampled_rhs = torch.tensor(
# np.random.choice(all_sampled_rhs, size=200,
# replace=False),
# dtype=torch.long).to(random_items.device)

indices = torch.randint(0, all_sampled_rhs.size(0), (1000, ))
sampled_items = all_sampled_rhs[indices]

random_items[i, :] = sampled_items[
~torch.isin(sampled_items, pos_item_per_user)][:100]

# include the target item if it is there
target_item = target[user_idx]
Expand Down

0 comments on commit 7911ceb

Please sign in to comment.