Skip to content

Commit

Permalink
Adjust QRF to enable single-output predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilwoodruff committed Sep 23, 2024
1 parent 7024666 commit 54449a2
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion policyengine_us_data/utils/qrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ def predict(self, X, count_samples=10, mean_quantile=0.5):
random_generator.beta(a, 1, size=len(X)) * count_samples
)
input_quantiles = input_quantiles.astype(int)
predictions = pred[np.arange(len(X)), :, input_quantiles]
if len(pred.shape) == 2:
predictions = pred[:, input_quantiles]
else:
predictions = pred[:, :, input_quantiles]
return pd.DataFrame(predictions, columns=self.output_columns)

def save(self, path):
Expand Down

0 comments on commit 54449a2

Please sign in to comment.