From f0e702339c0e1eedff28f9879f89346236b12efd Mon Sep 17 00:00:00 2001 From: Jonathan Nelson Date: Fri, 4 Dec 2020 10:36:43 -0500 Subject: [PATCH] Release 0.6.4 (#76) * update contact email * filter values for flesch and sentence length * Release 0.6.3 (#73) (#74) * update contact email * use alternative version of textstat for flesch scores * filter values for flesch and sentence length * add argument to return single prediction for multiclass probability estimate (#75) * Update version 0.6.4 --- quantgov/__init__.py | 2 +- quantgov/__main__.py | 6 +++++- quantgov/ml/estimation.py | 34 ++++++++++++++++++++++++---------- tests/test_ml.py | 11 +++++++++++ 4 files changed, 41 insertions(+), 12 deletions(-) diff --git a/quantgov/__init__.py b/quantgov/__init__.py index 9a0f55c..ea6de81 100644 --- a/quantgov/__init__.py +++ b/quantgov/__init__.py @@ -4,4 +4,4 @@ from . import corpus, nlp, ml, utils from .utils import load_driver -__version__ = '0.6.3' +__version__ = '0.6.4' diff --git a/quantgov/__main__.py b/quantgov/__main__.py index 0d9a12b..932131d 100644 --- a/quantgov/__main__.py +++ b/quantgov/__main__.py @@ -128,6 +128,9 @@ def parse_args(): estimate.add_argument( '--precision', default=4, type=int, help='number of decimal places to round the probabilities') + estimate.add_argument( + '--oneclass', action='store_true', + help='only return predicted class for multiclass probabilty estimates') estimate.add_argument( '-o', '--outfile', type=lambda x: open(x, 'w', newline='', encoding='utf-8'), @@ -223,7 +226,8 @@ def run_estimator(args): args.estimator, args.corpus, args.probability, - args.precision) + args.precision, + args.oneclass) ) diff --git a/quantgov/ml/estimation.py b/quantgov/ml/estimation.py index 4fa9369..5381495 100644 --- a/quantgov/ml/estimation.py +++ b/quantgov/ml/estimation.py @@ -4,6 +4,7 @@ Functionality for making predictions with an estimator """ import logging +import numpy as np log = logging.getLogger(__name__) @@ -106,7 +107,7 @@ def estimate_probability_multilabel(estimator, streamer, precision): ) -def estimate_probability_multiclass(estimator, streamer, precision): +def estimate_probability_multiclass(estimator, streamer, precision, oneclass): """ Generate probabilities for a one-label, multiclass estimator @@ -119,12 +120,24 @@ def estimate_probability_multiclass(estimator, streamer, precision): """ texts = (doc.text for doc in streamer) - probs = estimator.pipeline.predict_proba(texts).round(precision) - yield from ( - (docidx, (class_, probability)) - for docidx, doc_probs in zip(streamer.index, probs) - for class_, probability in zip(estimator.pipeline.classes_, doc_probs) - ) + probs = estimator.pipeline.predict_proba(texts) + # If oneclass flag is true, only returns the predicted class + if oneclass: + class_indices = list(i[-1] for i in np.argsort(probs, axis=1)) + yield from ( + (docidx, (estimator.pipeline.classes_[class_index], + doc_probs[class_index].round(precision))) + for docidx, doc_probs, class_index in zip( + streamer.index, probs, class_indices) + ) + # Else returns probabilty values for all classes + else: + yield from ( + (docidx, (class_, probability.round(precision))) + for docidx, doc_probs in zip(streamer.index, probs) + for class_, probability in zip( + estimator.pipeline.classes_, doc_probs) + ) def estimate_probability_multilabel_multiclass(estimator, streamer, precision): @@ -140,7 +153,7 @@ def estimate_probability_multilabel_multiclass(estimator, streamer, precision): """ texts = (doc.text for doc in streamer) - probs = estimator.pipeline.predict_proba(texts) + probs = estimator.pipeline.predict_proba(texts).round(precision) yield from ( (docidx, (label_name, class_, prob)) for label_name, label_probs in zip(estimator.label_names, probs) @@ -149,7 +162,8 @@ def estimate_probability_multilabel_multiclass(estimator, streamer, precision): ) -def estimate(estimator, corpus, probability, precision=4, *args, **kwargs): +def estimate(estimator, corpus, probability, precision=4, oneclass=False, + *args, **kwargs): """ Estimate label values for documents in corpus @@ -171,7 +185,7 @@ def estimate(estimator, corpus, probability, precision=4, *args, **kwargs): estimator, streamer, precision) elif estimator.multiclass: # Multiclass probability yield from estimate_probability_multiclass( - estimator, streamer, precision) + estimator, streamer, precision, oneclass) else: # Simple probability yield from estimate_probability( estimator, streamer, precision) diff --git a/tests/test_ml.py b/tests/test_ml.py index 042ddbb..3dcc4c5 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -65,3 +65,14 @@ def test_multiclass_probability_estimator(): 'moby,money,0.1536\n' 'moby,science-and-technology,0.1671\n' 'moby,world,0.141\n') + + +def test_multiclass_probability_oneclass_estimator(): + output = check_output( + ['quantgov', 'ml', 'estimate', + str(PSEUDO_ESTIMATOR_PATH.joinpath('data', 'multiclass.qge')), + str(PSEUDO_CORPUS_PATH), '--probability', '--oneclass'] + ) + assert output == ('file,class,probability\n' + 'cfr,world,0.1997\n' + 'moby,health-and-public-welfare,0.205\n')