-
Notifications
You must be signed in to change notification settings - Fork 0
/
hubconf.py
65 lines (56 loc) · 1.92 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
dependencies = ['torch', 'pandas']
import torch
import pickle
import pandas as pd
import os
from os.path import join
dirname = os.path.dirname(__file__)
def resnet18_cifar10(return_transform=True):
"""
Loads a torchvision ResNet18 model trained on CIFAR10
args: return_transform=True, returns the default input transform for this model
```
model = resnet18_cifar10(return_transform=False)
model, tf = resnet18_cifar10(return_transform=True)
```
"""
model = torch.load(join(dirname, 'detectron/cifar10.pt'))
if return_transform:
return model, torch.load(join(dirname, 'detectron/cifar10_input_transform.pt'))
return model
def uci_heart():
"""
Loads a featurized version of the UCI heart dataset,
as well as an XGB boost model trained on the Cleveland split.
`model, data = uci_heart()`
Data format:
{
'Cleveland': {'train': (data, labels), 'test': (data, labels), 'val': (data, labels)}
'Hungary': (data, labels),
'Switzerland': (data, labels),
'VA Long Beach': (data, labels)
}
The model is trained using default parameters
{
'objective': 'binary:logistic',
'eval_metric': 'auc',
'eta': 0.1,
'max_depth': 6,
'subsample': 0.8,
'colsample_bytree': 0.8,
'min_child_weight': 1,
'nthread': 4,
'tree_method': 'gpu_hist',
} as well as num_boost_round=100.
The test auc on cleveland is 0.809
Data source:
https://www.kaggle.com/datasets/redwankarimsony/heart-disease-data
"""
data = pickle.load(open(join(dirname, 'detectron/uci_heart_features.pkl'), 'rb'))
model = pickle.load(open(join(dirname, 'detectron/uci_xgb_cleveland.pkl'), 'rb'))
return model, data
def uci_heart_raw():
"""
Loads raw data from https://www.kaggle.com/datasets/redwankarimsony/heart-disease-data
"""
return pd.read_csv(join(dirname, 'detectron/heart_disease_uci.csv'))