-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes to support TPUs #1
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -396,7 +396,11 @@ def to_device( | |
if isinstance(device, str): | ||
device = torch.device(device) | ||
|
||
if not torch.cuda.is_available(): | ||
# default valude of device_type is cuda | ||
# since other device types such as xla can be passed | ||
# falling back to cpu should only happen when device_type | ||
# is set to cude but cuda is not available. | ||
if not torch.cuda.is_available() and device == "cuda": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ordering as |
||
device = torch.device("cpu") | ||
# to_device is specifically for SampleList | ||
# if user is passing something custom built | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
import numpy as np | ||
from mmf.common.registry import registry | ||
from mmf.utils.build import build_dataloader_and_sampler, build_dataset | ||
from mmf.utils.distributed import broadcast_scalar, is_dist_initialized, is_master | ||
from mmf.utils.distributed import broadcast_scalar, is_dist_initialized, is_master, is_xla | ||
from mmf.utils.general import get_batch_size | ||
|
||
|
||
|
@@ -186,9 +186,15 @@ def _infer_dataset_probabilities(self): | |
def __len__(self): | ||
# Since, this is iterator, we need to return total length == number of batches | ||
batch_size = get_batch_size() | ||
# Changed the length to accomadate drop_last == True | ||
# drop_last is required if the batch is split intor multiple cores | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s/intor/into/ |
||
# some of the cores may not have enough examples. | ||
if is_xla(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you use thee bool |
||
return (self._total_length) // batch_size | ||
else: | ||
# This assumes drop_last=False for all loaders. See also | ||
# build_dataloader_and_sampler(). | ||
return (self._total_length + batch_size - 1) // batch_size | ||
return (self._total_length + batch_size - 1) // batch_size | ||
|
||
def __iter__(self): | ||
if self._num_datasets == 1: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,19 +20,27 @@ def configure_seed(self) -> None: | |
torch.backends.cudnn.benchmark = False | ||
|
||
def configure_device(self) -> None: | ||
self.local_rank = self.config.device_id | ||
self.device = self.local_rank | ||
self.distributed = False | ||
if getattr(self.config.training, 'device', 'cuda') == 'xla': | ||
import torch_xla.core.xla_model as xm | ||
self.device = xm.xla_device() | ||
self.distributed = True | ||
self.local_rank = xm.get_local_ordinal() | ||
self.tpu = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think using self.xla to denote xla usage is better than using self.tpu |
||
else: | ||
self.tpu = False | ||
self.local_rank = self.config.device_id | ||
self.device = self.local_rank | ||
self.distributed = False | ||
|
||
# Will be updated later based on distributed setup | ||
registry.register("global_device", self.device) | ||
|
||
if self.config.distributed.init_method is not None: | ||
self.distributed = True | ||
self.device = torch.device("cuda", self.local_rank) | ||
elif torch.cuda.is_available(): | ||
elif torch.cuda.is_available() and not self.tpu: | ||
self.device = torch.device("cuda") | ||
else: | ||
elif not self.tpu: | ||
self.device = torch.device("cpu") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this prbably is not your code, but the logic flow here seems wonky to me. |
||
|
||
registry.register("current_device", self.device) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
from mmf.common.report import Report | ||
from mmf.common.sample import to_device | ||
from mmf.utils.distributed import is_master | ||
from mmf.utils.metsumm import metsumm | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this probably should go away, it's a debugging artifact. should not live in mmf codebase. |
||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -25,6 +26,7 @@ def evaluation_loop( | |
self.model.eval() | ||
disable_tqdm = not use_tqdm or not is_master() | ||
combined_report = None | ||
metsumm("Before Validation Start:") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should go away, it's a debugging artifact. should not live in mmf codebase. |
||
|
||
for batch in tqdm.tqdm(loader, disable=disable_tqdm): | ||
report = self._forward(batch) | ||
|
@@ -44,6 +46,8 @@ def evaluation_loop( | |
|
||
combined_report.metrics = self.metrics(combined_report, combined_report) | ||
self.update_meter(combined_report, meter, eval_mode=True) | ||
logger.info("Validation Done") | ||
metsumm("After Validation Complete") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should go away, it's a debugging artifact. should not live in mmf codebase. |
||
|
||
# enable train mode again | ||
self.model.train() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,8 +51,8 @@ def update_dict(self, meter_update_dict, values_dict): | |
if val.dim() == 1: | ||
val = val.mean() | ||
|
||
if hasattr(val, "item"): | ||
val = val.item() | ||
#if hasattr(val, "item"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's utilize the util function |
||
# val = val.item() | ||
|
||
meter_update_dict.update({key: val}) | ||
total_val += val | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ | |
from mmf.common.sample import to_device | ||
from mmf.utils.general import clip_gradients | ||
from torch import Tensor | ||
|
||
from mmf.utils.metsumm import metsumm | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should go away, it's a debugging artifact. should not live in mmf codebase. |
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -72,8 +72,8 @@ def run_training_epoch(self) -> None: | |
|
||
combined_report = None | ||
num_batches_for_this_update = 1 | ||
for idx, batch in enumerate(self.train_loader): | ||
|
||
for idx, batch in enumerate(self.train_loader): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lol what's the change here? new line? you could get rid of that if so. |
||
if (idx + 1) % self.training_config.update_frequency == 0: | ||
combined_report = None | ||
num_batches_for_this_update = min( | ||
|
@@ -84,7 +84,6 @@ def run_training_epoch(self) -> None: | |
|
||
# batch execution starts here | ||
self.on_batch_start() | ||
self.profile("Batch load time") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why delete? |
||
|
||
report = self.run_training_batch(batch, num_batches_for_this_update) | ||
|
||
|
@@ -129,7 +128,6 @@ def run_training_epoch(self) -> None: | |
# Validation begin callbacks | ||
self.on_validation_start() | ||
|
||
logger.info("Evaluation time. Running on full validation set...") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why delete? |
||
# Validation and Early stopping | ||
# Create a new meter for this case | ||
report, meter = self.evaluation_loop(self.val_loader) | ||
|
@@ -146,7 +144,6 @@ def run_training_epoch(self) -> None: | |
torch.cuda.empty_cache() | ||
|
||
if stop is True: | ||
logger.info("Early stopping activated") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why delete? |
||
should_break = True | ||
if self.num_updates >= self.max_updates: | ||
should_break = True | ||
|
@@ -168,7 +165,7 @@ def run_training_batch(self, batch: Tensor, loss_divisor: int) -> None: | |
def _forward(self, batch: Tensor) -> Dict[str, Any]: | ||
prepared_batch = self.dataset_loader.prepare_batch(batch) | ||
# Move the sample list to device if it isn't as of now. | ||
prepared_batch = to_device(prepared_batch, torch.device("cuda")) | ||
prepared_batch = to_device(prepared_batch, self.device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did u test the non-xla codepaths w/ this change? does it still work? |
||
self.profile("Batch prepare time") | ||
# Arguments should be a dict at this point | ||
|
||
|
@@ -188,6 +185,7 @@ def _start_update(self): | |
|
||
def _backward(self, loss: Tensor) -> None: | ||
self.scaler.scale(loss).backward() | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove blank line. |
||
self.profile("Backward time") | ||
|
||
def _finish_update(self): | ||
|
@@ -199,6 +197,12 @@ def _finish_update(self): | |
self.config, | ||
scale=self.scaler.get_scale(), | ||
) | ||
if getattr(self.config.training, 'device', 'cuda') == 'xla' and self.config.distributed.world_size > 1: | ||
ultrons marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you use the |
||
import torch_xla.core.xla_model as xm | ||
#gradients = xm._fetch_gradients(self.optimizer) | ||
# Assumes no model parallel | ||
#xm.all_reduce('sum', gradients, scale=1.0 / self.config.distributed.world_size) | ||
xm.reduce_gradients(self.optimizer) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's clean this block and remove comments etc. |
||
|
||
self.scaler.step(self.optimizer) | ||
self.scaler.update() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,10 +10,16 @@ | |
from mmf.common.registry import registry | ||
from mmf.datasets.processors.processors import Processor | ||
from mmf.utils.configuration import Configuration | ||
from mmf.utils.distributed import is_dist_initialized | ||
from mmf.utils.distributed import is_dist_initialized, is_xla | ||
from mmf.utils.general import get_optimizer_parameters | ||
from omegaconf import DictConfig, OmegaConf | ||
|
||
try: | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.distributed.parallel_loader as pl | ||
except ImportError: | ||
xm = None | ||
pl = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to set |
||
|
||
ProcessorType = Type[Processor] | ||
ProcessorDict = Dict[str, ProcessorType] | ||
|
@@ -152,6 +158,18 @@ def build_dataloader_and_sampler( | |
if not isinstance(dataset_instance, torch.utils.data.IterableDataset): | ||
other_args = _add_extra_args_for_dataloader(dataset_instance, other_args) | ||
|
||
if is_xla(): | ||
dataset_type = dataset_instance.dataset_type | ||
shuffle=True | ||
other_args["sampler"] = torch.utils.data.DistributedSampler( | ||
dataset_instance, | ||
num_replicas=xm.xrt_world_size(), | ||
rank=xm.get_ordinal(), | ||
shuffle=shuffle | ||
) | ||
Comment on lines
+170
to
+174
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. de-indent for pep8 compliance. |
||
other_args.pop("shuffle") | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. too many blank lines. |
||
loader = torch.utils.data.DataLoader( | ||
dataset=dataset_instance, | ||
pin_memory=pin_memory, | ||
|
@@ -163,6 +181,10 @@ def build_dataloader_and_sampler( | |
**other_args, | ||
) | ||
|
||
if is_xla(): | ||
device = xm.xla_device() | ||
loader = pl.MpDeviceLoader(loader, device) | ||
|
||
if num_workers >= 0: | ||
# Suppress leaking semaphore warning | ||
os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
import torch | ||
from mmf.common.registry import registry | ||
from mmf.utils.configuration import get_mmf_env, load_yaml | ||
from mmf.utils.distributed import is_master, synchronize | ||
from mmf.utils.distributed import is_master, synchronize, is_xla | ||
from mmf.utils.download import download_pretrained_model | ||
from mmf.utils.file_io import PathManager | ||
from mmf.utils.general import updir | ||
|
@@ -22,6 +22,11 @@ | |
except ImportError: | ||
git = None | ||
|
||
try: | ||
import torch_xla.core.xla_model as xm | ||
except ImportError: | ||
xm = None | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
|
@@ -379,15 +384,26 @@ def _get_vcs_fields(self): | |
"git/diff": self.git_repo.git.diff("--no-prefix"), | ||
} | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one blank line too many. |
||
|
||
def save_func(self): | ||
if is_xla(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
return xm.save | ||
else: | ||
return torch.save | ||
|
||
def save(self, update, iteration=None, update_best=False): | ||
# Only save in main process | ||
if not is_master(): | ||
if not is_master() and not is_xla(): | ||
return | ||
|
||
logger.info("Checkpoint save operation started!") | ||
|
||
if not iteration: | ||
iteration = update | ||
|
||
ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove blank line |
||
best_ckpt_filepath = os.path.join( | ||
self.ckpt_foldername, self.ckpt_prefix + "best.ckpt" | ||
) | ||
|
@@ -437,23 +453,29 @@ def save(self, update, iteration=None, update_best=False): | |
git_metadata_dict = self._get_vcs_fields() | ||
ckpt.update(git_metadata_dict) | ||
|
||
logger.info("Saving checkpoint") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you just logged this in line 400. remove one of them. |
||
with PathManager.open(ckpt_filepath, "wb") as f: | ||
torch.save(ckpt, f) | ||
self.save_func()(ckpt, f) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this looks ugly. I'd rewrite it:
then you can call it like this:
|
||
|
||
if update_best: | ||
logger.info("Saving best checkpoint") | ||
with PathManager.open(best_ckpt_filepath, "wb") as f: | ||
torch.save(ckpt, f) | ||
self.save_func()(ckpt, f) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment |
||
|
||
# Save current always | ||
|
||
logger.info("Saving Current checkpoint") | ||
with PathManager.open(current_ckpt_filepath, "wb") as f: | ||
torch.save(ckpt, f) | ||
self.save_func()(ckpt, f) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
|
||
# Remove old checkpoints if max_to_keep is set | ||
if self.max_to_keep > 0: | ||
if len(self.saved_iterations) == self.max_to_keep: | ||
self.remove(self.saved_iterations.pop(0)) | ||
self.saved_iterations.append(update) | ||
|
||
logger.info("Checkpoint save operation finished!") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logger getting too excited :) |
||
|
||
def remove(self, update): | ||
ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update) | ||
if PathManager.isfile(ckpt_filepath): | ||
|
@@ -468,6 +490,6 @@ def restore(self): | |
self._load(best_path, force=True) | ||
|
||
def finalize(self): | ||
if is_master(): | ||
if is_master() or is_xla(): | ||
with PathManager.open(self.pth_filepath, "wb") as f: | ||
torch.save(self.trainer.model.state_dict(), f) | ||
self.save_func()(self.trainer.model.state_dict(), f) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -531,6 +531,9 @@ def _update_specific(self, config): | |
lr = config.learning_rate | ||
config.optimizer.params.lr = lr | ||
|
||
# TODO: Correct the following issue | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. clean |
||
# This check is triggered before the config override from commandline is effective | ||
# even after setting training.device = 'xla', it gets triggered. | ||
if not torch.cuda.is_available() and "cuda" in config.training.device: | ||
warnings.warn( | ||
"Device specified is 'cuda' but cuda is not present. " | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of commenting out here, let's have a util function like
and use
v = self.item(v)
and then assert onassert isinstance(v, (float, int)) or v.device.type == 'xla'