Skip to content

Commit

Permalink
Fix bug where the models were not initialized properly
Browse files Browse the repository at this point in the history
  • Loading branch information
julianspaeth committed Oct 25, 2019
1 parent 4c2616c commit ae8cfc1
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 18 deletions.
17 changes: 8 additions & 9 deletions random_survival_forest/Node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,6 @@

class Node:

score = 0
split_val = None
split_var = None
lhs = None
rhs = None
chf = None
chf_terminal = None
terminal = False

def __init__(self, x, y, tree, f_idxs, n_features, timeline, unique_deaths=3, min_leaf=3, random_state=None):
"""
A Node of the Survival Tree.
Expand All @@ -36,6 +27,14 @@ def __init__(self, x, y, tree, f_idxs, n_features, timeline, unique_deaths=3, mi
self.unique_deaths = unique_deaths
self.random_state = random_state
self.min_leaf = min_leaf
self.score = 0
self.split_val = None
self.split_var = None
self.lhs = None
self.rhs = None
self.chf = None
self.chf_terminal = None
self.terminal = False
self.grow_tree()

def grow_tree(self):
Expand Down
12 changes: 6 additions & 6 deletions random_survival_forest/RandomSurvivalForest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,6 @@


class RandomSurvivalForest:
bootstrap_idxs = None
bootstraps = []
oob_idxs = None
oob_score = None
trees = []
random_states = []

def __init__(self, timeline, n_estimators=100, min_leaf=3, unique_deaths=3, n_jobs=None, random_state=None):
"""
Expand All @@ -31,6 +25,12 @@ def __init__(self, timeline, n_estimators=100, min_leaf=3, unique_deaths=3, n_jo
self.unique_deaths = unique_deaths
self.n_jobs = n_jobs
self.random_state = random_state
self.bootstrap_idxs = None
self.bootstraps = []
self.oob_idxs = None
self.oob_score = None
self.trees = []
self.random_states = []

def fit(self, x, y):
"""
Expand Down
3 changes: 1 addition & 2 deletions random_survival_forest/SurvivalTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

class SurvivalTree:

prediction_possible = None

def __init__(self, x, y, f_idxs, n_features, timeline, unique_deaths=3, min_leaf=3, random_state=None):
"""
A Survival Tree to predict survival.
Expand Down Expand Up @@ -34,6 +32,7 @@ def __init__(self, x, y, f_idxs, n_features, timeline, unique_deaths=3, min_leaf
self.lhs = None
self.rhs = None
self.chf = None
self.prediction_possible = None
self.grow_tree()

def grow_tree(self):
Expand Down
2 changes: 2 additions & 0 deletions random_survival_forest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@
"""

from random_survival_forest.RandomSurvivalForest import RandomSurvivalForest
from random_survival_forest.scoring import concordance_index
from random_survival_forest.SurvivalTree import SurvivalTree
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name='random_survival_forest', # How you named your package folder (MyLib)
packages=['random_survival_forest'], # Chose the same as "name"
version='0.6.2', # Start with a small number and increase it with every change you make
version='0.7', # Start with a small number and increase it with every change you make
license="MIT License", # Chose a license from here: https://help.github.com/articles/licensing-a-repository
long_description=readme,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit ae8cfc1

Please sign in to comment.