Skip to content

Commit

Permalink
fix tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Oct 10, 2023
1 parent d36485a commit 17a38aa
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
26 changes: 26 additions & 0 deletions tests/unit-tests/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import os
import json
import dgl
import numpy as np
import torch as th
Expand Down Expand Up @@ -421,6 +422,31 @@ def generate_dummy_dist_graph_multi_target_ntypes(dirname, size='tiny', graph_na
return partion_and_load_distributed_graph(hetero_graph=hetero_graph, dirname=dirname,
graph_name=graph_name)

def load_lm_graph(part_config):
with open(part_config) as f:
part_metadata = json.load(f)
g = dgl.distributed.DistGraph(graph_name=part_metadata["graph_name"],
part_config=part_config)

bert_model_name = "bert-base-uncased"
max_seq_length = 8
lm_config = [{"lm_type": "bert",
"model_name": bert_model_name,
"gradient_checkpoint": True,
"node_types": ["n0"]}]
feat_size = get_feat_size(g, {'n0' : ['feat']})
input_text = ["Hello world!"]
tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
input_ids, valid_len, attention_mask, _ = \
create_tokens(tokenizer=tokenizer,
input_text=input_text,
max_seq_length=max_seq_length,
num_node=g.number_of_nodes('n0'))

g.nodes['n0'].data[TOKEN_IDX] = input_ids
g.nodes['n0'].data[VALID_LEN] = valid_len
return g, lm_config

def create_lm_graph(tmpdirname):
""" Create a graph with textual feaures
Only n0 has a textual feature.
Expand Down
14 changes: 9 additions & 5 deletions tests/unit-tests/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
limitations under the License.
"""

import multiprocessing as mp
import pytest
import torch as th
from torch import nn
Expand All @@ -25,6 +26,7 @@

import dgl
from transformers import AutoTokenizer
import graphstorm as gs
from graphstorm import get_feat_size
from graphstorm.model import GSNodeEncoderInputLayer, GSLMNodeEncoderInputLayer, GSPureLMNodeInputLayer
from graphstorm.model.embed import compute_node_input_embeddings
Expand All @@ -33,7 +35,7 @@
from graphstorm.model.lm_embed import LMModels, LMCache

from data_utils import generate_dummy_dist_graph
from data_utils import create_lm_graph, create_lm_graph2
from data_utils import create_lm_graph, create_lm_graph2, load_lm_graph
from util import create_tokens

# In this case, we only use the node features to generate node embeddings.
Expand Down Expand Up @@ -238,7 +240,9 @@ def test_lm_cache():
th.distributed.destroy_process_group()
dgl.distributed.kvstore.close_kvstore()

def run_dist_cache(g, lm_config, tmpdirname):
def run_dist_cache(part_config, tmpdirname):
gs.initialize(ip_config=None, backend="gloo")
g, lm_config = load_lm_graph(part_config)
lm_models = LMModels(g, lm_config, 0, 10)
lm_cache = LMCache(g, lm_models, tmpdirname)
lm_cache.update_cache(100)
Expand All @@ -262,12 +266,12 @@ def test_mp_lm_cache():
rank=0,
world_size=1)
with tempfile.TemporaryDirectory() as tmpdirname:
lm_config, feat_size, input_ids, attention_mask, g, _ = \
lm_config, feat_size, input_ids, attention_mask, _, part_config = \
create_lm_graph(tmpdirname)

ctx = mp.get_context('spawn')
p0 = ctx.Process(target=run_dist_cache, args=(g, lm_config, tmpdirname))
p1 = ctx.Process(target=run_dist_cache, args=(g, lm_config, tmpdirname))
p0 = ctx.Process(target=run_dist_cache, args=(part_config, tmpdirname))
p1 = ctx.Process(target=run_dist_cache, args=(part_config, tmpdirname))

p0.start()
p1.start()
Expand Down

0 comments on commit 17a38aa

Please sign in to comment.