Skip to content

Commit

Permalink
chonky
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Aug 5, 2024
2 parents da03522 + b21c2a9 commit 596ee59
Show file tree
Hide file tree
Showing 31 changed files with 1,938 additions and 214 deletions.
2 changes: 1 addition & 1 deletion composer/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

"""The Composer Version."""

__version__ = '0.23.5'
__version__ = '0.24.0.dev0'
2 changes: 0 additions & 2 deletions composer/algorithms/selective_backprop/selective_backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,6 @@ def apply(self, event: Event, state: State, logger: Optional[Logger] = None) ->
raise RuntimeError('Model must be of type ComposerModel')
self._loss_fn = state.model.loss
return

state.batch = state.device.batch_to_device(state.batch)
input, target = state.batch_get_item(key=self.input_key), state.batch_get_item(key=self.target_key)
assert isinstance(input, torch.Tensor) and isinstance(target, torch.Tensor), \
'Multiple tensors not supported for this method yet.'
Expand Down
2 changes: 0 additions & 2 deletions composer/algorithms/seq_length_warmup/seq_length_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,6 @@ def _activate_model(self, state: State, logger: Logger) -> None:
while True:
model_inputs = {k: v[:state.device_train_microbatch_size] for k, v in batch_clone.items()}

model_inputs = state.device.batch_to_device(model_inputs)

found_cuda_oom = 0 # int since bool BOR not supported on all torch.distributed backends
try:
# Start by running a forward and backward pass
Expand Down
20 changes: 18 additions & 2 deletions composer/checkpoint/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ def download_file(
if node_ranks is not None and node_rank not in node_ranks:
return

if source_uri.endswith('.symlink'):
source_path = download_and_extract_symlink(source_uri)

object_store = maybe_create_object_store_from_uri(source_uri)
_, _, source_path = parse_uri(source_uri)
if source_uri.endswith('.symlink'):
source_path = extract_path_from_symlink(source_path, object_store)
assert object_store is not None

@retry(num_attempts=num_attempts)
Expand All @@ -58,6 +59,21 @@ def _download():
log.debug(f'Finished downloading {source_path} to {destination_path}')


def download_and_extract_symlink(source_uri) -> str:
"""Downloads a symlink file from the specified URI and returns the path it points to.
Args:
source_uri (str): The URI to download the symlink from.
Returns:
str: The path the symlink points to.
"""
object_store = maybe_create_object_store_from_uri(source_uri)
_, _, source_path = parse_uri(source_uri)
source_path = extract_path_from_symlink(source_path, object_store)
return source_path


def download_monolithic_checkpoint(
source_uri: str,
destination_path: str,
Expand Down
Loading

0 comments on commit 596ee59

Please sign in to comment.