Skip to content

Commit

Permalink
Add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
julianspaeth committed Sep 4, 2019
1 parent 9ae90f6 commit 553a7ab
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
2 changes: 2 additions & 0 deletions random_survival_forest/RandomSurvivalForest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def fit(self, x, y):
def create_tree(self, x, y, i):
"""
Grows a survival tree for the bootstrap samples.
:param y: label data frame y with survival time as the first column and event as second
:param x: feature data frame x
:param i: Indices
:return: SurvivalTree
"""
Expand Down
12 changes: 6 additions & 6 deletions random_survival_forest/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def logrank_statistics(x, y, feature, min_leaf):
feature2 = list(x_feature[x_feature > split_val].index)
if len(feature1) < min_leaf or len(feature2) < min_leaf:
continue
durations_A = y.iloc[feature1, 0]
event_observed_A = y.iloc[feature1, 1]
durations_B = y.iloc[feature2, 0]
event_observed_B = y.iloc[feature2, 1]
results = logrank_test(durations_A=durations_A, durations_B=durations_B,
event_observed_A=event_observed_A, event_observed_B=event_observed_B)
durations_a = y.iloc[feature1, 0]
event_observed_a = y.iloc[feature1, 1]
durations_b = y.iloc[feature2, 0]
event_observed_b = y.iloc[feature2, 1]
results = logrank_test(durations_A=durations_a, durations_B=durations_b,
event_observed_A=event_observed_a, event_observed_B=event_observed_b)
score = results.test_statistic

if score > score_opt:
Expand Down

0 comments on commit 553a7ab

Please sign in to comment.