diff --git a/src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py index 9ab1fb9112..c31f90af0b 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py +++ b/src/finn/custom_op/fpgadataflow/rtl/thresholding_rtl.py @@ -205,9 +205,9 @@ def prepare_codegen_rtl_values(self, model): num_channels = self.get_nodeattr("NumChannels") # number of channels # If a single threshold value is found, broadcast the value - expected_shape = (num_channels, expected_thresholds) - if t_packed.shape != expected_shape: - t_packed = np.broadcast_to(t_packed, expected_shape) + if t_packed.shape[0] == 1: + t_packed = np.broadcast_to(t_packed, (pe, expected_thresholds)) + num_channels = pe channel_fold = int(num_channels / pe) @@ -531,11 +531,11 @@ def make_weight_file(self, weights, weight_file_mode, weight_file_name): min_val = wdt.min() thresholds = np.insert(thresholds, 0, min_val, axis=1) n_thres_steps += 1 - expected_shape = (ch, expected_thresholds) # If a single threshold value is found, broadcast the value - if thresholds.shape != expected_shape: - thresholds = np.broadcast_to(thresholds, expected_shape) + if thresholds.shape[0] == 1: + thresholds = np.broadcast_to(thresholds, (pe, expected_thresholds)) + ch = pe width_padded = roundup_to_integer_multiple(thresholds.shape[1], 2**o_bitwidth) thresh_padded = np.zeros((thresholds.shape[0], width_padded))