Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
GentleZhu committed Mar 20, 2024
1 parent ae6b91f commit 597b9a8
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 27 deletions.
41 changes: 14 additions & 27 deletions examples/knn_retriever/build_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,15 @@ def main(config_args):
"""
config = GSConfig(config_args)
embs = load_gsgnn_embeddings(config.save_embed_path)
if False:
index_dimension = embs[config.target_ntype].size(1)
# Number of clusters (higher values lead to better recall but slower search)
#nlist = 750
#quantizer = faiss.IndexFlatL2(index_dimension) # Use Flat index for quantization
#index = faiss.IndexIVFFlat(quantizer, index_dimension, nlist, faiss.METRIC_INNER_PRODUCT)
#index.train(embs[config.target_ntype])
index = faiss.IndexFlatIP(index_dimension)
index.add(embs[config.target_ntype])
else:
scores = embs[config.target_ntype] @ embs[config.target_ntype].T
#scores.fill_diagonal_(-10)

index_dimension = embs[config.target_ntype].size(1)
# Number of clusters (higher values lead to better recall but slower search)
#nlist = 750
#quantizer = faiss.IndexFlatL2(index_dimension) # Use Flat index for quantization
#index = faiss.IndexIVFFlat(quantizer, index_dimension, nlist, faiss.METRIC_INNER_PRODUCT)
#index.train(embs[config.target_ntype])
index = faiss.IndexFlatIP(index_dimension)
index.add(embs[config.target_ntype])

#print(scores.abs().mean())

Expand All @@ -68,7 +65,7 @@ def main(config_args):
# TODO: devise a dataloader that can exclude targets and add train_mask like LP Loader
test_dataloader = GSgnnNodeDataLoader(
train_data,
train_data.test_idxs,
train_data.train_idxs,
fanout=[-1],
batch_size=config.eval_batch_size,
device=device,
Expand Down Expand Up @@ -106,27 +103,17 @@ def main(config_args):
query_idx = list(ground_truth.keys())
#print(ground_truth)
#breakpoint()
#ddd,lll = index.search(embs[config.target_ntype][query_idx],100 + 1)
ddd,lll = index.search(embs[config.target_ntype][query_idx],100 + 1)
#knn_result = lll.tolist()

for idx,query in enumerate(query_idx):
#if len(ground_truth[query]) > 10:
rank_list = scores[query,:].argsort(descending=True).tolist()
#for ii in rank_list[:10]:
#print(ii, query, train_data.g.ndata['bert_h'][query] @train_data.g.ndata['bert_h'][ii].T)
# print(ii, query, scores[query, ii])
#print(ground_truth[query])
#breakpoint()
#recall.append(calculate_recall(lll[idx, 1:], ground_truth[query]))
recall.append(calculate_recall(rank_list[:100], ground_truth[query]))
#print(ground_truth)
recall.append(calculate_recall(lll[idx, 1:], ground_truth[query]))
max_.append(query)
#print(recall)
if gs.get_rank() == 0:
#print(query_idx, lll)
print(max_num_batch, len(recall), np.mean(recall))
print(len(max_), len(set(max_)))
breakpoint()
#print(max_num_batch, len(recall), np.mean(recall))
print(f'recall@100: {np.mean(recall)}')

def generate_parser():
"""Generate an argument parser"""
Expand Down
46 changes: 46 additions & 0 deletions examples/knn_retriever/embedding_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
gsf:
basic:
backend: gloo
verbose: false
save_perf_results_path: null
gnn:
model_encoder_type: mlp
fanout: "5,5"
node_feat_name:
- item:bert_h
num_layers: 2
hidden_size: 768
use_mini_batch_infer: true
input:
restore_model_path: null
output:
save_model_path: null
save_embed_path: /shared_data/graphstorm/examples/peft_llm_gnn/results/lp/Video_Games
hyperparam:
dropout: 0.
lr: 0.001
num_epochs: 1
batch_size: 512
eval_batch_size: 512
wd_l2norm: 0.00001
no_validation: false
rgcn:
num_bases: -1
use_self_loop: true
lp_decoder_type: dot_product
sparse_optimizer_lr: 1e-2
use_node_embeddings: false
link_prediction:
num_negative_edges: 1
num_negative_edges_eval: 100
contrastive_loss_temperature: 0.1
lp_loss_func: contrastive
lp_embed_normalizer: l2_norm
train_negative_sampler: inbatch_joint
target_ntype: item
eval_etype:
- "item,also_buy,item"
train_etype:
- "item,also_buy,item"
exclude_training_targets: true
reverse_edge_types_map: ["item,also_buy,also_buy-rev,item"]
18 changes: 18 additions & 0 deletions examples/knn_retriever/run_knn.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
WORKSPACE=/shared_data/graphstorm/examples/knn_retriever/
DATASPACE=/shared_data/graphstorm/examples/peft_llm_gnn/
dataset=amazon_review
domain=$1

python -m graphstorm.run.launch \
--workspace "$WORKSPACE" \
--part-config "$DATASPACE"/datasets/amazon_review_"$domain"/amazon_review.json \
--ip-config "$DATASPACE"/ip_list.txt \
--num-trainers 1 \
--num-servers 1 \
--num-samplers 0 \
--ssh-port 22 \
--do-nid-remap False \
build_index.py \
--cf "$WORKSPACE"/embedding_config.yaml \
--save-model-path "$DATASPACE"/model/lp/"$domain"/ \
--save-embed-path "$DATASPACE"/results/lp/"$domain"/

0 comments on commit 597b9a8

Please sign in to comment.