-
Notifications
You must be signed in to change notification settings - Fork 22
/
utils.py
91 lines (67 loc) · 2.56 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Copyright 2020, 37.78 Tecnologia Ltda.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pickle
import pandas as pd
from tensorflow.keras.callbacks import LearningRateScheduler
from constants import SAVE_DIR, W2V_DIR, W2V_SIZE, MAX_LENGTH
import model_functions as fun
import models
def make_icds_histogram(df):
return df.ICD9_CODE.explode().value_counts()
def load_list_from_txt(filepath):
with open(filepath, 'r') as f:
return f.read().split()
def preprocessor(text_series):
return (text_series
.str.replace('<[^>]*>', '')
.str.lower()
.str.replace('[\W]+', ' ')
.str.split())
def preprocessor_tfidf(text_series):
return (text_series
.str.replace('\[\*\*[^\]]*\*\*\]','')
.str.replace('<[^>]*>', '')
.str.replace('[\W]+', ' ')
.str.lower()
.str.replace(' \d+', ' '))
def preprocessor_word2vec(text_series):
return (text_series
.str.replace('\[\*\*[^\]]*\*\*\]','')
.str.replace('<[^>]*>', '')
.str.replace('[\W]+', ' ')
.str.lower()
.str.replace(' \d+', ' ')
.str.split())
def convert_data_to_index(string_data, row_dict):
return [row_dict.get(word, row_dict['_unknown_']) for word in string_data]
def lr_schedule_callback(args):
# Create scheduler function
def scheduler(epoch):
if epoch < args.epoch_drop:
return args.initial_lr
else:
return args.final_lr
return LearningRateScheduler(scheduler, verbose=1)
def get_model(args=None, load_path=None):
if args.MODEL_NAME == 'cte':
return models.CTE_Model(args, load_path)
elif args.MODEL_NAME == 'lr':
return models.LR_Model(args, load_path)
elif args.MODEL_NAME == 'cnn':
return models.CNN_Model(args, load_path)
elif args.MODEL_NAME == 'gru':
return models.GRU_Model(args, load_path)
elif args.MODEL_NAME == 'cnn_att':
return models.CNNAtt_Model(args, load_path)