diff --git a/random_survival_forest/RandomSurvivalForest.py b/random_survival_forest/RandomSurvivalForest.py index b019682..ad2c7ed 100644 --- a/random_survival_forest/RandomSurvivalForest.py +++ b/random_survival_forest/RandomSurvivalForest.py @@ -59,6 +59,8 @@ def fit(self, x, y): def create_tree(self, x, y, i): """ Grows a survival tree for the bootstrap samples. + :param y: label data frame y with survival time as the first column and event as second + :param x: feature data frame x :param i: Indices :return: SurvivalTree """ diff --git a/random_survival_forest/splitting.py b/random_survival_forest/splitting.py index b8dfe87..6d4d3ea 100644 --- a/random_survival_forest/splitting.py +++ b/random_survival_forest/splitting.py @@ -58,12 +58,12 @@ def logrank_statistics(x, y, feature, min_leaf): feature2 = list(x_feature[x_feature > split_val].index) if len(feature1) < min_leaf or len(feature2) < min_leaf: continue - durations_A = y.iloc[feature1, 0] - event_observed_A = y.iloc[feature1, 1] - durations_B = y.iloc[feature2, 0] - event_observed_B = y.iloc[feature2, 1] - results = logrank_test(durations_A=durations_A, durations_B=durations_B, - event_observed_A=event_observed_A, event_observed_B=event_observed_B) + durations_a = y.iloc[feature1, 0] + event_observed_a = y.iloc[feature1, 1] + durations_b = y.iloc[feature2, 0] + event_observed_b = y.iloc[feature2, 1] + results = logrank_test(durations_A=durations_a, durations_B=durations_b, + event_observed_A=event_observed_a, event_observed_B=event_observed_b) score = results.test_statistic if score > score_opt: