Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to support TPUs #1

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmf/common/meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def update(self, update_dict, batch_size):
if isinstance(v, torch.Tensor):
if v.dim() != 0:
v = v.mean()
v = v.item()
assert isinstance(v, (float, int))
#v = v.item()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of commenting out here, let's have a util function like

def item(self, v):
  if torch.is_tensor(v) and v.device.type == 'xla':
    return v
  return v.item()

and use v = self.item(v) and then assert on assert isinstance(v, (float, int)) or v.device.type == 'xla'

#assert isinstance(v, (float, int))
self.meters[k].update(v, batch_size)

def update_from_meter(self, meter):
Expand Down
6 changes: 5 additions & 1 deletion mmf/common/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,11 @@ def to_device(
if isinstance(device, str):
device = torch.device(device)

if not torch.cuda.is_available():
# default valude of device_type is cuda
# since other device types such as xla can be passed
# falling back to cpu should only happen when device_type
# is set to cude but cuda is not available.
if not torch.cuda.is_available() and device == "cuda":

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ordering as device == 'cude' and torch.cuda.is_available() will save you the cuda available check.

device = torch.device("cpu")
# to_device is specifically for SampleList
# if user is passing something custom built
Expand Down
10 changes: 8 additions & 2 deletions mmf/datasets/multi_dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from mmf.common.registry import registry
from mmf.utils.build import build_dataloader_and_sampler, build_dataset
from mmf.utils.distributed import broadcast_scalar, is_dist_initialized, is_master
from mmf.utils.distributed import broadcast_scalar, is_dist_initialized, is_master, is_xla
from mmf.utils.general import get_batch_size


Expand Down Expand Up @@ -186,9 +186,15 @@ def _infer_dataset_probabilities(self):
def __len__(self):
# Since, this is iterator, we need to return total length == number of batches
batch_size = get_batch_size()
# Changed the length to accomadate drop_last == True
# drop_last is required if the batch is split intor multiple cores

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/intor/into/

# some of the cores may not have enough examples.
if is_xla():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use thee bool drop_last here instead of is_xla?

return (self._total_length) // batch_size
else:
# This assumes drop_last=False for all loaders. See also
# build_dataloader_and_sampler().
return (self._total_length + batch_size - 1) // batch_size
return (self._total_length + batch_size - 1) // batch_size

def __iter__(self):
if self._num_datasets == 1:
Expand Down
4 changes: 2 additions & 2 deletions mmf/trainers/callbacks/logistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from mmf.trainers.callbacks.base import Callback
from mmf.utils.configuration import get_mmf_env
from mmf.utils.distributed import is_master
from mmf.utils.distributed import is_master, is_xla
from mmf.utils.logger import TensorboardLogger, log_progress, setup_output_folder
from mmf.utils.timer import Timer

Expand Down Expand Up @@ -105,7 +105,7 @@ def on_test_end(self, **kwargs):
def _summarize_report(self, meter, should_print=True, extra=None):
if extra is None:
extra = {}
if not is_master():
if not is_master() and not is_xla():
return

if self.training_config.tensorboard:
Expand Down
18 changes: 13 additions & 5 deletions mmf/trainers/core/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,27 @@ def configure_seed(self) -> None:
torch.backends.cudnn.benchmark = False

def configure_device(self) -> None:
self.local_rank = self.config.device_id
self.device = self.local_rank
self.distributed = False
if getattr(self.config.training, 'device', 'cuda') == 'xla':
import torch_xla.core.xla_model as xm
self.device = xm.xla_device()
self.distributed = True
self.local_rank = xm.get_local_ordinal()
self.tpu = True

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using self.xla to denote xla usage is better than using self.tpu

else:
self.tpu = False
self.local_rank = self.config.device_id
self.device = self.local_rank
self.distributed = False

# Will be updated later based on distributed setup
registry.register("global_device", self.device)

if self.config.distributed.init_method is not None:
self.distributed = True
self.device = torch.device("cuda", self.local_rank)
elif torch.cuda.is_available():
elif torch.cuda.is_available() and not self.tpu:
self.device = torch.device("cuda")
else:
elif not self.tpu:
self.device = torch.device("cpu")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this prbably is not your code, but the logic flow here seems wonky to me. self.device is specified first, then overridden?


registry.register("current_device", self.device)
Expand Down
4 changes: 4 additions & 0 deletions mmf/trainers/core/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mmf.common.report import Report
from mmf.common.sample import to_device
from mmf.utils.distributed import is_master
from mmf.utils.metsumm import metsumm

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this probably should go away, it's a debugging artifact. should not live in mmf codebase.



logger = logging.getLogger(__name__)
Expand All @@ -25,6 +26,7 @@ def evaluation_loop(
self.model.eval()
disable_tqdm = not use_tqdm or not is_master()
combined_report = None
metsumm("Before Validation Start:")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should go away, it's a debugging artifact. should not live in mmf codebase.


for batch in tqdm.tqdm(loader, disable=disable_tqdm):
report = self._forward(batch)
Expand All @@ -44,6 +46,8 @@ def evaluation_loop(

combined_report.metrics = self.metrics(combined_report, combined_report)
self.update_meter(combined_report, meter, eval_mode=True)
logger.info("Validation Done")
metsumm("After Validation Complete")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should go away, it's a debugging artifact. should not live in mmf codebase.


# enable train mode again
self.model.train()
Expand Down
4 changes: 2 additions & 2 deletions mmf/trainers/core/reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def update_dict(self, meter_update_dict, values_dict):
if val.dim() == 1:
val = val.mean()

if hasattr(val, "item"):
val = val.item()
#if hasattr(val, "item"):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's utilize the util function item as described in a previous comment instead of commenting out.

# val = val.item()

meter_update_dict.update({key: val})
total_val += val
Expand Down
16 changes: 10 additions & 6 deletions mmf/trainers/core/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mmf.common.sample import to_device
from mmf.utils.general import clip_gradients
from torch import Tensor

from mmf.utils.metsumm import metsumm

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should go away, it's a debugging artifact. should not live in mmf codebase.


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -72,8 +72,8 @@ def run_training_epoch(self) -> None:

combined_report = None
num_batches_for_this_update = 1
for idx, batch in enumerate(self.train_loader):

for idx, batch in enumerate(self.train_loader):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol what's the change here? new line? you could get rid of that if so.

if (idx + 1) % self.training_config.update_frequency == 0:
combined_report = None
num_batches_for_this_update = min(
Expand All @@ -84,7 +84,6 @@ def run_training_epoch(self) -> None:

# batch execution starts here
self.on_batch_start()
self.profile("Batch load time")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why delete?


report = self.run_training_batch(batch, num_batches_for_this_update)

Expand Down Expand Up @@ -129,7 +128,6 @@ def run_training_epoch(self) -> None:
# Validation begin callbacks
self.on_validation_start()

logger.info("Evaluation time. Running on full validation set...")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why delete?

# Validation and Early stopping
# Create a new meter for this case
report, meter = self.evaluation_loop(self.val_loader)
Expand All @@ -146,7 +144,6 @@ def run_training_epoch(self) -> None:
torch.cuda.empty_cache()

if stop is True:
logger.info("Early stopping activated")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why delete?

should_break = True
if self.num_updates >= self.max_updates:
should_break = True
Expand All @@ -168,7 +165,7 @@ def run_training_batch(self, batch: Tensor, loss_divisor: int) -> None:
def _forward(self, batch: Tensor) -> Dict[str, Any]:
prepared_batch = self.dataset_loader.prepare_batch(batch)
# Move the sample list to device if it isn't as of now.
prepared_batch = to_device(prepared_batch, torch.device("cuda"))
prepared_batch = to_device(prepared_batch, self.device)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did u test the non-xla codepaths w/ this change? does it still work?

self.profile("Batch prepare time")
# Arguments should be a dict at this point

Expand All @@ -188,6 +185,7 @@ def _start_update(self):

def _backward(self, loss: Tensor) -> None:
self.scaler.scale(loss).backward()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove blank line.

self.profile("Backward time")

def _finish_update(self):
Expand All @@ -199,6 +197,12 @@ def _finish_update(self):
self.config,
scale=self.scaler.get_scale(),
)
if getattr(self.config.training, 'device', 'cuda') == 'xla' and self.config.distributed.world_size > 1:
ultrons marked this conversation as resolved.
Show resolved Hide resolved

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use the is_xla util function here? Also, no need for checking the world size. reduce_gradients is a noop if wsz is 1.

import torch_xla.core.xla_model as xm
#gradients = xm._fetch_gradients(self.optimizer)
# Assumes no model parallel
#xm.all_reduce('sum', gradients, scale=1.0 / self.config.distributed.world_size)
xm.reduce_gradients(self.optimizer)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's clean this block and remove comments etc.


self.scaler.step(self.optimizer)
self.scaler.update()
Expand Down
24 changes: 23 additions & 1 deletion mmf/utils/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,16 @@
from mmf.common.registry import registry
from mmf.datasets.processors.processors import Processor
from mmf.utils.configuration import Configuration
from mmf.utils.distributed import is_dist_initialized
from mmf.utils.distributed import is_dist_initialized, is_xla
from mmf.utils.general import get_optimizer_parameters
from omegaconf import DictConfig, OmegaConf

try:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
except ImportError:
xm = None
pl = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to set pl if you're not using it independent of xm.


ProcessorType = Type[Processor]
ProcessorDict = Dict[str, ProcessorType]
Expand Down Expand Up @@ -152,6 +158,18 @@ def build_dataloader_and_sampler(
if not isinstance(dataset_instance, torch.utils.data.IterableDataset):
other_args = _add_extra_args_for_dataloader(dataset_instance, other_args)

if is_xla():
dataset_type = dataset_instance.dataset_type
shuffle=True
other_args["sampler"] = torch.utils.data.DistributedSampler(
dataset_instance,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=shuffle
)
Comment on lines +170 to +174

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

de-indent for pep8 compliance.

other_args.pop("shuffle")


Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

too many blank lines.

loader = torch.utils.data.DataLoader(
dataset=dataset_instance,
pin_memory=pin_memory,
Expand All @@ -163,6 +181,10 @@ def build_dataloader_and_sampler(
**other_args,
)

if is_xla():
device = xm.xla_device()
loader = pl.MpDeviceLoader(loader, device)

if num_workers >= 0:
# Suppress leaking semaphore warning
os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning"
Expand Down
36 changes: 29 additions & 7 deletions mmf/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from mmf.common.registry import registry
from mmf.utils.configuration import get_mmf_env, load_yaml
from mmf.utils.distributed import is_master, synchronize
from mmf.utils.distributed import is_master, synchronize, is_xla
from mmf.utils.download import download_pretrained_model
from mmf.utils.file_io import PathManager
from mmf.utils.general import updir
Expand All @@ -22,6 +22,11 @@
except ImportError:
git = None

try:
import torch_xla.core.xla_model as xm
except ImportError:
xm = None

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -379,15 +384,26 @@ def _get_vcs_fields(self):
"git/diff": self.git_repo.git.diff("--no-prefix"),
}


Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one blank line too many.


def save_func(self):
if is_xla():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return xm.save if is_xla() else torch.save is cleaner imho.

return xm.save
else:
return torch.save

def save(self, update, iteration=None, update_best=False):
# Only save in main process
if not is_master():
if not is_master() and not is_xla():
return

logger.info("Checkpoint save operation started!")

if not iteration:
iteration = update

ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove blank line

best_ckpt_filepath = os.path.join(
self.ckpt_foldername, self.ckpt_prefix + "best.ckpt"
)
Expand Down Expand Up @@ -437,23 +453,29 @@ def save(self, update, iteration=None, update_best=False):
git_metadata_dict = self._get_vcs_fields()
ckpt.update(git_metadata_dict)

logger.info("Saving checkpoint")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you just logged this in line 400. remove one of them.

with PathManager.open(ckpt_filepath, "wb") as f:
torch.save(ckpt, f)
self.save_func()(ckpt, f)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks ugly. I'd rewrite it:

@property
def save_func(self):
  ...

then you can call it like this:

self.save_func(chpt, g)


if update_best:
logger.info("Saving best checkpoint")
with PathManager.open(best_ckpt_filepath, "wb") as f:
torch.save(ckpt, f)
self.save_func()(ckpt, f)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment


# Save current always

logger.info("Saving Current checkpoint")
with PathManager.open(current_ckpt_filepath, "wb") as f:
torch.save(ckpt, f)
self.save_func()(ckpt, f)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same


# Remove old checkpoints if max_to_keep is set
if self.max_to_keep > 0:
if len(self.saved_iterations) == self.max_to_keep:
self.remove(self.saved_iterations.pop(0))
self.saved_iterations.append(update)

logger.info("Checkpoint save operation finished!")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logger getting too excited :)


def remove(self, update):
ckpt_filepath = os.path.join(self.models_foldername, "model_%d.ckpt" % update)
if PathManager.isfile(ckpt_filepath):
Expand All @@ -468,6 +490,6 @@ def restore(self):
self._load(best_path, force=True)

def finalize(self):
if is_master():
if is_master() or is_xla():
with PathManager.open(self.pth_filepath, "wb") as f:
torch.save(self.trainer.model.state_dict(), f)
self.save_func()(self.trainer.model.state_dict(), f)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

3 changes: 3 additions & 0 deletions mmf/utils/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,9 @@ def _update_specific(self, config):
lr = config.learning_rate
config.optimizer.params.lr = lr

# TODO: Correct the following issue

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean

# This check is triggered before the config override from commandline is effective
# even after setting training.device = 'xla', it gets triggered.
if not torch.cuda.is_available() and "cuda" in config.training.device:
warnings.warn(
"Device specified is 'cuda' but cuda is not present. "
Expand Down
Loading