Skip to content

Commit

Permalink
cuda docs
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Aug 30, 2024
1 parent 950bc9d commit 77a9f3c
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
13 changes: 12 additions & 1 deletion simba/data_processors/cuda/is_inside_rectangle.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,20 @@ def is_inside_rectangle(x: np.ndarray, y: np.ndarray) -> np.ndarray:
:return np.ndarray: 2d numeric boolean (N, 1) with 1s representing the point being inside the rectangle and 0 if the point is outside the rectangle.
.. csv-table:: Function Performance Table with Markup
:file: ../_tables/is_inside_rectangle.csv
:file: ../../_tables/is_inside_rectangle.csv
:widths: 30, 70
:header-rows: 1
.. csv-table:: Function Performance Table with Markup
:file: ../../../_tables/is_inside_rectangle.csv
:widths: 30, 70
:header-rows: 1
.. csv-table:: Function Performance Table with Markup
:file: /_tables/is_inside_rectangle.csv
:widths: 30, 70
:header-rows: 1
"""


Expand Down
3 changes: 2 additions & 1 deletion simba/mixins/train_model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,11 +676,12 @@ def create_x_importance_log(self,
self.f_importance_save_path = os.path.join(save_dir, f"{clf_name}_{save_file_no}_feature_importance_log.csv")
else:
self.f_importance_save_path = os.path.join(save_dir, f"{clf_name}_feature_importance_log.csv")
if cuRF is not None and isinstance(rf_clf, cuRF):
if cuRF is not None and isinstance(rf_clf, cuRF) and hasattr(rf_clf, 'get_json'):
cuml_tree_nodes = loads(rf_clf.get_json())
importances = list(self.cuml_rf_x_importances(nodes=cuml_tree_nodes, n_features=len(x_names)))
std_importances = [np.nan] * len(importances)
else:
print('s')
importances_per_tree = np.array([tree.feature_importances_ for tree in rf_clf.estimators_])
importances = list(np.mean(importances_per_tree, axis=0))
std_importances = list(np.std(importances_per_tree, axis=0))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
FEATURE,FEATURE_IMPORTANCE_MEAN,FEATURE_IMPORTANCE_STDEV
1,0.5053268795706305,0.1650058251884446
0,0.4946731204293695,0.16500582518844456

0 comments on commit 77a9f3c

Please sign in to comment.