Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Jan 30, 2024
1 parent ad0fb96 commit 041bc7c
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def main(args):
from pylibwholegraph.torch.initialize import (
get_global_communicator,
get_local_node_communicator,
init
init,
)

logger.info("initializing WG comms...")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def __load_x_torch(self) -> None:

def __load_x_wg(self) -> None:
import logging

logger = logging.getLogger("OGBNPapers100MDataset")
logger.info("Loading x into WG embedding...")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,22 @@ def get_input_files(self, path, epoch=0, stage="train"):
np.random.seed(epoch)
np.random.shuffle(splits)

ex = re.compile(r'batch=([0-9]+)\-([0-9]+).parquet')
num_batches = min([
sum([int(ex.match(fname.split('/')[-1])[2]) - int(ex.match(fname.split('/')[-1])[1]) for fname in s])
for s in splits
])
ex = re.compile(r"batch=([0-9]+)\-([0-9]+).parquet")
num_batches = min(
[
sum(
[
int(ex.match(fname.split("/")[-1])[2])
- int(ex.match(fname.split("/")[-1])[1])
for fname in s
]
)
for s in splits
]
)
if num_batches == 0:
raise ValueError(f"Too few batches for training with world size {self.__world_size}")
raise ValueError(
f"Too few batches for training with world size {self.__world_size}"
)

return splits[self.rank], num_batches
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,16 @@ def log_batch(


def train_epoch(
model, optimizer, loader, feature_store, epoch, num_classes, time_d, logger, rank, max_num_batches
model,
optimizer,
loader,
feature_store,
epoch,
num_classes,
time_d,
logger,
rank,
max_num_batches,
):
"""
Train the model for one epoch.
Expand Down Expand Up @@ -185,7 +194,7 @@ def train_epoch(
epoch=epoch,
rank=rank,
)

if max_num_batches is not None and iter_i >= max_num_batches:
break

Expand Down Expand Up @@ -319,7 +328,7 @@ def train(self):
loader=loader,
feature_store=self.data,
num_classes=self.dataset.num_labels,
max_num_batches = max_num_batches,
max_num_batches=max_num_batches,
)
print(f"Accuracy: {test_acc:.4f}%")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,19 +258,24 @@ def get_input_files(self, path, epoch=0, stage="train"):
np.random.shuffle(file_list)

splits = np.array_split(file_list, self.__world_size)

import logging
logger = logging.getLogger('PyGCuGraphTrainer')

logger = logging.getLogger("PyGCuGraphTrainer")
logger.info(f"rank {self.rank} input files: {str(splits[self.rank])}")

split = splits[self.rank]

ex = re.compile(r'batch=([0-9]+)\-([0-9]+).parquet')
num_batches = min([
sum([int(ex.match(fname)[2]) - int(ex.match(fname)[1]) for fname in s])
for s in splits
])
ex = re.compile(r"batch=([0-9]+)\-([0-9]+).parquet")
num_batches = min(
[
sum([int(ex.match(fname)[2]) - int(ex.match(fname)[1]) for fname in s])
for s in splits
]
)
if num_batches == 0:
raise ValueError(f"Too few batches for training with world size {self.__world_size}")
raise ValueError(
f"Too few batches for training with world size {self.__world_size}"
)

return split, num_batches
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def pyg_num_workers(world_size: int) -> int:


def calc_accuracy(
loader: NeighborLoader, max_num_batches: int, model: torch.nn.Module, num_classes: int
loader: NeighborLoader,
max_num_batches: int,
model: torch.nn.Module,
num_classes: int,
) -> float:
"""
Evaluates the accuracy of a model given a loader over evaluation samples.
Expand Down Expand Up @@ -131,16 +134,14 @@ def train(self):
end_time_backward = start_time

num_layers = len(self.model.module.convs)

for epoch in range(self.num_epochs):
with td.algorithms.join.Join(
[self.model, self.optimizer], divide_by_initial_world_size=False
):
self.model.train()
loader, max_num_batches = self.get_loader(epoch=epoch, stage="train")
for iter_i, data in enumerate(
loader
):
for iter_i, data in enumerate(loader):
loader_time_iter = time.perf_counter() - end_time_backward
time_loader += loader_time_iter

Expand Down Expand Up @@ -247,7 +248,9 @@ def train(self):
loader, max_num_batches = self.get_loader(epoch=epoch, stage="test")
num_classes = self.dataset.num_labels

acc = calc_accuracy(loader, max_num_batches, self.model.module, num_classes)
acc = calc_accuracy(
loader, max_num_batches, self.model.module, num_classes
)

if self.rank == 0:
print(
Expand Down

0 comments on commit 041bc7c

Please sign in to comment.