Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
- fix download spell
- add push_to_hub option
- fix Optional type hinting
- apply single loop for DepthProImageProcessor.preprocess
  • Loading branch information
geetu040 committed Dec 21, 2024
1 parent 8f4c61f commit 8960535
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,5 @@ def __init__(
self.scaled_images_overlap_ratios = scaled_images_overlap_ratios
self.scaled_images_feature_dims = scaled_images_feature_dims


__all__ = ["DepthProConfig"]
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def write_model(
# Convert weights
# ------------------------------------------------------------

# downlaod and load state_dict from hf repo
# download and load state_dict from hf repo
file_path = hf_hub_download(hf_repo_id, "depth_pro.pt")
# file_path = "/home/geetu/work/hf/depth_pro/depth_pro.pt" # when you already have the files locally
loaded = torch.load(file_path, weights_only=True)
Expand Down Expand Up @@ -214,8 +214,9 @@ def write_model(
# Safety check: reload the converted model
gc.collect()
print("Reloading the model to check if it's saved correctly.")
DepthProForDepthEstimation.from_pretrained(output_dir, torch_dtype=torch.bfloat16, device_map="auto")
model = DepthProForDepthEstimation.from_pretrained(output_dir, device_map="auto")
print("Model reloaded successfully.")
return model


def write_image_processor(output_dir: str):
Expand All @@ -231,6 +232,7 @@ def write_image_processor(output_dir: str):
image_std=0.5,
)
image_processor.save_pretrained(output_dir)
return image_processor


def main():
Expand All @@ -243,23 +245,38 @@ def main():
parser.add_argument(
"--output_dir",
default="apple_DepthPro",
help="Location to write HF model and processor",
help="Location to write the converted model and processor",
)
parser.add_argument(
"--safe_serialization", default=True, type=bool, help="Whether or not to save using `safetensors`."
)
parser.add_argument(
"--push_to_hub",
default=True,
type=bool,
help="Whether or not to push the converted model to the huggingface hub.",
)
parser.add_argument(
"--hub_repo_id",
default="geetu040/DepthPro",
help="Huggingface hub repo to write the converted model and processor",
)
args = parser.parse_args()

write_model(
model = write_model(
hf_repo_id=args.hf_repo_id,
output_dir=args.output_dir,
safe_serialization=args.safe_serialization,
)

write_image_processor(
image_processor = write_image_processor(
output_dir=args.output_dir,
)

if args.push_to_hub:
model.push_to_hub(args.hub_repo_id)
image_processor.push_to_hub(args.hub_repo_id)


if __name__ == "__main__":
main()
61 changes: 28 additions & 33 deletions src/transformers/models/depth_pro/image_processing_depth_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ def resize(
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"])

# we use torch interpolation instead of image.resize because DepthProImageProcessor
# rescales, then normalizes, which may cause some values to become negative, before resizing the image.
# image.resize expects all values to be in range [0, 1] or [0, 255] and throws an exception otherwise,
# however pytorch interpolation works with negative values.
# relevant issue here: https://github.com/huggingface/transformers/issues/34920
return (
torch.nn.functional.interpolate(
# input should be (B, C, H, W)
Expand All @@ -182,9 +187,6 @@ def _validate_input_arguments(
image_std: Union[float, List[float]],
data_format: Union[str, ChannelDimension],
):
if data_format != ChannelDimension.FIRST:
raise ValueError("Only channel first data format is currently supported.")

if do_resize and None in (size, resample, antialias):
raise ValueError("Size, resample and antialias must be specified if do_resize is True.")

Expand All @@ -199,8 +201,8 @@ def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
size: Optional[Dict[str, int]] = None,
resample: Optional[PILImageResampling] = None,
antialias: Optional[bool] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
Expand Down Expand Up @@ -302,36 +304,28 @@ def preprocess(
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])

if do_rescale:
images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]

if do_normalize:
images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]

images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]

# depth-pro scales the image before resizing it
# uses torch interpolation which requires ChannelDimension.FIRST
if do_resize:
images = [
self.resize(
image=image,
size=size,
resample=resample,
antialias=antialias,
all_images = []
for image in images:
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)

if do_normalize:
image = self.normalize(
image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
)
for image in images
]

data = {"pixel_values": images}
# depth-pro rescales and normalizes the image before resizing it
# uses torch interpolation which requires ChannelDimension.FIRST
if do_resize:
image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
image = self.resize(image=image, size=size, resample=resample, antialias=antialias)
image = to_channel_dimension_format(image, data_format, input_channel_dim=ChannelDimension.FIRST)
else:
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)

all_images.append(image)

data = {"pixel_values": all_images}
return BatchFeature(data=data, tensor_type=return_tensors)

def post_process_depth_estimation(
Expand Down Expand Up @@ -408,4 +402,5 @@ def post_process_depth_estimation(

return outputs


__all__ = ["DepthProImageProcessor"]
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
size: Optional[Dict[str, int]] = None,
resample: Optional[PILImageResampling] = None,
antialias: Optional[bool] = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
Expand Down Expand Up @@ -381,4 +381,5 @@ def post_process_depth_estimation(

return outputs


__all__ = ["DepthProImageProcessorFast"]
1 change: 1 addition & 0 deletions src/transformers/models/depth_pro/modeling_depth_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,4 +1675,5 @@ def forward(
attentions=depth_pro_outputs.attentions,
)


__all__ = ["DepthProPreTrainedModel", "DepthProModel", "DepthProForDepthEstimation"]

0 comments on commit 8960535

Please sign in to comment.