-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
Fix smaller bugs
- Loading branch information
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
#!/bin/sh | ||
|
||
rm -r random_survival_forest.egg-info | ||
rm -r build | ||
rm -r dist | ||
python -m pip install --user --upgrade setuptools wheel | ||
python -m pip install --user --upgrade twine | ||
python setup.py sdist bdist_wheel | ||
python -m twine upload dist/* |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,22 @@ | ||
from random_survival_forest import RandomSurvivalForest, concordance_index | ||
import time | ||
|
||
from lifelines import datasets | ||
from sklearn.model_selection import train_test_split | ||
import time | ||
|
||
from random_survival_forest.models import RandomSurvivalForest | ||
from random_survival_forest.scoring import concordance_index | ||
|
||
rossi = datasets.load_rossi() | ||
# Attention: duration column must be index 0, event column index 1 in y | ||
y = rossi.loc[:, ["arrest", "week"]] | ||
X = rossi.drop(["arrest", "week"], axis=1) | ||
X, X_test, y, y_test = train_test_split(X, y, test_size=0.25, random_state=10) | ||
X, X_test, y, y_test = train_test_split(X, y, test_size=0.33, random_state=10) | ||
|
||
print("RSF") | ||
print("Start training...") | ||
start_time = time.time() | ||
rsf = RandomSurvivalForest(n_estimators=20, n_jobs=-1, min_leaf=10) | ||
rsf = RandomSurvivalForest(n_estimators=10, n_jobs=-1, random_state=10) | ||
rsf = rsf.fit(X, y) | ||
print("--- %s seconds ---" % (time.time() - start_time)) | ||
print(f'--- {round(time.time() - start_time, 3)} seconds ---') | ||
y_pred = rsf.predict(X_test) | ||
c_val = concordance_index(y_time=y_test["week"], y_pred=y_pred, y_event=y_test["arrest"]) | ||
print("C-index", round(c_val, 3)) | ||
print(f'C-index {round(c_val, 3)}') |
This file was deleted.