-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes to support TPUs #1
base: master
Are you sure you want to change the base?
Conversation
initial changes to support training on tpus changed tpu configuration to use training.device replace parallelLoader with mpLoader to solved loader exhaust issue. removed debug message. updated the comment added comments for drop_last change. removed pdb lines removed redundant device config added comments for pending changes default init not applicable for xla device type moved wrapping of dataloader to build added line-debug function metsumm removed some .item calls from reporting xla equivalents in the distributed module, earlier eval was failing at the metrics all reduce step implemented broadcast in terms of all_to_all changes for checkpoint saving change to make execution even across cores corrected the is_master logic one more fix for is_master clean up of debug messages
@@ -60,8 +60,8 @@ def update(self, update_dict, batch_size): | |||
if isinstance(v, torch.Tensor): | |||
if v.dim() != 0: | |||
v = v.mean() | |||
v = v.item() | |||
assert isinstance(v, (float, int)) | |||
#v = v.item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of commenting out here, let's have a util function like
def item(self, v):
if torch.is_tensor(v) and v.device.type == 'xla':
return v
return v.item()
and use v = self.item(v)
and then assert on assert isinstance(v, (float, int)) or v.device.type == 'xla'
# since other device types such as xla can be passed | ||
# falling back to cpu should only happen when device_type | ||
# is set to cude but cuda is not available. | ||
if not torch.cuda.is_available() and device == "cuda": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ordering as device == 'cude' and torch.cuda.is_available()
will save you the cuda available check.
@@ -186,9 +186,15 @@ def _infer_dataset_probabilities(self): | |||
def __len__(self): | |||
# Since, this is iterator, we need to return total length == number of batches | |||
batch_size = get_batch_size() | |||
# Changed the length to accomadate drop_last == True | |||
# drop_last is required if the batch is split intor multiple cores |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/intor/into/
# Changed the length to accomadate drop_last == True | ||
# drop_last is required if the batch is split intor multiple cores | ||
# some of the cores may not have enough examples. | ||
if is_xla(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you use thee bool drop_last
here instead of is_xla?
self.device = xm.xla_device() | ||
self.distributed = True | ||
self.local_rank = xm.get_local_ordinal() | ||
self.tpu = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think using self.xla to denote xla usage is better than using self.tpu
@@ -46,7 +46,7 @@ def __call__(self, update, iteration, meter): | |||
Returns: | |||
bool -- Tells whether early stopping occurred or not | |||
""" | |||
if not is_master(): | |||
if not is_master() and not is_xla(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this. why always False if not the master ordinal?
@@ -32,6 +32,7 @@ def main(configuration, init_distributed=False, predict=False): | |||
if init_distributed: | |||
distributed_init(config) | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blank line, remove.
@@ -96,6 +97,7 @@ def run(opts: typing.Optional[typing.List[str]] = None, predict: bool = False): | |||
if config.distributed.init_method is None: | |||
infer_init_method(config) | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove blank line
) | ||
if is_xla(): | ||
import torch_xla.distributed.xla_multiprocessing as xmp | ||
torch.multiprocessing.set_sharing_strategy("file_system") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why this?
For Internal review.