Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to support TPUs #1

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open

Changes to support TPUs #1

wants to merge 2 commits into from

Conversation

ultrons
Copy link
Owner

@ultrons ultrons commented Nov 3, 2020

For Internal review.

mmf/trainers/core/training_loop.py Show resolved Hide resolved
mmf/trainers/mmf_trainer.py Outdated Show resolved Hide resolved
mmf/utils/build.py Outdated Show resolved Hide resolved
mmf/utils/build.py Outdated Show resolved Hide resolved
mmf_cli/run.py Show resolved Hide resolved
initial changes to support training on tpus

changed tpu configuration to use training.device

replace parallelLoader with mpLoader to solved loader exhaust issue.

removed debug message. updated the comment

added comments for drop_last change.

removed pdb lines

removed redundant device config

added comments for pending changes

default init not applicable for xla device type

moved wrapping of dataloader to build

added line-debug function metsumm

removed some .item calls from reporting

xla equivalents in the distributed module, earlier eval was failing at the metrics all reduce step

implemented broadcast in terms of all_to_all

changes for checkpoint saving

change to make execution even across cores

corrected the is_master logic

one more fix for is_master

clean up of debug messages
@@ -60,8 +60,8 @@ def update(self, update_dict, batch_size):
if isinstance(v, torch.Tensor):
if v.dim() != 0:
v = v.mean()
v = v.item()
assert isinstance(v, (float, int))
#v = v.item()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of commenting out here, let's have a util function like

def item(self, v):
  if torch.is_tensor(v) and v.device.type == 'xla':
    return v
  return v.item()

and use v = self.item(v) and then assert on assert isinstance(v, (float, int)) or v.device.type == 'xla'

# since other device types such as xla can be passed
# falling back to cpu should only happen when device_type
# is set to cude but cuda is not available.
if not torch.cuda.is_available() and device == "cuda":

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ordering as device == 'cude' and torch.cuda.is_available() will save you the cuda available check.

@@ -186,9 +186,15 @@ def _infer_dataset_probabilities(self):
def __len__(self):
# Since, this is iterator, we need to return total length == number of batches
batch_size = get_batch_size()
# Changed the length to accomadate drop_last == True
# drop_last is required if the batch is split intor multiple cores

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/intor/into/

# Changed the length to accomadate drop_last == True
# drop_last is required if the batch is split intor multiple cores
# some of the cores may not have enough examples.
if is_xla():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use thee bool drop_last here instead of is_xla?

self.device = xm.xla_device()
self.distributed = True
self.local_rank = xm.get_local_ordinal()
self.tpu = True

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using self.xla to denote xla usage is better than using self.tpu

@@ -46,7 +46,7 @@ def __call__(self, update, iteration, meter):
Returns:
bool -- Tells whether early stopping occurred or not
"""
if not is_master():
if not is_master() and not is_xla():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this. why always False if not the master ordinal?

@@ -32,6 +32,7 @@ def main(configuration, init_distributed=False, predict=False):
if init_distributed:
distributed_init(config)


Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blank line, remove.

@@ -96,6 +97,7 @@ def run(opts: typing.Optional[typing.List[str]] = None, predict: bool = False):
if config.distributed.init_method is None:
infer_init_method(config)


Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove blank line

)
if is_xla():
import torch_xla.distributed.xla_multiprocessing as xmp
torch.multiprocessing.set_sharing_strategy("file_system")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this?

mmf_cli/run.py Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants