Skip to content

Commit

Permalink
added resampling for large n sources
Browse files Browse the repository at this point in the history
  • Loading branch information
aldengolab committed Mar 9, 2017
1 parent 804da27 commit 935f347
Showing 1 changed file with 41 additions and 15 deletions.
56 changes: 41 additions & 15 deletions pipeline/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,44 @@
import spacy
from transform_features import get_feature_transformer

def pipeline(args):
'''
Runs the model loop.
'''
df = pd.read_csv(args.filename)
if args.dedupe:
df = df.drop_duplicates(subset='content')
if args.reduce:
df = restrict_sources(df, column)
X = df[args.x_label]
y = df[args.y_label]
parser = spacy.load('en')
X_train, X_test, y_train, y_test = train_test_split(X, y)
loop = ModelLoop(X_train, X_test, y_train, y_test, args.models,
args.iterations, args.output_dir,
thresholds = args.thresholds, ks = args.ks)
loop.run()

def restrict_sources(df, column, max_size=500, random_state=1):
'''
Resamples data set such that samples with more than n=500 are re-sampled
randomly.
'''
print("Resampling sources with frequency larger than {}".format(max_size))
counts = df.groupby(column).count()
counts['count'] = counts[counts.columns[0]]
counts = counts.loc[:,'count']
to_sample = defaultdict(lambda:set([]))
for row in df.itertuples():
if counts.loc[row.source] > max_size:
to_sample[row.source].add(row[0])
remove = []
for source in to_sample:
size = counts.loc[source] - max_size
remove += list(np.random.choice(list(to_sample[source]), size=size, replace=False))
df.drop(remove, inplace=True)
return df

if __name__=='__main__':
parser = argparse.ArgumentParser(description='Run a model loop')
parser.add_argument('filename', type=str,
Expand All @@ -25,21 +63,9 @@
action="store_true")
parser.add_argument('--ks', nargs='+', type=float, help='Metrics at k',
default = [])
parser.add_argument('--reduce', nargs=1, type=int, help='Restrict sample size from large sources',
default = False)

args = parser.parse_args()
print(args)

df = pd.read_csv(args.filename)
if args.dedupe:
df = df.drop_duplicates(subset='content')
# print(df.head())
X = df[args.x_label]
# print(X.head())
y = df[args.y_label]
# print(y.head())
parser = spacy.load('en')
X_train, X_test, y_train, y_test = train_test_split(X, y)
loop = ModelLoop(X_train, X_test, y_train, y_test, args.models,
args.iterations, args.output_dir,
thresholds = args.thresholds, ks = args.ks)
loop.run()
pipeline(args)

0 comments on commit 935f347

Please sign in to comment.