Skip to content

Commit

Permalink
revised a minor issue in comment and formatted coco.py with the Black…
Browse files Browse the repository at this point in the history
… formatter. (#361)
  • Loading branch information
dixiyao authored Oct 25, 2023
1 parent dae5776 commit 39ea831
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions plato/datasources/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
- annotations_trainval2017: captions for train/val
The data structure under the 'data/' is:
├── COCO2017 # root dir of Flickr30K Entities dataset
├── COCO2017 # root dir of COCO2017 Entities dataset
│ ├── COCO2017Raw # Raw images/annotations and the official splits
│ │ └── annotations
│ │ └── train2017
Expand All @@ -54,7 +54,7 @@


class DataSource(multimodal_base.MultiModalDataSource):
""" The COCO dataset."""
"""The COCO dataset."""

def __init__(self, **kwargs):
super().__init__()
Expand All @@ -64,7 +64,7 @@ def __init__(self, **kwargs):

self.modality_names = ["image", "text"]

_path = Config().params['data_path']
_path = Config().params["data_path"]
self._data_path_process(data_path=_path, base_data_name=self.data_name)

base_data_path = self.mm_data_info["data_path"]
Expand All @@ -81,7 +81,7 @@ def __init__(self, **kwargs):
splits_downalods = {
"train": download_train_url,
"test": download_test_url,
"val": download_val_url
"val": download_val_url,
}

# Download raw data and extract to different splits
Expand All @@ -91,27 +91,29 @@ def __init__(self, **kwargs):
split_file_name = self._download_arrange_data(
download_url_address=split_download_url,
data_path=raw_data_path,
extract_to_dir=split_path)
extract_to_dir=split_path,
)
# renaming the extracted file to "images"
extracted_path = os.path.join(split_path, split_file_name)
renamed_path = os.path.join(split_path, "images")
os.rename(src=extracted_path, dst=renamed_path)

# Download the annotation
self._download_arrange_data(
download_url_address=download_annotation_url,
data_path=raw_data_path)
download_url_address=download_annotation_url, data_path=raw_data_path
)
annotation_path = os.path.join(raw_data_path, "annotations")

# Move the annotation to each split
splits_caption_name = {
"train": "captions_train2017.json",
"val": "captions_val2017.json"
"val": "captions_val2017.json",
}
for split_name in list(splits_caption_name.keys()):
split_caption_name = splits_caption_name[split_name]
to_split_path = os.path.join(self.splits_info[split_name]["path"],
"captions.json")
shutil.copyfile(src=os.path.join(annotation_path,
split_caption_name),
dst=to_split_path)
to_split_path = os.path.join(
self.splits_info[split_name]["path"], "captions.json"
)
shutil.copyfile(
src=os.path.join(annotation_path, split_caption_name), dst=to_split_path
)

0 comments on commit 39ea831

Please sign in to comment.