Skip to content

Commit

Permalink
style: format
Browse files Browse the repository at this point in the history
  • Loading branch information
amitkparekh committed Dec 1, 2023
1 parent 9137073 commit 17c1a14
Show file tree
Hide file tree
Showing 15 changed files with 17 additions and 20 deletions.
6 changes: 4 additions & 2 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ extend-ignore =
# Google Python style is not RST until after processed by Napoleon
# See https://github.com/peterjc/flake8-rst-docstrings/issues/17
RST201,RST203,RST301,
# It happens too often
C416, C419,
# This is new and cba to change the repo
S113
extend-select =
# Should raise AssertionError instead of assert False
B011,
Expand All @@ -69,8 +73,6 @@ extend-select =
# Within an except clause, raise exceptions with `raise ... from err` or `raise ...
# from None` to distinguish them from errors in exception handling
B904,
# Alternative to E501 regarding line length
B950,
# Counterpart to W503, enforce having the operator at the start of a new line.
W504,

Expand Down
4 changes: 2 additions & 2 deletions src/emma_policy/api/clients/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def extract_single_image(self, image: Union[Image.Image, ArrayLike]) -> FeatureR
try:
response.raise_for_status()
except requests.exceptions.HTTPError as err:
raise SystemExit(err)
raise SystemExit(err) from err

