diff --git a/opencv_transforms/functional.py b/opencv_transforms/functional.py index 7dcab3a..bdfc3f6 100644 --- a/opencv_transforms/functional.py +++ b/opencv_transforms/functional.py @@ -109,7 +109,7 @@ def resize(img, size, interpolation=cv2.INTER_LINEAR): output = cv2.resize(img, dsize=(ow, oh), interpolation=interpolation) else: output = cv2.resize(img, dsize=size[::-1], interpolation=interpolation) - if img.shape[2]==1: + if img.ndim==2: return(output[:,:,np.newaxis]) else: return(output) @@ -555,4 +555,4 @@ def to_grayscale(img, num_output_channels=1): elif num_output_channels==3: # much faster than doing cvtColor to go back to gray img = np.broadcast_to(cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)[:,:,np.newaxis], img.shape) - return img \ No newline at end of file + return img