Skip to content

Commit

Permalink
Fix serious prediction bug. C-indices were always to high. OOB score …
Browse files Browse the repository at this point in the history
…war correct
  • Loading branch information
julianspaeth committed Oct 16, 2019
1 parent 548a645 commit ffbdf0b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
33 changes: 18 additions & 15 deletions random_survival_forest/RandomSurvivalForest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,18 @@ class RandomSurvivalForest:
trees = []
random_states = []

def __init__(self, n_estimators=2, min_leaf=3, unique_deaths=3, timeline=None, n_jobs=None, random_state=None):
def __init__(self, n_estimators=2, min_leaf=3, unique_deaths=3, n_jobs=None, random_state=None):
"""
A Random Survival Forest is a prediction model especially designed for survival analysis.
:param n_estimators: The numbers of trees in the forest.
:param timeline: The timeline used for the prediction.
:param min_leaf: The minimum number of samples required to be at a leaf node. A split point at any depth will
only be considered if it leaves at least min_leaf training samples in each of the left and right branches.
:param unique_deaths: The minimum number of unique deaths required to be at a leaf node.
:param random_state: The random state to create reproducible results.
:param n_jobs: The number of jobs to run in parallel for fit. None means 1.
"""
self.n_estimators = n_estimators
self.min_leaf = min_leaf
self.timeline = timeline
self.unique_deaths = unique_deaths
self.n_jobs = n_jobs
self.random_state = random_state
Expand All @@ -49,7 +48,6 @@ def fit(self, x, y):

trees = Parallel(n_jobs=self.n_jobs)(delayed(self.create_tree)(x, y, i) for i in range(self.n_estimators))


for i in range(len(trees)):
if trees[i].prediction_possible is True:
self.trees.append(trees[i])
Expand All @@ -74,9 +72,8 @@ def create_tree(self, x, y, i):
f_idxs = np.random.RandomState(seed=self.random_states[i]).permutation(x.shape[1])[:n_features]

tree = SurvivalTree(x=x.iloc[self.bootstrap_idxs[i], :], y=y.iloc[self.bootstrap_idxs[i], :],
f_idxs=f_idxs, n_features=n_features, timeline=self.timeline,
unique_deaths=self.unique_deaths, min_leaf=self.min_leaf,
random_state=self.random_states[i])
f_idxs=f_idxs, n_features=n_features, unique_deaths=self.unique_deaths,
min_leaf=self.min_leaf, random_state=self.random_states[i])

return tree

Expand Down Expand Up @@ -115,16 +112,22 @@ def compute_oob_score(self, x, y):
def predict(self, xs):
"""
Predict survival for xs.
:param xs:The input samples
:param xs: The input samples
:return: List of the predicted cumulative hazard functions.
"""
preds = []
for x in xs.values:
chfs = []
for q in range(len(self.trees)):
chfs.append(self.trees[q].predict(x))
preds.append(pd.concat(chfs).groupby(level=0).mean())
return preds
ensemble_chfs = []
for sample_idx in range(xs.shape[0]):
denominator = 0
numerator = 0
for b in range(len(self.trees)):
sample = xs.iloc[sample_idx].to_list()
chf = self.trees[b].predict(sample)
denominator = denominator + 1
numerator = numerator + 1 * chf

ensemble_chf = numerator / denominator
ensemble_chfs.append(ensemble_chf)
return ensemble_chfs

def draw_bootstrap_samples(self, data):
"""
Expand Down
2 changes: 1 addition & 1 deletion random_survival_forest/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def concordance_index(y_time, y_pred, y_event):
:param y_event: Actual Survival Events.
:return: c-index.
"""
oob_predicted_outcome = [round(x.sum(), 1) for x in y_pred]
oob_predicted_outcome = [x.sum() for x in y_pred]
possible_pairs = list(combinations(range(len(y_pred)), 2))
concordance = 0
permissible = 0
Expand Down
Empty file added test.py
Empty file.

0 comments on commit ffbdf0b

Please sign in to comment.