data = response.json()
feature_response = FeatureResponse(
Expand Down Expand Up @@ -103,7 +103,7 @@ def extract_batch_images(
try:
response.raise_for_status()
except requests.exceptions.HTTPError as err:
raise SystemExit(err)
raise SystemExit(err) from err

data = response.json()

Expand Down
1 change: 0 additions & 1 deletion src/emma_policy/datamodules/coco_captioning_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def __init__(
merged_annotations: bool = True,
is_train: bool = True,
) -> None:

if not merged_annotations:
raise NotImplementedError(
"Expecting dbs where every instance is an image associated with all of its captions."
Expand Down
1 change: 0 additions & 1 deletion src/emma_policy/datamodules/nlvr2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
max_frames: int = 0,
use_task_prefix: bool = False,
) -> None:

super().__init__(
dataset_db_path=dataset_db_path, tokenizer=tokenizer, max_frames=max_frames
)
Expand Down
2 changes: 0 additions & 2 deletions src/emma_policy/datamodules/pretrain_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,6 @@ def _region_mapping(
width: int,
height: int,
) -> tuple[torch.Tensor, torch.Tensor]:

gt_bbox = []
for region in regions:
gt_bbox_coord = BoxMode.convert(
Expand Down Expand Up @@ -894,7 +893,6 @@ def _convert_trajectory_to_text(
trajectory_text.extend(split_action_name(action.api_action.action))
# Match the object to a predicted bounding box
if "bbox" in action.discrete_action.args:

bbox_coord = action.discrete_action.args["bbox"] # noqa: WPS529
gt_bbox = torch.tensor(
[
Expand Down
1 change: 0 additions & 1 deletion src/emma_policy/datamodules/refcoco_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
shuffle_objects: bool = False,
train_with_golden_bbox_prob: float = 1.0,
) -> None:

super().__init__(
dataset_db_path=dataset_db_path,
tokenizer=tokenizer,
Expand Down
1 change: 0 additions & 1 deletion src/emma_policy/datamodules/vqa_v2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
tokenizer: PreTrainedTokenizer,
max_frames: int = 0,
) -> None:

super().__init__(
dataset_db_path=dataset_db_path,
tokenizer=tokenizer,
Expand Down
1 change: 0 additions & 1 deletion src/emma_policy/models/decoder_emma.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def decoder_layer_outputs(
) -> tuple[torch.FloatTensor, ...]:
"""Get output from a single decoder layer."""
if self.gradient_checkpointing and self.training:

if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
Expand Down
3 changes: 2 additions & 1 deletion src/emma_policy/models/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ def tiny_value_of_dtype(dtype: torch.dtype) -> float:
This is used to avoid numerical issues such as division by zero. This is different from
`info_value_of_dtype(dtype).tiny` because it causes some NaN bugs. Only supports floating point
dtypes. Implementation from AllenNLP: https://github.com/allenai/allennlp/blob/39c40fe38cd2fd36b3465b0b3c031f54ec824160/allennlp/nn/util.py#L2010-L2024
dtypes. Implementation from AllenNLP:
https://github.com/allenai/allennlp/blob/39c40fe38cd2fd36b3465b0b3c031f54ec824160/allennlp/nn/util.py#L2010-L2024
"""
if not dtype.is_floating_point:
raise TypeError("Only supports floating point dtypes.")
Expand Down
4 changes: 3 additions & 1 deletion src/emma_policy/models/model_output_emma.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

@dataclass
class EmmaSeq2SeqModelOutput(ModelOutput):
"""Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential decoding.
"""Base class for encoder outputs.
Also contains pre-computed hidden states that can speed up sequential decoding.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Expand Down
1 change: 0 additions & 1 deletion src/emma_policy/models/nlvr2_emma_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(
nlvr2_metrics: Optional[list[str]] = None,
**kwargs: Any,
) -> None:

self._tokenizer = AutoTokenizer.from_pretrained(model_name)
self._pred_gt: dict[str, list[str]] = {
"predictions": [],
Expand Down
4 changes: 2 additions & 2 deletions src/emma_policy/utils/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def _convert(self, box: torch.Tensor, from_mode: BoxMode, to_mode: BoxMode) -> t

try:
converted_box = convert_functions[from_mode][to_mode](box)
except KeyError:
except KeyError as err:
raise NotImplementedError(
f"Conversion from BoxMode {from_mode} to {to_mode} is not supported."
)
) from err

return converted_box

Expand Down
6 changes: 4 additions & 2 deletions src/emma_policy/utils/masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def decompress_simbot_mask(
) -> Union[torch.Tensor, typing.NDArray[np.float64]]:
"""Decompress a compressed mask array.
Adopted from https://us-east-1.console.aws.amazon.com/codesuite/codecommit/repositories/AlexaSimbotMLToolbox/browse/refs/heads/main/--/AlexaSimbotToolbox/arena_wrapper/util/__init__.py?region=us-east-1
Adopted from
https://us-east-1.console.aws.amazon.com/codesuite/codecommit/repositories/AlexaSimbotMLToolbox/browse/refs/heads/main/--/AlexaSimbotToolbox/arena_wrapper/util/__init__.py?region=us-east-1
"""
mask = np.zeros((image_width, image_height))
for start_idx, run_len in compressed_mask:
Expand All @@ -29,7 +30,8 @@ def compress_simbot_mask(
) -> list[list[int]]:
"""Compress a binary 2D array mask for the simbot arena.
Adopted from https://us-east-1.console.aws.amazon.com/codesuite/codecommit/repositories/AlexaSimbotMLToolbox/browse/refs/heads/main/--/AlexaSimbotToolbox/arena_wrapper/util/__init__.py?region=us-east-1
Adopted from
https://us-east-1.console.aws.amazon.com/codesuite/codecommit/repositories/AlexaSimbotMLToolbox/browse/refs/heads/main/--/AlexaSimbotToolbox/arena_wrapper/util/__init__.py?region=us-east-1
"""
# list of lists of run lengths for 1s, which are assumed to be less frequent.
run_len_compressed: list[list[int]] = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def __init__(
rank: Optional[int] = None,
replacement: bool = True,
) -> None:

if num_replicas is None:
if not torch.distributed.is_available():
raise RuntimeError("Requires distributed package to be available")
Expand Down
1 change: 0 additions & 1 deletion tests/datamodules/test_datamodule_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def test_simbot_target_tokens(
action_text: str,
emma_tokenizer: EmmaTokenizer,
) -> None:

target_encoding = emma_tokenizer.encode_plus(target_text, return_tensors="pt", truncation=True)
full_target_token_ids = target_encoding.input_ids.squeeze(0)
target_token_ids = mask_past_target_actions(
Expand Down

0 comments on commit 17c1a14

Please sign in to comment.