Skip to content

Commit

Permalink
Merge pull request #12 from julianspaeth/update-rsf
Browse files Browse the repository at this point in the history
Fix smaller bugs
  • Loading branch information
julianspaeth authored Oct 19, 2022
2 parents 5867007 + 10f576e commit 17f96cf
Show file tree
Hide file tree
Showing 21 changed files with 485 additions and 371 deletions.
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

43 changes: 43 additions & 0 deletions .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/markdown.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions .idea/random-survival-forest.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ authors:
orcid: https://orcid.org/0000-0003-4562-5816
title: Random Survival Forest
doi: 10.5281/zenodo.5146376
version: v0.1.1-beta
version: v0.1.2-beta
25 changes: 13 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,24 @@ $ pip install random-survival-forest
## Getting Started

```python
>>> from random_survival_forest import RandomSurvivalForest, concordance_index
>>> from lifelines import datasets
>>> from sklearn.model_selection import train_test_split
from random_survival_forest.models import RandomSurvivalForest
from random_survival_forest.scoring import concordance_index
from lifelines import datasets
from sklearn.model_selection import train_test_split


>>> rossi = datasets.load_rossi()
rossi = datasets.load_rossi()
# Attention: duration column must be index 0, event column index 1 in y
>>> y = rossi.loc[:, ["week", "arrest"]]
>>> X = rossi.drop(["arrest", "week"], axis=1)
>>> X, X_test, y, y_test = train_test_split(X, y, test_size=0.25)
y = rossi.loc[:, ["week", "arrest"]]
X = rossi.drop(["arrest", "week"], axis=1)
X, X_test, y, y_test = train_test_split(X, y, test_size=0.25)


>>> rsf = RandomSurvivalForest(n_estimators=20, n_jobs=-1)
>>> rsf = rsf.fit(X, y)
>>> y_pred = rsf.predict(X_test)
>>> c_val = concordance_index(y_time=y_test["week"], y_pred=y_pred, y_event=y_test["arrest"])
>>> print("C-index", round(c_val, 3))
rsf = RandomSurvivalForest(n_estimators=20, n_jobs=-1)
rsf = rsf.fit(X, y)
y_pred = rsf.predict(X_test)
c_val = concordance_index(y_time=y_test["week"], y_pred=y_pred, y_event=y_test["arrest"])
print("C-index", round(c_val, 3))
```

## Feedback
Expand Down
9 changes: 9 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/sh

rm -r random_survival_forest.egg-info
rm -r build
rm -r dist
python -m pip install --user --upgrade setuptools wheel
python -m pip install --user --upgrade twine
python setup.py sdist bdist_wheel
python -m twine upload dist/*
17 changes: 10 additions & 7 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from random_survival_forest import RandomSurvivalForest, concordance_index
import time

from lifelines import datasets
from sklearn.model_selection import train_test_split
import time

from random_survival_forest.models import RandomSurvivalForest
from random_survival_forest.scoring import concordance_index

rossi = datasets.load_rossi()
# Attention: duration column must be index 0, event column index 1 in y
y = rossi.loc[:, ["arrest", "week"]]
X = rossi.drop(["arrest", "week"], axis=1)
X, X_test, y, y_test = train_test_split(X, y, test_size=0.25, random_state=10)
X, X_test, y, y_test = train_test_split(X, y, test_size=0.33, random_state=10)

print("RSF")
print("Start training...")
start_time = time.time()
rsf = RandomSurvivalForest(n_estimators=20, n_jobs=-1, min_leaf=10)
rsf = RandomSurvivalForest(n_estimators=10, n_jobs=-1, random_state=10)
rsf = rsf.fit(X, y)
print("--- %s seconds ---" % (time.time() - start_time))
print(f'--- {round(time.time() - start_time, 3)} seconds ---')
y_pred = rsf.predict(X_test)
c_val = concordance_index(y_time=y_test["week"], y_pred=y_pred, y_event=y_test["arrest"])
print("C-index", round(c_val, 3))
print(f'C-index {round(c_val, 3)}')
91 changes: 0 additions & 91 deletions random_survival_forest/Node.py

This file was deleted.

Loading

0 comments on commit 17f96cf

Please sign in to comment.