Skip to content

Commit

Permalink
Resolve conflicts in main branch for unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chang-l committed Feb 8, 2024
1 parent 8bad6bd commit 158e8c9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 55 deletions.
52 changes: 5 additions & 47 deletions tests/unit-tests/test_wg_sparse_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,12 @@
partition_graph,
)

from graphstorm.wholegraph import init_wholegraph, is_wholegraph_init
from graphstorm.wholegraph import init_wholegraph, is_wholegraph_init, create_wholememory_optimizer

from graphstorm.model import GSNodeEncoderInputLayer
from graphstorm.model.embed import compute_node_input_embeddings
from graphstorm.model.utils import save_sparse_embeds
from graphstorm.model.utils import load_sparse_embeds
from graphstorm.model.utils import _get_sparse_emb_range
from graphstorm.model.utils import pad_file_index
from graphstorm import get_feat_size
from graphstorm import get_node_feat_size

from data_utils import generate_dummy_dist_graph, generate_dummy_hetero_graph

Expand Down Expand Up @@ -155,46 +152,7 @@ def _start_trainer(
dist_graph, feat_size, 32, use_wholegraph_sparse_emb=True
)
for ntype in embed_layer.sparse_embeds.keys():
embed_layer.sparse_embeds[ntype].attach_wg_optimizer(None)


def get_wholegraph_sparse_emb(sparse_emb):
(local_tensor, _) = sparse_emb.get_local_tensor()
return local_tensor

saved_embs = \
{ntype: get_wholegraph_sparse_emb(sparse_emb) \
for ntype, sparse_emb in embed_layer.sparse_embeds.items()}
save_sparse_embeds(model_path, embed_layer)
load_sparse_embeds(model_path, embed_layer)
load_sparse_embs = \
{ntype: get_wholegraph_sparse_emb(sparse_emb) \
for ntype, sparse_emb in embed_layer.sparse_embeds.items()}

for ntype in embed_layer.sparse_embeds.keys():
assert_equal(saved_embs[ntype].numpy(), load_sparse_embs[ntype].numpy())
dgl.distributed.exit_client()
_finalize()

def _start_trainer(
rank,
world_size,
ip_config,
part_config,
num_server,
model_path,
):
os.environ["DGL_GROUP_ID"] = str(0)
dgl.distributed.initialize(ip_config)
dist_graph = DistGraph("test_wholegraph_sparseemb", part_config=part_config)
print('here world size is: ', world_size)
_initialize(rank, world_size, use_wholegraph=True)
feat_size = {"n0":0, "n1":0}
embed_layer = GSNodeEncoderInputLayer(
dist_graph, feat_size, 32, use_wholegraph_sparse_emb=True
)
for ntype in embed_layer.sparse_embeds.keys():
embed_layer.sparse_embeds[ntype].attach_wg_optimizer(None)
embed_layer.sparse_embeds[ntype].attach_wg_optimizer(create_wholememory_optimizer("adam", {}))


def get_wholegraph_sparse_emb(sparse_emb):
Expand Down Expand Up @@ -327,7 +285,7 @@ def test_wg_input_layer3(dev):
# get the test dummy distributed graph
g, _ = generate_dummy_dist_graph(tmpdirname)

feat_size = get_feat_size(g, {'n0' : ['feat']})
feat_size = get_node_feat_size(g, {'n0' : ['feat']})
layer = GSNodeEncoderInputLayer(g, feat_size, 2, use_wholegraph_sparse_emb=True)
assert len(layer.input_projs) == 1
assert list(layer.input_projs.keys())[0] == 'n0'
Expand Down Expand Up @@ -393,7 +351,7 @@ def test_wg_input_layer2():
# get the test dummy distributed graph
g, _ = generate_dummy_dist_graph(tmpdirname)

feat_size = get_feat_size(g, 'feat')
feat_size = get_node_feat_size(g, 'feat')
layer = GSNodeEncoderInputLayer(
g, feat_size, 2, use_node_embeddings=True, use_wholegraph_sparse_emb=True
)
Expand Down
11 changes: 3 additions & 8 deletions tests/unit-tests/test_wg_sparse_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,9 @@
from graphstorm.wholegraph import init_wholegraph, is_wholegraph_init, create_wholememory_optimizer

from graphstorm.model import GSNodeEncoderInputLayer
from graphstorm.model.embed import compute_node_input_embeddings
from graphstorm.model.utils import save_sparse_embeds
from graphstorm.model.utils import load_sparse_embeds
from graphstorm.model.utils import _get_sparse_emb_range
from graphstorm.model.utils import pad_file_index
from graphstorm import get_feat_size
from graphstorm import get_node_feat_size

from data_utils import generate_dummy_dist_graph, generate_dummy_hetero_graph
from data_utils import generate_dummy_hetero_graph


def generate_ip_config(file_name, num_machines, num_servers):
Expand Down Expand Up @@ -154,7 +149,7 @@ def _start_trainer(
lr = 0.01

dev = th.device('cuda:{}'.format(rank))
feat_size = get_feat_size(dist_graph, {'n0': ['feat']})
feat_size = get_node_feat_size(dist_graph, {'n0': ['feat']})
layer_wg = GSNodeEncoderInputLayer(
dist_graph, feat_size, embed_dim, use_wholegraph_sparse_emb=True
).to(dev)
Expand Down

0 comments on commit 158e8c9

Please sign in to comment.