From 6daa89e6fab8e706ea3b31235f9c70f9d2b9026d Mon Sep 17 00:00:00 2001 From: sronilsson Date: Mon, 9 Sep 2024 11:10:03 -0400 Subject: [PATCH] roi pre-set sizes --- simba/assets/icons/half_circle.png | Bin 0 -> 835 bytes simba/assets/icons/hexagon.png | Bin 0 -> 891 bytes simba/assets/icons/roi_green.png | Bin 0 -> 1217 bytes simba/assets/icons/size_black.png | Bin 0 -> 871 bytes simba/assets/icons/square_black.png | Bin 0 -> 759 bytes .../cuda/circular_statistics.py | 677 ++++++++++++++++++ simba/data_processors/cuda/geometry.py | 4 +- simba/data_processors/cuda/image.py | 1 - simba/data_processors/cuda/statistics.py | 14 +- simba/mixins/config_reader.py | 17 +- simba/mixins/feature_extraction_mixin.py | 18 +- simba/roi_tools/ROI_define.py | 25 +- simba/roi_tools/ROI_image.py | 21 +- simba/ui/pop_ups/roi_fixed_size_pop_up.py | 323 +++++++++ simba/utils/checks.py | 29 +- tests/test_roi_tools.py | 2 +- 16 files changed, 1069 insertions(+), 62 deletions(-) create mode 100644 simba/assets/icons/half_circle.png create mode 100644 simba/assets/icons/hexagon.png create mode 100644 simba/assets/icons/roi_green.png create mode 100644 simba/assets/icons/size_black.png create mode 100644 simba/assets/icons/square_black.png create mode 100644 simba/data_processors/cuda/circular_statistics.py create mode 100644 simba/ui/pop_ups/roi_fixed_size_pop_up.py diff --git a/simba/assets/icons/half_circle.png b/simba/assets/icons/half_circle.png new file mode 100644 index 0000000000000000000000000000000000000000..36059811b962b7d7108c7c4318fd849b7fd129b8 GIT binary patch literal 835 zcmeAS@N?(olHy`uVBq!ia0vp^{2aVqgWc85q16rQz%#Mh&PMCI*J~Oa>OHnkXO*0v1AfjnEKjFOT9D}DX)@^Za$W4-*MbbUihOG|wN zBYh(yU7!lx;>x^|#0uTKVr7USFmqf|i<65o3raHc^AtelCMM;Vme?vOfh>Xph&xL% z(-1c06+^uR^q@XSM&D4+Kp$>4P^%3{)XKjoGZknv$b36P8?Z_gF{nK@`XI}Z90Tzw zSQO}0J1!f2c(B=VNya^72ZrQlPZ!4!4bioe?D-rW1=|1ix3ma*#5gJ!GEDV*%H?;! zwnciy7bmBH9c?bFmj*ug{q}ZczOrMJ=f1f!@1HxfFE(hUyw2NdYA;US^5Tij zm4z(jKT=-pJfL{|QiG<9um7D^qnL#Z=`#+pq^!F)EBRoqtGDcO#w;6`NzsR!nJl@!c0CMOdFt`? nsS4G4U00q=)xAF7`OkT|Im?v{SVd1XfKrmDtDnm{r-UW|E7Sj| literal 0 HcmV?d00001 diff --git a/simba/assets/icons/hexagon.png b/simba/assets/icons/hexagon.png new file mode 100644 index 0000000000000000000000000000000000000000..ee0452911bf49502f41e6f9395329d74dfd0f687 GIT binary patch literal 891 zcmeAS@N?(olHy`uVBq!ia0vp^{2aVqgWc85q16rQz%#Mh&PMCI*J~Oa>OHnkXO*0v1AfjnEKjFOT9D}DX)@^Za$W4-*MbbUihOG|wN zBYh(yU7!lx;>x^|#0uTKVr7USFmqf|i<65o3raHc^AtelCMM;Vme?vOfh>Xph&xL% z(-1c06+^uR^q@XSM&D4+Kp$>4P^%3{)XKjoGZknv$b36P8?Z_gF{nK@`XI}Z90Tzw zSQO}0J1!f2c(B=VNya^72L`gXr;B5VhUnbMwqDMTB5vyT8A6WS8#j8SPc{k_V2W%u zYqH+3W&0_)JBX`)jxCQLLlnu^{ zM~lDWdm*7CA+cuSzSqxI#)vmpA3Oa+~8-L{SSgzR?7*HdagpOKHQVLdOf`}m$+scTlenVr6`G*G?g zzg6+7{@63m@9nt#bKjo*?mu47i`9R3^KbMY#(6I}`5sCw3;5R>D{y`xx7qr<#GR*w vABEp&w3AtSGoqg9={n(mC)VfxS^tYEXq~Ci;quZdPzv;P^>bP0l+XkKn!+4x literal 0 HcmV?d00001 diff --git a/simba/assets/icons/roi_green.png b/simba/assets/icons/roi_green.png new file mode 100644 index 0000000000000000000000000000000000000000..2a0c806e062b61450a53a2a2d716cf1e3fc4726c GIT binary patch literal 1217 zcmeAS@N?(olHy`uVBq!ia0vp^{2aVqgWc85q16rQz%#Mh&PMCI*J~Oa>OHnkXO*0v1AfjnEKjFOT9D}DX)@^Za$W4-*MbbUihOG|wN zBYh(yU7!lx;>x^|#0uTKVr7USFmqf|i<65o3raHc^AtelCMM;Vme?vOfh>Xph&xL% z(-1c06+^uR^q@XSM&D4+Kp$>4P^%3{)XKjoGZknv$b36P8?Z_gF{nK@`XI}Z90Tzw zSQO}0J1!f2c(B=VNya^7XJB9|^mK6y(GblIw#^7}6!@1`_FjGZ?(a=*Op08ZOjY*@|P034$Q{~a_$I^|Z=-=Y}%3k7rJP36pr4%KCL# zYTfz#)#W=JCk7gG{n24jxYPA}>oJBKOV3@*H2mUsF*)_FfT7u@r(cxi7O^_E%zmx8 zZo!f2#hmTGcs3uKp={6KZaDG(kptCF?VcgBga*J-|4Vs^Dl;emlxc=H~{|p5G z9`X6x{iR{z0+z`-6FdKTOckHAK92R3QS1u!O*&jnJvY~%(0wJggMDrI+cnhz4m!nR zn*zVY%n5G{_1wpuDa@XsUS@eJ^wt%(D=8&*IosZsO|V_Iy1nF>_~xWC|DuzT8#&JY zbmn9`F~{lhvQ;jZ&5DaJZKx^xDXSl~qG!XxO@}1?_vz#oJ({q8!JU~0Legd$hH)+F a{?D+^@`us7S3z?@`PkFd&t;ucLK6VHe7gYv literal 0 HcmV?d00001 diff --git a/simba/assets/icons/size_black.png b/simba/assets/icons/size_black.png new file mode 100644 index 0000000000000000000000000000000000000000..d69d2e657567d7ebc522178fdf4201495f0fb023 GIT binary patch literal 871 zcmeAS@N?(olHy`uVBq!ia0vp^{2aVqgWc85q16rQz%#Mh&PMCI*J~Oa>OHnkXO*0v1AfjnEKjFOT9D}DX)@^Za$W4-*MbbUihOG|wN zBYh(yU7!lx;>x^|#0uTKVr7USFmqf|i<65o3raHc^AtelCMM;Vme?vOfh>Xph&xL% z(-1c06+^uR^q@XSM&D4+Kp$>4P^%3{)XKjoGZknv$b36P8?Z_gF{nK@`XI}Z90Tzw zSQO}0J1!f2c(B=VNya^72L`gZr;B5VhUn5smVPXT0Q&-14qM9{(A{&$tfu#in{U;pGKP{skY!C;XSTD@gCv{#IeSAd2LHF% z#=myc6^pdJS4-C!uvmFMD48d`c=^A=dmn7;?^^cnD0^PN_uH5GpI2#WTm7Homi76K zi10i`-VI?=-`3=+JZ|HQ{>D}wcFExGtMs0`%y$xY9^BAWyQ7Ky#7j-{sNMgvZ!}*u Y?|G|vt<-(89VoSVy85}Sb4q9e0JVA&6#xJL literal 0 HcmV?d00001 diff --git a/simba/assets/icons/square_black.png b/simba/assets/icons/square_black.png new file mode 100644 index 0000000000000000000000000000000000000000..ee05703cbbd44a6fc0a8bdb08118ec8216acfaf3 GIT binary patch literal 759 zcmeAS@N?(olHy`uVBq!ia0vp^{2aVqgWc85q16rQz%#Mh&PMCI*J~Oa>OHnkXO*0v1AfjnEKjFOT9D}DX)@^Za$W4-*MbbUihOG|wN zBYh(yU7!lx;>x^|#0uTKVr7USFmqf|i<65o3raHc^AtelCMM;Vme?vOfh>Xph&xL% z(-1c06+^uR^q@XSM&D4+Kp$>4P^%3{)XKjoGZknv$b36P8?Z_gF{nK@`XI}Z90Tzw zSQO}0J1!f2c(B=VNya^72ZrQwPZ!4!4bi)k4{|jd2)JYm>bKUdTwd3qzCYx}(vGru zDUU97WUh~u*tE|0W{%A4({gX;g>9B}o1D>nZTAO}bt}uZ=xE(*@vOMOtMI0EdVp~G zqaLT%i!TW(S~qWQV7D({INQ2P33Xn#d@U-1u-TZ^S`$A4uzqA9 x.shape[0]: + return + else: + a = math.atan2(x[i][0] - y[i][0], y[i][1] - x[i][1]) * (180 / math.pi) + a = int32(a + 360 if a < 0 else a) + results[i] = a + + +def direction_from_two_bps(x: np.ndarray, y: np.ndarray) -> np.ndarray: + """ + Compute the directionality in degrees from two body-parts. E.g., ``nape`` and ``nose``, + or ``swim_bladder`` and ``tail`` with GPU acceleration. + + .. csv-table:: + :header: EXPECTED RUNTIMES + :file: ../../../docs/tables/direction_two_bps.csv + :widths: 10, 90 + :align: center + :header-rows: 1 + + .. seealso:: + For CPU function see :func:`~simba.mixins.circular_statistics.CircularStatisticsMixin.direction_two_bps`. + + :parameter np.ndarray x: Size len(frames) x 2 representing x and y coordinates for first body-part. + :parameter np.ndarray y: Size len(frames) x 2 representing x and y coordinates for second body-part. + :return: Frame-wise directionality in degrees. + :rtype: np.ndarray. + + """ + x = np.ascontiguousarray(x).astype(np.int32) + y = np.ascontiguousarray(y).astype(np.int32) + x_dev = cuda.to_device(x) + y_dev = cuda.to_device(y) + results = cuda.device_array((x.shape[0]), dtype=np.int32) + bpg = (x.shape[0] + (THREADS_PER_BLOCK - 1)) // THREADS_PER_BLOCK + _cuda_direction_from_two_bps[bpg, THREADS_PER_BLOCK](x_dev, y_dev, results) + results = results.copy_to_host() + return results + + +def sliding_circular_hotspots(x: np.ndarray, + time_window: float, + sample_rate: float, + bins: np.ndarray, + batch_size: Optional[int] = int(3.5e+7)) -> np.ndarray: + """ + Calculate the proportion of data points falling within specified circular bins over a sliding time window using GPU + + This function processes time series data representing angles (in degrees) and calculates the proportion of data + points within specified angular bins over a sliding window. The calculations are performed in batches to + accommodate large datasets efficiently. + + .. csv-table:: + :header: EXPECTED RUNTIMES + :file: ../../../docs/tables/sliding_circular_hotspots.csv + :widths: 10, 45, 45 + :align: center + :header-rows: 1 + + .. seealso:: + For CPU function see :func:`~simba.mixins.circular_statistics.CircularStatisticsMixin.sliding_circular_hotspots`. + + + :param np.ndarray x: The input time series data in degrees. Should be a 1D numpy array. + :param float time_window: The size of the sliding window in seconds. + :param float sample_rate: The sample rate of the time series data (i.e., hz, fps). + :param ndarray bins: 2D array of shape representing circular bins defining [start_degree, end_degree] inclusive. + :param Optional[int] batch_size: The size of each batch for processing the data. Default is 5e+7 (50m). + :return: A 2D numpy array where each row corresponds to a time point in `data`, and each column represents a circular bin. The values in the array represent the proportion of data points within each bin at each time point. The first column represents the first bin. + :rtype: np.ndarray + """ + + n = x.shape[0] + x = cp.asarray(x, dtype=cp.float16) + results = cp.full((x.shape[0], bins.shape[0]), dtype=cp.float16, fill_value=-1) + window_size = int(cp.ceil(time_window * sample_rate)) + for cnt, left in enumerate(range(0, n, batch_size)): + right = int(min(left + batch_size, n)) + if cnt > 0: + left = left - window_size + 1 + x_batch = x[left:right] + x_batch = cp.lib.stride_tricks.sliding_window_view(x_batch, window_size).astype(cp.float16) + batch_results = cp.full((x_batch.shape[0], bins.shape[0]), dtype=cp.float16, fill_value=-1) + for bin_cnt in range(bins.shape[0]): + if bins[bin_cnt][0] > bins[bin_cnt][1]: + mask = ((x_batch >= bins[bin_cnt][0]) & (x_batch <= 360)) | ((x_batch >= 0) & (x_batch <= bins[bin_cnt][1])) + else: + mask = (x_batch >= bins[bin_cnt][0]) & (x_batch <= bins[bin_cnt][1]) + count_per_row = cp.array(mask.sum(axis=1) / window_size).reshape(-1, ) + batch_results[:, bin_cnt] = count_per_row + results[left + window_size - 1:right, ] = batch_results + return results.get() + +def sliding_circular_mean(x: np.ndarray, + time_window: float, + sample_rate: int, + batch_size: Optional[int] = 3e+7) -> np.ndarray: + + """ + Calculate the sliding circular mean over a time window for a series of angles. + + This function computes the circular mean of angles in the input array `x` over a specified sliding window. + The circular mean is a measure of the average direction for angles, which is especially useful for angular data + where traditional averaging would not be meaningful due to the circular nature of angles (e.g., 359° and 1° should average to 0°). + + The calculation is performed using a sliding window approach, where the circular mean is computed for each window + of angles. The function leverages GPU acceleration via CuPy for efficiency when processing large datasets. + + The circular mean :math:`\\mu` for a set of angles is calculated using the following formula: + + .. math:: + + \\mu = \\text{atan2}\\left(\\frac{1}{N} \\sum_{i=1}^{N} \\sin(\\theta_i), \\frac{1}{N} \\sum_{i=1}^{N} \\cos(\\theta_i)\\right) + + - :math:`\\theta_i` are the angles in radians within the sliding window + - :math:`N` is the number of samples in the window + + + .. csv-table:: + :header: EXPECTED RUNTIMES + :file: ../../../docs/tables/sliding_circular_mean.csv + :widths: 10, 45, 45 + :align: center + :header-rows: 1 + + .. seealso:: + For CPU function see :func:`~simba.mixins.circular_statistics.CircularStatisticsMixin.sliding_circular_mean`. + + :param np.ndarray x: Input array containing angle values in degrees. The array should be 1-dimensional. + :param float time_window: Time duration for the sliding window, in seconds. This determines the number of samples in each window based on the `sample_rate`. + :param int sample_rate: The number of samples per second (i.e., FPS). This is used to calculate the window size in terms of array indices. + :param Optional[int] batch_size: The maximum number of elements to process in each batch. This is used to handle large arrays by processing them in chunks to avoid memory overflow. Defaults to 3e+7 (30 million elements). + :return np.ndarray: A 1D numpy array of the same length as `x`, containing the circular mean for each sliding window. Values before the window is fully populated will be set to -1. + + :example: + >>> x = np.random.randint(0, 361, (i, )).astype(np.int32) + >>> results = sliding_circular_mean(x, 1, 10) + """ + + + window_size = np.ceil(time_window * sample_rate).astype(np.int64) + n = x.shape[0] + results = cp.full(x.shape[0], -1, dtype=np.int32) + for cnt, left in enumerate(range(0, int(n), int(batch_size))): + right = np.int32(min(left + batch_size, n)) + if cnt > 0: + left = left - window_size+1 + x_batch = cp.asarray(x[left:right]) + x_batch = cp.lib.stride_tricks.sliding_window_view(x_batch, window_size) + x_batch = np.deg2rad(x_batch) + cos, sin = cp.cos(x_batch).astype(np.float32), cp.sin(x_batch).astype(np.float32) + r = cp.rad2deg(cp.arctan2(cp.mean(sin, axis=1), cp.mean(cos, axis=1))) + r = cp.where(r < 0, r + 360, r) + results[left + window_size - 1:right] = r + return results.get() + + + +def sliding_circular_range(x: np.ndarray, + time_window: float, + sample_rate: float, + batch_size: Optional[int] = int(5e+7)) -> np.ndarray: + """ + Computes the sliding circular range of a time series data array using GPU. + + This function calculates the circular range of a time series data array using a sliding window approach. + The input data is assumed to be in degrees, and the function handles the circular nature of the data + by considering the circular distance between angles. + + .. math:: + + R = \\min \\left( \\text{max}(\\Delta \\theta) - \\text{min}(\\Delta \\theta), \\, 360 - \\text{max}(\\Delta \\theta) + \\text{min}(\\Delta \\theta) \\right) + + where: + + - :math:`\\Delta \\theta` is the difference between angles within the window, + - :math:`360` accounts for the circular nature of the data (i.e., wrap-around at 360 degrees). + + .. csv-table:: + :header: EXPECTED RUNTIMES + :file: ../../../docs/tables/sliding_circular_range.csv + :widths: 10, 45, 45 + :align: center + :header-rows: 1 + + .. seealso:: + For CPU function see :func:`~simba.mixins.circular_statistics.CircularStatisticsMixin.sliding_circular_range`. + + :param np.ndarray x: The input time series data in degrees. Should be a 1D numpy array. + :param float time_window: The size of the sliding window in seconds. + :param float sample_rate: The sample rate of the time series data (i.e., hz, fps). + :param Optional[int] batch_size: The size of each batch for processing the data. Default is 5e+7 (50m). + :return: A numpy array containing the sliding circular range values. + :rtype: np.ndarray + + :example: + >>> x = np.random.randint(0, 361, (19, )).astype(np.int32) + >>> p = sliding_circular_range(x, 1, 10) + """ + + n = x.shape[0] + x = cp.asarray(x, dtype=cp.float16) + results = cp.zeros_like(x, dtype=cp.int16) + x = cp.deg2rad(x).astype(cp.float16) + window_size = int(cp.ceil(time_window * sample_rate)) + for cnt, left in enumerate(range(0, n, batch_size)): + right = int(min(left + batch_size, n)) + if cnt > 0: + left = left - window_size + 1 + x_batch = x[left:right] + x_batch = cp.lib.stride_tricks.sliding_window_view(x_batch, window_size).astype(cp.float16) + x_batch = cp.sort(x_batch) + results[left + window_size - 1:right] = cp.abs(cp.rint(cp.rad2deg(cp.amin(cp.vstack([x_batch[:, -1] - x_batch[:, 0], 2 * cp.pi - cp.max(cp.diff(x_batch), axis=1)]).T, axis=1)))) + return results.get() + + + + +def sliding_circular_std(x: np.ndarray, + time_window: float, + sample_rate: float, + batch_size: Optional[int] = int(5e+7)) -> np.ndarray: + + """ + Calculate the sliding circular standard deviation of a time series data on GPU. + + This function computes the circular standard deviation over a sliding window for a given time series array. + The time series data is assumed to be in degrees, and the function converts it to radians for computation. + The sliding window approach is used to handle large datasets efficiently, processing the data in batches. + + The circular standard deviation (σ) is computed using the formula: + + .. math:: + + \sigma = \sqrt{-2 \cdot \log \left|\text{mean}\left(\exp(i \cdot x_{\text{batch}})\right)\right|} + + where :math:`x_{\text{batch}}` is the data within the current sliding window, and :math:`\text{mean}` and + :math:`\log` are computed in the circular (complex plane) domain. + + .. csv-table:: + :header: EXPECTED RUNTIMES + :file: ../../../docs/tables/sliding_circular_std.csv + :widths: 10, 45, 45 + :align: center + :header-rows: 1 + + .. seealso:: + For CPU function see :func:`~simba.mixins.circular_statistics.CircularStatisticsMixin.sliding_circular_std`. + + :param np.ndarray x: The input time series data in degrees. Should be a 1D numpy array. + :param float time_window: The size of the sliding window in seconds. + :param float sample_rate: The sample rate of the time series data (i.e., hz, fps). + :param Optional[int] batch_size: The size of each batch for processing the data. Default is 5e+7 (50m). + + :return: A numpy array containing the sliding circular standard deviation values. + :rtype: np.ndarray + """ + + + n = x.shape[0] + x = cp.asarray(x, dtype=cp.float16) + results = cp.zeros_like(x, dtype=cp.float16) + x = np.deg2rad(x).astype(cp.float16) + window_size = int(np.ceil(time_window * sample_rate)) + for cnt, left in enumerate(range(0, n, batch_size)): + right = int(min(left + batch_size, n)) + if cnt > 0: + left = left - window_size + 1 + x_batch = x[left:right] + x_batch = cp.lib.stride_tricks.sliding_window_view(x_batch, window_size).astype(cp.float16) + m = cp.log(cp.abs(cp.mean(cp.exp(1j * x_batch), axis=1))) + stdev = cp.rad2deg(cp.sqrt(-2 * m)) + results[left + window_size - 1:right] = stdev + + return results.get() + + +def sliding_rayleigh_z(x: np.ndarray, + time_window: float, + sample_rate: float, + batch_size: Optional[int] = int(5e+7)) -> Tuple[np.ndarray, np.ndarray]: + + """ + Computes the Rayleigh Z-statistic over a sliding window for a given time series of angles + + This function calculates the Rayleigh Z-statistic, which tests the null hypothesis that the population of angles + is uniformly distributed around the circle. The calculation is performed over a sliding window across the input + time series, and results are computed in batches for memory efficiency. + + Data is processed using GPU acceleration via CuPy, which allows for faster computation compared to a CPU-based approach. + + .. note:: + Adapted from ``pingouin.circular.circ_rayleigh`` and ``pycircstat.tests.rayleigh``. + + + **Rayleigh Z-statistic:** + + The Rayleigh Z-statistic is given by: + + .. math:: + + R = \frac{1}{n} \sqrt{\left(\sum_{i=1}^{n} \cos(\theta_i)\right)^2 + \left(\sum_{i=1}^{n} \sin(\theta_i)\right)^2} + + where: + - :math:`\theta_i` are the angles in the window. + - :math:`n` is the number of angles in the window. + + + .. csv-table:: + :header: EXPECTED RUNTIMES + :file: ../../../docs/tables/sliding_rayleigh_z.csv + :widths: 10, 45, 45 + :align: center + :header-rows: 1 + + .. seealso:: + For CPU function see :func:`~simba.mixins.circular_statistics.CircularStatisticsMixin.sliding_rayleigh_z`. + + + :param np.ndarray x: Input array of angles in degrees. Should be a 1D numpy array. + :param float time_window: The size of the sliding window in time units (e.g., seconds). + :param float sample_rate: The sampling rate of the input time series in samples per time unit (e.g., Hz, fps). + :param Optional[int] batch_size: The number of samples to process in each batch. Default is 5e7 (50m). Reducing this value may save memory at the cost of longer computation time. + :return: + A tuple containing two numpy arrays: + - **z_results**: Rayleigh Z-statistics for each position in the input array where the window was fully applied. + - **p_results**: Corresponding p-values for the Rayleigh Z-statistics. + :rtype: Tuple[np.ndarray, np.ndarray] + """ + + n = x.shape[0] + x = cp.asarray(x, dtype=cp.float16) + z_results = cp.zeros_like(x, dtype=cp.float16) + p_results = cp.zeros_like(x, dtype=cp.float16) + x = np.deg2rad(x).astype(cp.float16) + window_size = int(np.ceil(time_window * sample_rate)) + for cnt, left in enumerate(range(0, n, batch_size)): + right = int(min(left + batch_size, n)) + if cnt > 0: + left = left - window_size + 1 + x_batch = x[left:right] + x_batch = cp.lib.stride_tricks.sliding_window_view(x_batch, window_size).astype(cp.float16) + cos_sums = cp.nansum(cp.cos(x_batch), axis=1) ** 2 + sin_sums = cp.nansum(cp.sin(x_batch), axis=1) ** 2 + R = cp.sqrt(cos_sums + sin_sums) / window_size + Z = window_size * (R**2) + P = cp.exp(np.sqrt(1 + 4 * window_size + 4 * (window_size ** 2 - R ** 2)) - (1 + 2 * window_size)) + z_results[left + window_size - 1:right] = Z + p_results[left + window_size - 1:right] = P + + return z_results.get(), p_results.get() + + +def sliding_resultant_vector_length(x: np.ndarray, + time_window: float, + sample_rate: int, + batch_size: Optional[int] = 3e+7) -> np.ndarray: + + """ + Calculate the sliding resultant vector length over a time window for a series of angles. + + This function computes the resultant vector length (R) for each window of angles in the input array `x`. + The resultant vector length is a measure of the concentration of angles, and it ranges from 0 to 1, where 1 + indicates all angles point in the same direction, and 0 indicates uniform distribution of angles. + + For a given sliding window of angles, the resultant vector length :math:`R` is calculated using the following formula: + + .. math:: + + R = \\frac{1}{N} \\sqrt{\\left(\\sum_{i=1}^{N} \\cos(\\theta_i)\\right)^2 + \\left(\\sum_{i=1}^{N} \\sin(\\theta_i)\\right)^2} + + where: + + - :math:`\\theta_i` are the angles in radians within the sliding window + - :math:`N` is the number of samples in the window + + The computation is performed in a sliding window manner over the entire array, utilizing GPU acceleration + with CuPy for efficiency, especially on large datasets. + + + .. csv-table:: + :header: EXPECTED RUNTIMES + :file: ../../../docs/tables/sliding_resultant_vector_length.csv + :widths: 10, 10, 80 + :align: center + :header-rows: 1 + + .. seealso:: + For CPU function see :func:`~simba.mixins.circular_statistics.CircularStatisticsMixin.sliding_resultant_vector_length`. + + :param np.ndarray x: Input array containing angle values in degrees. The array should be 1-dimensional. + :param float time_window: Time duration for the sliding window, in seconds. This determines the number of samples in each window based on the `sample_rate`. + :param int sample_rate: The number of samples per second (i.e., FPS). This is used to calculate the window size in terms of array indices. + :param Optional[int] batch_size: The maximum number of elements to process in each batch. This is used to handle large arrays by processing them in chunks to avoid memory overflow. Defaults to 3e+7 (30 million elements). + :return np.ndarray: A 1D numpy array of the same length as `x`, containing the resultant vector length for each sliding window. Values before the window is fully populated will be set to -1. + + :example: + >>> x = np.random.randint(0, 361, (5000, )).astype(np.int32) + >>> results = sliding_resultant_vector_length(x, 1, 10) + """ + + window_size = np.ceil(time_window * sample_rate).astype(np.int64) + n = x.shape[0] + results = cp.full(x.shape[0], -1, dtype=np.float32) + for cnt, left in enumerate(range(0, int(n), int(batch_size))): + right = np.int32(min(left + batch_size, n)) + if cnt > 0: + left = left - window_size+1 + x_batch = cp.asarray(x[left:right]) + x_batch = cp.lib.stride_tricks.sliding_window_view(x_batch, window_size) + x_batch = np.deg2rad(x_batch) + cos, sin = cp.cos(x_batch).astype(np.float32), cp.sin(x_batch).astype(np.float32) + cos_sum, sin_sum = cp.sum(cos, axis=1), cp.sum(sin, axis=1) + r = np.sqrt(cos_sum ** 2 + sin_sum ** 2) / window_size + results[left+window_size-1:right] = r + return results.get() + + +def direction_from_three_bps(x: np.ndarray, + y: np.ndarray, + z: np.ndarray, + batch_size: Optional[int] = int(1.5e+7)) -> np.ndarray: + + """ + Calculate the direction angle based on the coordinates of three body points using GPU acceleration. + + This function computes the mean direction angle (in degrees) for a batch of coordinates + provided in the form of NumPy arrays. The calculation is based on the arctangent of the + difference in x and y coordinates between pairs of points. The result is a value in + the range [0, 360) degrees. + + .. seealso:: + :func:`simba.mixins.circular_statistics.CircularStatisticsMixin.direction_three_bps` + + :param np.ndarray x: A 2D array of shape (N, 2) containing the x-coordinates of the first body part (nose) + :param np.ndarray y: A 2D array of shape (N, 2) containing the coordinates of the second body part (left ear). + :param np.ndarray z: A 2D array of shape (N, 2) containing the coordinates of the second body part (right ear). + :param Optional[int] batch_size: The size of the batch to be processed in each iteration. Default is 15 million. + :return: An array of shape (N,) containing the computed direction angles in degrees. + :rtype np.ndarray: + """ + + check_valid_array(data=x, source=direction_from_three_bps.__name__, accepted_ndims=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value) + check_valid_array(data=y, source=direction_from_three_bps.__name__, accepted_shapes=(x.shape,), accepted_dtypes=Formats.NUMERIC_DTYPES.value) + check_valid_array(data=z, source=direction_from_three_bps.__name__, accepted_shapes=(x.shape,), accepted_dtypes=Formats.NUMERIC_DTYPES.value) + check_int(value=batch_size, name=direction_from_three_bps.__name__, min_value=1) + results = cp.full((x.shape[0]), fill_value=-1, dtype=np.int16) + + for l in range(0, x.shape[0], batch_size): + r = l + batch_size + x_batch = cp.array(x[l:r]) + y_batch = cp.array(y[l:r]) + z_batch = cp.array(z[l:r]) + left_ear_to_nose = cp.arctan2(x_batch[:, 0] - y_batch[:, 0], y_batch[:, 1] - x_batch[:,1]) + right_ear_nose = cp.arctan2(x_batch[:, 0] - z_batch[:, 0], z_batch[:, 1] - x_batch[:, 1]) + mean_angle_rad = cp.arctan2(cp.sin(left_ear_to_nose) + cp.sin(right_ear_nose), cp.cos(left_ear_to_nose) + cp.cos(right_ear_nose)) + results[l:r] = (cp.degrees(mean_angle_rad) + 360) % 360 + + return results.get() + + +@cuda.jit() +def _instantaneous_angular_velocity(x, stride, results): + r = cuda.grid(1) + l = np.int32(r - (stride[0])) + if (r > results.shape[0]) or (l < 0): + results[r] = -1 + else: + d = math.pi - (abs(math.pi - abs(x[l] - x[r]))) + results[r] = d * (180 / math.pi) + + +def instantaneous_angular_velocity(x: np.ndarray, stride: Optional[int] = 1) -> np.ndarray: + """ + Calculate the instantaneous angular velocity between angles in a given array. + + This function uses CUDA to perform parallel computations on the GPU. + + The angular velocity is computed using the difference in angles between + the current and previous values (with a specified stride) in the array. + The result is returned in degrees per unit time. + + .. csv-table:: + :header: EXPECTED RUNTIMES + :file: ../../../docs/tables/instantaneous_angular_velocity.csv + :widths: 10, 90 + :align: center + :header-rows: 1 + + .. math:: + \omega = \frac{{\Delta \theta}}{{\Delta t}} = \frac{{180}}{{\pi}} \times \left( \pi - \left| \pi - \left| \theta_r - \theta_l \right| \right| \right) + + where: + - \( \theta_r \) is the current angle. + - \( \theta_l \) is the angle at the specified stride before the current angle. + - \( \Delta t \) is the time difference between the two angles. + + + .. seealso:: + :func:`simba.mixins.circular_statistics.CircularStatisticsMixin.instantaneous_angular_velocity` + + :param np.ndarray x: Array of angles in degrees, for which the instantaneous angular velocity will be calculated. + :param Optional[int] stride: The stride or lag (in frames) to use when calculating the difference in angles. Defaults to 1. + :return: Array of instantaneous angular velocities corresponding to the input angles. Velocities are in degrees per unit time. + :rtype: np.ndarray + """ + + x = np.deg2rad(x).astype(np.int16) + stride = np.array([stride]).astype(np.int64) + bpg = (x.shape[0] + (THREADS_PER_BLOCK - 1)) // THREADS_PER_BLOCK + x_dev = cuda.to_device(x) + stride_dev = cuda.to_device(stride) + results = cuda.device_array(x.shape[0], dtype=np.float32) + _instantaneous_angular_velocity[bpg, THREADS_PER_BLOCK](x_dev, stride_dev, results) + return results.copy_to_host() + + +@cuda.jit(device=True) +def _rad2deg(x): + return x * (180/math.pi) + + +@cuda.jit() +def _sliding_bearing(x, stride, results): + r = cuda.grid(1) + l = np.int32(r - (stride[0])) + if (r > results.shape[0]-1) or (l < 0): + results[r] = -1 + else: + x1, y1 = x[l, 0], x[l, 1] + x2, y2 = x[r, 0], x[r, 1] + bearing = _rad2deg(math.atan2(x2 - x1, y2 - y1)) + results[r] = (bearing + 360) % 360 + + +def sliding_bearing(x: np.ndarray, + stride: Optional[float] = 1, + sample_rate: Optional[float] = 1) -> np.ndarray: + """ + Compute the bearing between consecutive points in a 2D coordinate array using a sliding window approach using GPU acceleration. + + This function calculates the angle (bearing) in degrees between each point and a point a certain number of + steps ahead (defined by `stride`) in the 2D coordinate array `x`. The bearing is calculated using the + arctangent of the difference in coordinates, converted from radians to degrees. + + .. csv-table:: + :header: EXPECTED RUNTIMES + :file: ../../../docs/tables/sliding_bearing.csv + :widths: 10, 90 + :align: center + :header-rows: 1 + + .. seealso:: + :func:`simba.mixins.circular_statistics.CircularStatisticsMixin.sliding_bearing` + + :param np.ndarray x: A 2D array of shape `(n, 2)` where each row represents a point with `x` and `y` coordinates. The array must be numeric. + :param Optional[float] stride: The time (multiplied by `sample_rate`) to look ahead when computing the bearing in seconds. Defaults to 1. + :param Optional[float] sample_rate: A multiplier applied to the `stride` value to determine the actual step size for calculating the bearing. E.g., frames per second. Defaults to 1. If the resulting stride is less than 1, it is automatically set to 1. + :return:A 1D array of shape `(n,)` containing the calculated bearings in degrees. Values outside the valid range (i.e., where the stride exceeds array bounds) are set to -1. + :rtype: np.ndarray + """ + + check_valid_array(data=x, source=f'{sliding_bearing.__name__} x', accepted_ndims=(2,), accepted_axis_1_shape=(2,), accepted_dtypes=Formats.NUMERIC_DTYPES.value) + check_float(name=f'{sliding_bearing.__name__} stride', value=stride, min_value=10e-6, max_value=x.shape[0]-1) + check_float(name=f'{sliding_bearing.__name__} sample_rate', value=sample_rate, min_value=10e-6, max_value=x.shape[0]-1) + stride = int(stride * sample_rate) + if stride < 1: + stride = 1 + stride = np.array([stride]).astype(np.int64) + bpg = (x.shape[0] + (THREADS_PER_BLOCK - 1)) // THREADS_PER_BLOCK + x_dev = cuda.to_device(x) + stride_dev = cuda.to_device(stride) + results = cuda.device_array(x.shape[0], dtype=np.float32) + _sliding_bearing[bpg, THREADS_PER_BLOCK](x_dev, stride_dev, results) + return results.copy_to_host() + + +@cuda.jit(device=True) +def _rad2deg(x): + return x * (180 / math.pi) + + +@cuda.jit() +def _sliding_angular_diff(data, strides, results): + x, y = cuda.grid(2) + if (x > data.shape[0] - 1) or (y > strides.shape[0] - 1): + return + else: + stride = int(strides[y]) + if x - stride < 0: + return + a_2 = data[x] + a_1 = data[x - stride] + distance = math.pi - abs(math.pi - abs(a_1 - a_2)) + distance = abs(int(_rad2deg(distance)) + 1) + results[x][y] = distance + + +def sliding_angular_diff(x: np.ndarray, + time_windows: np.ndarray, + fps: float) -> np.ndarray: + """ + Calculate the sliding angular differences for a given time window using GPU acceleration. + + + This function computes the angular differences between each angle in `x` + and the corresponding angle located at a distance determined by the time window + and frame rate (fps). The results are returned as a 2D array where each row corresponds + to a position in `x`, and each column corresponds to a different time window. + + .. csv-table:: + :header: EXPECTED RUNTIMES + :file: ../../../docs/tables/sliding_angular_diff.csv + :widths: 10, 90 + :align: center + :header-rows: 1 + + + .. seealso:: + :func:`simba.mixins.circular_statistics.CircularStatisticsMixin.sliding_angular_diff` + + .. math:: + \text{difference} = \pi - |\pi - |a_1 - a_2|| + + Where: + - \( a_1 \) is the angle at position `x`. + - \( a_2 \) is the angle at position `x - \text{stride}`. + + :param np.ndarray x: 1D array of angles in degrees. + :param np.ndarray time_windows: 1D array of time windows in seconds to determine the stride (distance in frames) between angles. + :param float fps: Frame rate (frames per second) used to convert time windows to strides. + :return: 2D array of angular differences. Each row corresponds to an angle in `x`, and each column corresponds to a time window. + :rtype: np.ndarray + """ + + x = np.deg2rad(x) + strides = np.zeros(time_windows.shape[0]) + for i in range(time_windows.shape[0]): + strides[i] = np.ceil(time_windows[i] * fps).astype(np.int32) + x_dev = cuda.to_device(x) + stride_dev = cuda.to_device(strides) + results = cuda.device_array((x.shape[0], time_windows.shape[0])) + grid_x = (x.shape[0] + THREADS_PER_BLOCK - 1) // THREADS_PER_BLOCK + grid_y = (strides.shape[0] + THREADS_PER_BLOCK - 1) + blocks_per_grid = (grid_x, grid_y) + _sliding_angular_diff[blocks_per_grid, THREADS_PER_BLOCK](x_dev, stride_dev, results) + results = results.copy_to_host().astype(np.int32) + return results + + diff --git a/simba/data_processors/cuda/geometry.py b/simba/data_processors/cuda/geometry.py index 38a8608fa..5ab83dffa 100644 --- a/simba/data_processors/cuda/geometry.py +++ b/simba/data_processors/cuda/geometry.py @@ -46,7 +46,8 @@ def is_inside_rectangle(x: np.ndarray, y: np.ndarray) -> np.ndarray: :param np.ndarray x: 2d numeric np.ndarray size (N, 2). :param np.ndarray y: 2d numeric np.ndarray size (2, 2) (top left[x, y], bottom right[x, y]) - :return np.ndarray: 2d numeric boolean (N, 1) with 1s representing the point being inside the rectangle and 0 if the point is outside the rectangle. + :return: 2d numeric boolean (N, 1) with 1s representing the point being inside the rectangle and 0 if the point is outside the rectangle. + :rtype: np.ndarray """ x = np.ascontiguousarray(x).astype(np.int32) @@ -298,6 +299,7 @@ def poly_area(data: np.ndarray, :param pixels_per_mm: Optional scaling factor to convert the area from pixels squared to square millimeters. Default is 1.0. :param batch_size: Optional batch size for processing the data in chunks to fit in memory. Default is 0.5e+7. :return: A 1D numpy array of shape (N,) containing the computed area of each polygon in square millimeters. + :rtype: np.ndarray """ check_valid_array(data=data, source=f'{poly_area} data', accepted_ndims=(3,), accepted_dtypes=Formats.NUMERIC_DTYPES.value) diff --git a/simba/data_processors/cuda/image.py b/simba/data_processors/cuda/image.py index df2a8d3c1..59b62ee41 100644 --- a/simba/data_processors/cuda/image.py +++ b/simba/data_processors/cuda/image.py @@ -763,7 +763,6 @@ def slice_imgs(video_path: Union[str, os.PathLike], """ Slice frames from a video based on given shape coordinates (rectangles or circles) and return the cropped regions using GPU acceleration. - .. video:: _static/img/slice_imgs_gpu.webm :width: 800 :autoplay: diff --git a/simba/data_processors/cuda/statistics.py b/simba/data_processors/cuda/statistics.py index d1e586740..bacc7ef1f 100644 --- a/simba/data_processors/cuda/statistics.py +++ b/simba/data_processors/cuda/statistics.py @@ -49,13 +49,14 @@ def get_3pt_angle(x: np.ndarray, y: np.ndarray, z: np.ndarray) -> np.ndarray: :header-rows: 1 .. seealso:: - For CPU function see :func:`~simba.mixins.FeatureExtractionMixin.angle3pt` and - For CPU function see :func:`~simba.mixins.FeatureExtractionMixin.angle3pt_serialized`. + For CPU function see :func:`~simba.mixins.feature_extraction_mixin.FeatureExtractionMixin.angle3pt` and + For CPU function see :func:`~simba.mixins.feature_extraction_mixin.FeatureExtractionMixin.angle3pt_serialized`. :param x: A numpy array of shape (n, 2) representing the first point (e.g., nose) coordinates. :param y: A numpy array of shape (n, 2) representing the second point (e.g., center) coordinates, where the angle is computed. :param z: A numpy array of shape (n, 2) representing the second point (e.g., center) coordinates, where the angle is computed. :return: A numpy array of shape (n, 1) containing the calculated angles (in degrees) for each row. + :rtype: np.ndarray :example: >>> video_path = r"/mnt/c/troubleshooting/mitra/project_folder/videos/501_MA142_Gi_CNO_0514.mp4" @@ -110,11 +111,12 @@ def count_values_in_ranges(x: np.ndarray, r: np.ndarray) -> np.ndarray: :header-rows: 1 .. seealso:: - For CPU function see :func:`~simba.mixins.FeatureExtractionMixin.count_values_in_range`. + For CPU function see :func:`~simba.mixins.feature_extraction_mixin.FeatureExtractionMixin.count_values_in_range`. :param np.ndarray x: 2d array with feature values. :param np.ndarray r: 2d array with lower and upper boundaries. - :return np.ndarray: 2d array of size len(x) x len(r) with the counts of values in each feature range (inclusive). + :return: 2d array of size len(x) x len(r) with the counts of values in each feature range (inclusive). + :rtype: np.ndarray :example: >>> x = np.random.randint(1, 11, (10, 10)).astype(np.int8) @@ -153,7 +155,7 @@ def get_euclidean_distance_cuda(x: np.ndarray, y: np.ndarray) -> np.ndarray: :header-rows: 1 .. seealso:: - For CPU function see :func:`~simba.mixins.FeatureExtractionMixin.framewise_euclidean_distance`. + For CPU function see :func:`~simba.mixins.feature_extraction_mixin.FeatureExtractionMixin.framewise_euclidean_distance`. For CuPY function see :func:`~simba.data_processors.cuda.statistics.get_euclidean_distance_cupy`. @@ -193,7 +195,7 @@ def get_euclidean_distance_cupy(x: np.ndarray, .. seealso:: - For CPU function see :func:`~simba.mixins.FeatureExtractionMixin.framewise_euclidean_distance`. + For CPU function see :func:`~simba.mixins.feature_extraction_mixin.FeatureExtractionMixin.framewise_euclidean_distance`. For CUDA JIT function see :func:`~simba.data_processors.cuda.statistics.get_euclidean_distance_cuda`. :param np.ndarray x: A 2D NumPy array with shape (n, 2), where each row represents a point in a 2D space. diff --git a/simba/mixins/config_reader.py b/simba/mixins/config_reader.py index f495c2741..e54c14b4a 100644 --- a/simba/mixins/config_reader.py +++ b/simba/mixins/config_reader.py @@ -710,16 +710,13 @@ def read_config_entry( source=self.__class__.__name__, ) - def read_video_info_csv(self, file_path: str) -> pd.DataFrame: + def read_video_info_csv(self, file_path: Union[str, os.PathLike]) -> pd.DataFrame: """ Helper to read the project_folder/logs/video_info.csv of the SimBA project in as a pd.DataFrame - Parameters - ---------- - file_path: str - Returns - ------- - pd.DataFrame + :param Union[str, os.PathLike] file_path: Path to the project_folder/logs/video_info.csv file. + :return: Dataframe representation of the file. + :rtype: pd.DataFrame """ if not os.path.isfile(file_path): @@ -764,8 +761,7 @@ def read_video_info_csv(self, file_path: str) -> pd.DataFrame: return info_df def read_video_info( - self, video_name: str, raise_error: Optional[bool] = True - ) -> (pd.DataFrame, float, float): + self, video_name: str, raise_error: Optional[bool] = True) -> Tuple[pd.DataFrame, float, float]: """ Helper to read the meta-data (pixels per mm, resolution, fps) from the video_info.csv for a single input file. @@ -773,7 +769,8 @@ def read_video_info( :param Optional[bool] raise_error: If True, raise error if video info for the video name cannot be found. Default: True. :raise ParametersFileError: If ``raise_error`` and video metadata info is not found :raise DuplicationError: If file contains multiple entries for the same video. - :return (pd.DataFrame, float, float) representing all video info, pixels per mm, and fps + :returns: Tuple representing all video info, pixels per mm, and fps + :rtype: Tuple[pd.DataFrame, float, float] """ video_settings = self.video_info_df.loc[ diff --git a/simba/mixins/feature_extraction_mixin.py b/simba/mixins/feature_extraction_mixin.py index f8dc94e41..00cfb76c3 100644 --- a/simba/mixins/feature_extraction_mixin.py +++ b/simba/mixins/feature_extraction_mixin.py @@ -82,7 +82,8 @@ def euclidean_distance( .. seealso:: Use :meth:`simba.mixins.feature_extraction_mixin.FeatureExtractionMixin.framewise_euclidean_distance` - for imporved run-times. + for imporved run-times. Use :func:`simba.data_processors.cuda.statistics.get_euclidean_distance_cuda` + or :func:`simba.data_processors.cuda.statistics.get_euclidean_distance_cupy` for GPU acceleration. :param np.ndarray bp_1_x: 2D array of size len(frames) x 1 with bodypart 1 x-coordinates. :param np.ndarray bp_2_x: 2D array of size len(frames) x 1 with bodypart 2 x-coordinates. @@ -112,6 +113,11 @@ def angle3pt(ax: float, ay: float, bx: float, by: float, cx: float, cy: float) - :width: 300 :align: center + .. seealso:: + :func:`simba.mixins.feature_extraction_mixin.FeatureExtractionMixin.angle3pt_serialized`, + :func: + + :example: >>> FeatureExtractionMixin.angle3pt(ax=122.0, ay=198.0, bx=237.0, by=138.0, cx=191.0, cy=109) >>> 59.78156901181637 @@ -743,12 +749,10 @@ def minimum_bounding_rectangle(points: np.ndarray) -> np.ndarray: @staticmethod @jit(nopython=True) - def framewise_euclidean_distance( - location_1: np.ndarray, - location_2: np.ndarray, - px_per_mm: float, - centimeter: bool = False, - ) -> np.ndarray: + def framewise_euclidean_distance(location_1: np.ndarray, + location_2: np.ndarray, + px_per_mm: float, + centimeter: bool = False) -> np.ndarray: """ Jitted helper finding frame-wise distances between two moving locations in millimeter or centimeter. diff --git a/simba/roi_tools/ROI_define.py b/simba/roi_tools/ROI_define.py index e3ed87e86..f2b2d475e 100644 --- a/simba/roi_tools/ROI_define.py +++ b/simba/roi_tools/ROI_define.py @@ -1,7 +1,8 @@ import copy -import glob import os from tkinter import * +from PIL import ImageTk +import PIL.Image import cv2 import pandas as pd @@ -23,10 +24,11 @@ from simba.utils.printing import log_event, stdout_success from simba.utils.read_write import find_all_videos_in_directory, get_fn_ext from simba.utils.warnings import NoDataFoundWarning +from simba.utils.lookups import get_icons_paths +from simba.ui.pop_ups.roi_fixed_size_pop_up import DrawFixedROIPopUp WINDOW_SIZE = (800, 750) - class ROI_definitions(ConfigReader, PopUpMixin): """ Launch ROI user-interface for drawing user-defined shapes in a video. @@ -64,6 +66,11 @@ def __init__(self, config_path: str, video_path: str): self.other_video_file_names.append(os.path.basename(video)) self.video_info, self.curr_px_mm, self.curr_fps = self.read_video_info(video_name=self.file_name) + self.menu_icons = get_icons_paths() + + for k in self.menu_icons.keys(): + self.menu_icons[k]["img"] = ImageTk.PhotoImage(image=PIL.Image.open(os.path.join(os.path.dirname(__file__), self.menu_icons[k]["icon_path"]))) + self.roi_root = Toplevel() self.roi_root.minsize(WINDOW_SIZE[0], WINDOW_SIZE[1]) self.screen_width = self.roi_root.winfo_screenwidth() @@ -224,9 +231,7 @@ def apply_rois_from_other_video(self): for shape_type in ["rectangles", "circleDf", "polygons"]: c_df = pd.read_hdf(self.roi_coordinates_path, key=shape_type) if len(c_df) > 0: - c_df = c_df[c_df["Video"] == target_video].reset_index( - drop=True - ) + c_df = c_df[c_df["Video"] == target_video].reset_index(drop=True) c_df["Video"] = self.file_name c_df = c_df.to_dict("records") if shape_type == "rectangles": @@ -848,9 +853,9 @@ def window_menus(self): menu = Menu(self.roi_root) file_menu = Menu(menu) menu.add_cascade(label="File (ROI)", menu=file_menu) - file_menu.add_command( - label="Preferences...", command=lambda: PreferenceMenu(self.image_data) - ) + + file_menu.add_command(label="Preferences...", compound="left", image=self.menu_icons["settings"]["img"], command=lambda: PreferenceMenu(self.image_data)) + file_menu.add_command(label="Draw ROIs of pre-defined sizes...", compound="left", image=self.menu_icons["size_black"]["img"], command=lambda: DrawFixedROIPopUp(roi_image=self.image_data)) file_menu.add_separator() file_menu.add_command(label="Exit", command=self.Exit) self.roi_root.config(menu=menu) @@ -893,9 +898,7 @@ def __init__(self, image_data): line_type_dropdown = OptionMenu(pref_lbl_frame, self.line_type, *line_type_list) text_thickness_dropdown = OptionMenu(pref_lbl_frame, self.text_thickness, *text_thickness_list) text_size_dropdown = OptionMenu(pref_lbl_frame, self.text_size, *text_size_list) - click_sens_dropdown = OptionMenu( - pref_lbl_frame, self.click_sens, *click_sensitivity_list - ) + click_sens_dropdown = OptionMenu(pref_lbl_frame, self.click_sens, *click_sensitivity_list) duplicate_jump_size_lbl = Label(pref_lbl_frame, text="DUPLICATE SHAPE JUMP: ", font=Formats.FONT_REGULAR.value) duplicate_jump_size_list = list(range(1, 100, 5)) self.duplicate_jump_size = IntVar() diff --git a/simba/roi_tools/ROI_image.py b/simba/roi_tools/ROI_image.py index f3b58a07a..8bfe977eb 100644 --- a/simba/roi_tools/ROI_image.py +++ b/simba/roi_tools/ROI_image.py @@ -25,9 +25,7 @@ def __init__(self, config = read_config_file(config_path=config_path) self.roi_define = ROI_define_instance - self.project_path = config.get( - ConfigKey.GENERAL_SETTINGS.value, ConfigKey.PROJECT_PATH.value - ) + self.project_path = config.get(ConfigKey.GENERAL_SETTINGS.value, ConfigKey.PROJECT_PATH.value) _, self.curr_vid_name, ext = get_fn_ext(video_path) ( self.duplicate_jump_size, @@ -48,10 +46,7 @@ def __init__(self, self.colors = self.roi_define.named_shape_colors self.select_color = (128, 128, 128) _, self.orig_frame = self.cap.read() - self.frame_width, self.frame_height = ( - self.orig_frame.shape[0], - self.orig_frame.shape[1], - ) + self.frame_width, self.frame_height = (self.orig_frame.shape[0], self.orig_frame.shape[1]) self.frame_default_loc = ( int(self.roi_define.default_top_left_x - self.frame_width), 0, @@ -539,19 +534,13 @@ def check_if_click_is_tag(): def remove_ROI(self, roi_to_delete): if roi_to_delete.startswith("Rectangle"): rectangle_name = roi_to_delete.split("Rectangle: ")[1] - self.out_rectangles[:] = [ - d for d in self.out_rectangles if d.get("Name") != rectangle_name - ] + self.out_rectangles[:] = [d for d in self.out_rectangles if d.get("Name") != rectangle_name] if roi_to_delete.startswith("Circle"): circle_name = roi_to_delete.split("Circle: ")[1] - self.out_circles[:] = [ - d for d in self.out_circles if d.get("Name") != circle_name - ] + self.out_circles[:] = [d for d in self.out_circles if d.get("Name") != circle_name] if roi_to_delete.startswith("Polygon"): polygon_name = roi_to_delete.split("Polygon: ")[1] - self.out_polygon[:] = [ - d for d in self.out_polygon if d.get("Name") != polygon_name - ] + self.out_polygon[:] = [d for d in self.out_polygon if d.get("Name") != polygon_name] self.insert_all_ROIs_into_image() def insert_all_ROIs_into_image( diff --git a/simba/ui/pop_ups/roi_fixed_size_pop_up.py b/simba/ui/pop_ups/roi_fixed_size_pop_up.py new file mode 100644 index 000000000..b9d3eb889 --- /dev/null +++ b/simba/ui/pop_ups/roi_fixed_size_pop_up.py @@ -0,0 +1,323 @@ +from typing import Tuple, Dict, Optional +import numpy as np + +from simba.mixins.pop_up_mixin import PopUpMixin +from tkinter import LabelFrame, NW, Label +from simba.ui.tkinter_functions import Entry_Box, DropDownMenu, SimbaButton +from simba.utils.enums import Formats +from simba.utils.lookups import get_color_dict +from simba.utils.checks import check_str, check_int, check_valid_tuple +from simba.roi_tools.ROI_image import ROI_image_class +from simba.utils.errors import InvalidInputError +from simba.utils.printing import stdout_success + +THICKNESS_OPTIONS = list(range(1, 26, 1)) +EAR_TAG_SIZE_OPTIONS = list(range(1, 26, 1)) +THICKNESS_OPTIONS.insert(0, 'THICKNESS') +EAR_TAG_SIZE_OPTIONS.insert(0, 'EAR TAG SIZE') + + +def get_half_circle_vertices(center: Tuple[int, int], + radius: int, + direction: str, + n_points: Optional[int] = 50) -> Tuple[np.ndarray, Dict[str, Tuple[int, int]]]: + + check_valid_tuple(x=center, source=get_vertices_hexagon.__name__, accepted_lengths=(2,), valid_dtypes=Formats.NUMERIC_DTYPES.value) + check_int(name='radius', value=radius, min_value=1) + check_str(name='direction', options=['NORTH', 'SOUTH', 'WEST', 'EAST'], value=direction) + x_c, y_c = center + if direction == "WEST": + a = np.linspace(np.pi / 2, 3 * np.pi / 2, n_points) + elif direction == "EAST": + a = np.linspace(-np.pi / 2, np.pi / 2, n_points) + elif direction == "SOUTH": + a = np.linspace(0, np.pi, n_points) + else: + a = np.linspace(np.pi, 2 * np.pi, n_points) + x, y = x_c + radius * np.cos(a), y_c + radius * np.sin(a) + vertices = np.column_stack((x, y)).astype(np.int32) + vertices_dict = {"Center_tag": (center[0], center[1])} + for tag_id in range(vertices.shape[0]): + vertices_dict[f"Tag_{tag_id}"] = (vertices[tag_id][0], vertices[tag_id][1]) + return (np.array(vertices).astype("int32"), vertices_dict) + + +def get_vertices_hexagon(center: Tuple[int, int], + radius: int) -> Tuple[np.ndarray, Dict[str, Tuple[int, int]]]: + + + check_valid_tuple(x=center, source=get_vertices_hexagon.__name__, accepted_lengths=(2,), valid_dtypes=Formats.NUMERIC_DTYPES.value) + check_int(name='radius', value=radius, min_value=1) + vertices = [] + x_c, y_c = center + for i in range(6): + angle_rad = np.deg2rad(60 * i) + x_i = x_c + radius * np.cos(angle_rad) + y_i = y_c + radius * np.sin(angle_rad) + vertices.append((x_i, y_i)) + + vertices_dict = {"Center_tag": (center[0], center[1])} + for tag_id, tag in enumerate(vertices): + vertices_dict[f"Tag_{tag_id}"] = (int(tag[0]), int(tag[1])) + return (np.array(vertices).astype("int32"), vertices_dict) + +def get_ear_tags_for_rectangle(center: Tuple[int, int], width: int, height: int) -> Dict[str, int]: + """ + Knowing the center, width, and height of rectangle, return its vertices. + + :param Tuple[int, int] center: The center x and y coordinates of the rectangle + :param int width: The width of the rectangle in pixels. + :param Tuple[int, int] width: The width of the rectangle in pixels. + """ + + check_valid_tuple(x=center, source=get_ear_tags_for_rectangle.__name__, accepted_lengths=(2,), valid_dtypes=Formats.NUMERIC_DTYPES.value) + check_int(name='width', value=width, min_value=1) + check_int(name='height', value=height, min_value=1) + tags = {} + tags['top_left_x'] = int((center[1] - (width/2))) + tags['top_left_y'] = int(center[0] - (height/2)) + tags['bottom_right_x'] = int(center[1] + (width/2)) + tags['bottom_right_y'] = int(center[0] + (height/2)) + tags['top_right_tag'] = (int(center[1] + (width/2)), int(center[0] - (height/2))) + tags['bottom_left_tag'] = (int(center[1] - (width / 2)), int(center[0] + (height / 2))) + tags['top_tag'] = (int(center[1]), int(center[0] - (height / 2))) + tags['right_tag'] = (int(center[1] + (width / 2)), int(center[0])) + tags['left_tag'] = (int(center[1] - (width / 2)), int(center[0])) + tags['bottom_tag'] = (int(center[1]), int(center[0] + (height / 2))) + return tags + +class DrawFixedROIPopUp(PopUpMixin): + + """ + GUI for drawing specifying + """ + def __init__(self, + roi_image: ROI_image_class): + + PopUpMixin.__init__(self, title="DRAW ROI OF FIXED SIZE") + self.clrs_dict = get_color_dict() + self.clrs = list(self.clrs_dict.keys()) + self.shape_cnt = 0 + self.roi_image = roi_image + self.roi_define = roi_image.roi_define + self.jump_size = roi_image.roi_define.duplicate_jump_size + self.px_per_mm = roi_image.roi_define.curr_px_mm + self.w, self.h = self.roi_image.frame_height, self.roi_image.frame_width + self.img_center = (int(self.h/2), int(self.w/2)) + + self.settings_frm = LabelFrame(self.main_frm, text="SETTINGS", pady=10, font=Formats.FONT_HEADER.value, fg="black") + self.name_eb = Entry_Box(self.settings_frm, 'NAME', 10) + self.clr_drpdwn = DropDownMenu(self.settings_frm, 'COLOR:', self.clrs, 10) + self.clr_drpdwn.setChoices('Red') + self.thickness_drpdwn = DropDownMenu(self.settings_frm, 'THICKNESS:', THICKNESS_OPTIONS, 10) + self.thickness_drpdwn.setChoices(10) + self.eartag_size_drpdwn = DropDownMenu(self.settings_frm, 'EAR TAG SIZE', EAR_TAG_SIZE_OPTIONS, 10) + self.eartag_size_drpdwn.setChoices(5) + + self.settings_frm.grid(row=0, column=0, sticky=NW) + self.name_eb.grid(row=0, column=0, sticky=NW) + self.clr_drpdwn.grid(row=0, column=1, sticky=NW) + self.thickness_drpdwn.grid(row=0, column=2, sticky=NW) + self.eartag_size_drpdwn.grid(row=0, column=3, sticky=NW) + + self.rectangle_frm = LabelFrame(self.main_frm, text="ADD RECTANGLE", pady=10, font=Formats.FONT_HEADER.value, fg="black") + self.rectangle_width_eb = Entry_Box(self.rectangle_frm, '', 0, None, validation='numeric', entry_box_width='9') + self.rectangle_width_eb.entry_set('WIDTH (MM)') + self.rectangle_height_eb = Entry_Box(self.rectangle_frm, '', 0, None, validation='numeric', entry_box_width='9') + self.rectangle_height_eb.entry_set('HEIGHT (MM)') + add_rect_btn = SimbaButton(parent=self.rectangle_frm, txt='ADD RECTANGLE', img='square_black', cmd=self.add_rect, txt_clr='blue') + self.rectangle_frm.grid(row=1, column=0, sticky=NW) + self.rectangle_width_eb.grid(row=0, column=0, sticky=NW) + self.rectangle_height_eb.grid(row=0, column=1, sticky=NW) + add_rect_btn.grid(row=1, column=0, sticky=NW) + + self.circle_frm = LabelFrame(self.main_frm, text="ADD CIRCLE", pady=10, font=Formats.FONT_HEADER.value, fg="black") + self.circle_radius_eb = Entry_Box(self.circle_frm, '', 0, None, validation='numeric', entry_box_width='9') + self.circle_radius_eb.entry_set('RADIUS (MM)') + add_circle_btn = SimbaButton(parent=self.circle_frm, txt='ADD CIRCLE', img='circle_2', cmd=self.add_circle, txt_clr='blue') + self.circle_frm.grid(row=2, column=0, sticky=NW) + self.circle_radius_eb.grid(row=0, column=0, sticky=NW) + add_circle_btn.grid(row=1, column=0, sticky=NW) + + self.hexagon_frm = LabelFrame(self.main_frm, text="ADD HEXAGON", pady=10, font=Formats.FONT_HEADER.value, fg="black") + self.hexagon_radius_eb = Entry_Box(self.hexagon_frm, '', 0, None, validation='numeric', entry_box_width='9') + self.hexagon_radius_eb.entry_set('RADIUS (MM)') + add_hex_btn = SimbaButton(parent=self.hexagon_frm, txt='ADD HEXAGON', img='hexagon', cmd=self.add_hex, txt_clr='blue') + + self.hexagon_frm.grid(row=3, column=0, sticky=NW) + self.hexagon_radius_eb.grid(row=0, column=0, sticky=NW) + add_hex_btn.grid(row=1, column=0, sticky=NW) + + self.half_circle_frm = LabelFrame(self.main_frm, text="ADD HALF CIRCLE", pady=10, font=Formats.FONT_HEADER.value, fg="black") + self.half_circle_radius_eb = Entry_Box(self.half_circle_frm, '', 0, None, validation='numeric', entry_box_width='9') + self.half_circle_radius_eb.entry_set('RADIUS (MM)') + self.half_circle_direction_drpdwn = DropDownMenu(self.half_circle_frm, 'DIRECTION:', ['NORTH', 'SOUTH', 'WEST', 'EAST'], 10) + self.half_circle_direction_drpdwn.setChoices('NORTH') + add_half_circle_btn = SimbaButton(parent=self.half_circle_frm, txt='ADD HALF CIRCLE', img='half_circle', cmd=self.add_half_circle, txt_clr='blue') + + self.half_circle_frm.grid(row=4, column=0, sticky=NW) + self.half_circle_radius_eb.grid(row=0, column=0, sticky=NW) + self.half_circle_direction_drpdwn.grid(row=0, column=1, sticky=NW) + add_half_circle_btn.grid(row=1, column=0, sticky=NW) + + self.info_txt = Label(self.main_frm, text='', font=Formats.FONT_REGULAR.value) + self.info_txt.grid(row=5, column=0, sticky=NW) + self.main_frm.mainloop() + + + def _checks(self): + name = self.name_eb.entry_get.strip() + valid, error_msg = check_str(name='ROI NAME', value=name, invalid_options=['NAME'], allow_blank=False, raise_error=False) + if not valid: self.info_txt['text'] = error_msg; raise InvalidInputError(msg=error_msg, source=self.__class__.__name__) + valid, error_msg = check_int(name='THICKNESS', value=self.thickness_drpdwn.getChoices(), min_value=1, raise_error=False) + if not valid: self.info_txt['text'] = error_msg; raise InvalidInputError(msg=error_msg, source=self.__class__.__name__) + valid, error_msg = check_int(name='EAR TAG SIZE', value=self.eartag_size_drpdwn.getChoices(), min_value=1, raise_error=False) + if not valid: self.info_txt['text'] = error_msg; raise InvalidInputError(msg=error_msg, source=self.__class__.__name__) + valid, error_msg = check_str(name='COLOR', value=self.clr_drpdwn.getChoices(), options=self.clrs, raise_error=False) + if not valid: self.info_txt['text'] = error_msg; raise InvalidInputError(msg=error_msg, source=self.__class__.__name__) + names_of_existing_rois = [x['Name'] for x in self.roi_image.out_rectangles] + [x['Name'] for x in self.roi_image.out_circles] + [x['Name'] for x in self.roi_image.out_rectangles] + if name in names_of_existing_rois: + error_msg = f'An ROI named {name} already exist for video {self.roi_define.file_name}. PLease choose a different name' + self.info_txt['text'] = error_msg + raise InvalidInputError(error_msg, source=self.__class__.__name__) + + self.clr_name = self.clr_drpdwn.getChoices() + self.thickness = int(self.thickness_drpdwn.getChoices()) + self.ear_tag_size = int(self.eartag_size_drpdwn.getChoices()) + self.name = self.name_eb.entry_get.strip() + + def add_rect(self): + self._checks() + valid, error_msg = check_int(name='WIDTH', value=self.rectangle_width_eb.entry_get, min_value=1) + if not valid: self.info_txt['text'] = error_msg; raise InvalidInputError(msg=error_msg, source=self.__class__.__name__) + valid, error_msg = check_int(name='HEIGHT', value=self.rectangle_height_eb.entry_get, min_value=1) + if not valid: self.info_txt['text'] = error_msg; raise InvalidInputError(msg=error_msg, source=self.__class__.__name__) + mm_width, mm_height = int(self.rectangle_width_eb.entry_get), int(self.rectangle_height_eb.entry_get) + width, height = int(int(self.rectangle_width_eb.entry_get) * float(self.px_per_mm)), int(int(self.rectangle_height_eb.entry_get) * float(self.px_per_mm)) + shape_center = (int(self.img_center[0]) + (self.jump_size*self.shape_cnt), int(self.img_center[1] + (self.jump_size*self.shape_cnt))) + tags = get_ear_tags_for_rectangle(center=shape_center, width=width, height=height) + + results = {"Video": self.roi_define.file_name, + "Shape_type": 'Rectangle', + "Name": self.name, + "Color name": self.clr_name, + "Color BGR": self.clrs_dict[self.clr_name], + "Thickness": self.thickness, + "Center_X": shape_center[1], + "Center_Y": shape_center[0], + "topLeftX": tags['top_left_x'], + "topLeftY": tags['top_left_y'], + "Bottom_right_X": tags['bottom_right_x'], + "Bottom_right_Y": tags['bottom_right_y'], + 'width': width, + 'height': height, + "Tags": {"Center tag": (shape_center[1], shape_center[0]), + "Top left tag": (tags['top_left_x'], tags['top_left_y']), + "Bottom right tag": (tags['bottom_right_x'], tags['bottom_right_y']), + "Top right tag": tags['top_right_tag'], + "Bottom left tag": tags['bottom_left_tag'], + "Top tag": tags['top_tag'], + "Right tag": tags['right_tag'], + "Left tag": tags['left_tag'], + "Bottom tag": tags['bottom_tag']}, + "Ear_tag_size": self.ear_tag_size} + + self.roi_image.out_rectangles.append(results) + self.roi_define.get_all_ROI_names() + self.roi_define.update_delete_ROI_menu() + self.roi_image.insert_all_ROIs_into_image() + txt = f'New rectangle {self.name} (MM h: {mm_height}, w: {mm_width}; PIXELS h {height}, w: {width}) inserted using pixel per millimeter {self.px_per_mm} conversion factor.)' + self.info_txt['text'] = txt + stdout_success(msg=txt) + self.shape_cnt += 1 + + def add_circle(self): + self._checks() + valid, error_msg = check_int(name='RADIUS', value=self.circle_radius_eb.entry_get, min_value=1) + if not valid: self.info_txt['text'] = error_msg; raise InvalidInputError(msg=error_msg, source=self.__class__.__name__) + mm_radius = int(self.circle_radius_eb.entry_get) + radius = int(int(self.circle_radius_eb.entry_get) * float(self.px_per_mm)) + shape_center = (int(self.img_center[0]) + (self.jump_size*self.shape_cnt), int(self.img_center[1] + (self.jump_size*self.shape_cnt))) + results = {'Video': self.roi_define.file_name, + 'Shape_type': "Circle", + 'Name': self.name, + 'Color name': self.clr_name, + "Color BGR": self.clrs_dict[self.clr_name], + "Thickness": self.thickness, + "centerX": shape_center[0], + "centerY": shape_center[1], + "radius": radius, + "Tags": { + "Center tag": (shape_center[0], shape_center[1]), + "Border tag": (shape_center[0], int(shape_center[1]-radius))}, + "Ear_tag_size": self.ear_tag_size, + } + + self.roi_image.out_circles.append(results) + self.roi_define.get_all_ROI_names() + self.roi_define.update_delete_ROI_menu() + self.roi_image.insert_all_ROIs_into_image() + txt = f'New circle {self.name} (MM radius: {mm_radius}, PIXELS radius: {radius}) inserted using pixel per millimeter {self.px_per_mm} conversion factor.)' + self.info_txt['text'] = txt + stdout_success(msg=txt) + self.shape_cnt += 1 + + def add_hex(self): + self._checks() + valid, error_msg = check_int(name='RADIUS', value=self.hexagon_radius_eb.entry_get, min_value=1) + if not valid: self.info_txt['text'] = error_msg; raise InvalidInputError(msg=error_msg, source=self.__class__.__name__) + mm_radius = int(self.hexagon_radius_eb.entry_get) + radius = int(int(self.hexagon_radius_eb.entry_get) * float(self.px_per_mm)) + shape_center = (int(self.img_center[0]) + (self.jump_size*self.shape_cnt), int(self.img_center[1] + (self.jump_size*self.shape_cnt))) + vertices, vertices_dict = get_vertices_hexagon(center=shape_center, radius=radius) + results = {"Video": self.roi_define.file_name, + "Shape_type": "Polygon", + "Name": self.name, + "Color name": self.clr_name, + "Color BGR": self.clrs_dict[self.clr_name], + "Thickness": self.thickness, + "Center_X": shape_center[0], + "Center_Y": shape_center[1], + "vertices": vertices, + "Tags": vertices_dict, + "Ear_tag_size": self.ear_tag_size} + + self.roi_image.out_polygon.append(results) + self.roi_define.get_all_ROI_names() + self.roi_define.update_delete_ROI_menu() + self.roi_image.insert_all_ROIs_into_image() + txt = f'New HEXAGON {self.name} (MM radius: {mm_radius}, PIXELS radius: {radius}) inserted using pixel per millimeter {self.px_per_mm} conversion factor.)' + self.info_txt['text'] = txt + stdout_success(msg=txt) + self.shape_cnt += 1 + + def add_half_circle(self): + self._checks() + valid, error_msg = check_int(name='RADIUS', value=self.half_circle_radius_eb.entry_get, min_value=1) + if not valid: self.info_txt['text'] = error_msg; raise InvalidInputError(msg=error_msg, source=self.__class__.__name__) + mm_radius = int(self.half_circle_radius_eb.entry_get) + radius = int(int(self.half_circle_radius_eb.entry_get) * float(self.px_per_mm)) + shape_center = (int(self.img_center[0]) + (self.jump_size*self.shape_cnt), int(self.img_center[1] + (self.jump_size*self.shape_cnt))) + direction = self.half_circle_direction_drpdwn.getChoices() + vertices, vertices_dict = get_half_circle_vertices(center=shape_center, radius=radius, direction=direction) + + results = {"Video": self.roi_define.file_name, + "Shape_type": "Polygon", + "Name": self.name, + "Color name": self.clr_name, + "Color BGR": self.clrs_dict[self.clr_name], + "Thickness": self.thickness, + "Center_X": shape_center[0], + "Center_Y": shape_center[1], + "vertices": vertices, + "Tags": vertices_dict, + "Ear_tag_size": self.ear_tag_size} + + self.roi_image.out_polygon.append(results) + self.roi_define.get_all_ROI_names() + self.roi_define.update_delete_ROI_menu() + self.roi_image.insert_all_ROIs_into_image() + txt = f'New HEXAGON {self.name} (MM radius: {mm_radius}, PIXELS radius: {radius}) inserted using pixel per millimeter {self.px_per_mm} conversion factor.)' + self.info_txt['text'] = txt + stdout_success(msg=txt) + self.shape_cnt += 1 diff --git a/simba/utils/checks.py b/simba/utils/checks.py index 58ec91657..cab05995a 100644 --- a/simba/utils/checks.py +++ b/simba/utils/checks.py @@ -104,13 +104,13 @@ def check_int( return True, msg -def check_str( - name: str, - value: Any, - options: Optional[Tuple[Any]] = (), - allow_blank: bool = False, - raise_error: bool = True, -) -> (bool, str): +def check_str(name: str, + value: Any, + options: Optional[Tuple[Any]] = (), + allow_blank: bool = False, + invalid_options: Optional[List[str]] = None, + raise_error: bool = True) -> (bool, str): + """ Check if variable is a valid string. @@ -119,6 +119,7 @@ def check_str( :param Optional[Tuple[Any]] options: Tuple of allowed strings. If empty tuple, then any string allowed. Default: (). :param Optional[bool] allow_blank: If True, allow empty string. Default: False. :param Optional[bool] raise_error: If True, then raise error if invalid string. Default: True. + :param Optional[List[str]] invalid_options: If not None, then a list of strings that are invalid. :return bool: False if invalid. True if valid. :return str: If invalid, then error msg. Else empty str. @@ -138,7 +139,18 @@ def check_str( return False, msg if len(options) > 0: if value not in options: - msg = f"{name} is set to {str(value)} in SimBA, but this is not a valid option: {options}" + msg = f"{name} is set to {value} in SimBA, but this is not a valid option: {options}" + if raise_error: + raise StringError(msg=msg, source=check_str.__name__) + else: + return False, msg + else: + return True, msg + + if invalid_options is not None: + check_valid_lst(data=invalid_options, valid_dtypes=(str,), min_len=1) + if value in invalid_options: + msg = f"{name} is set to {value} in SimBA, but this is among invalid options: {invalid_options}" if raise_error: raise StringError(msg=msg, source=check_str.__name__) else: @@ -148,7 +160,6 @@ def check_str( else: return True, msg - def check_float( name: str, value: Any, diff --git a/tests/test_roi_tools.py b/tests/test_roi_tools.py index bf5b06820..3a66d1f0f 100644 --- a/tests/test_roi_tools.py +++ b/tests/test_roi_tools.py @@ -54,7 +54,7 @@ def test_circle_size_calc(circle_dict, px_mm, expected_area): results = circle_size_calc(circle_dict=circle_dict, px_mm=px_mm) assert results['area_cm'] == expected_area -@pytest.mark.parametrize("polygon_dict, px_mm, expected_area", [({'vertices': np.array([[0, 2], [200, 98], [100, 876], [10, 702]])}, 5, 45.29)]) +@pytest.mark.parametrize("polygon_dict, px_mm, expected_area", [({'vertices': np.array([[0, 2], [200, 98], [100, 876], [10, 702]])}, 5, 38.04)]) def test_polygon_size_calc(polygon_dict, px_mm, expected_area): results = polygon_size_calc(polygon_dict=polygon_dict, px_mm=px_mm) assert results['area_cm'] == expected_area