Skip to content

Commit

Permalink
Merge branch 'deep_classiflie_feat' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Dale committed Sep 11, 2020
2 parents c2581ba + 433c701 commit fa6dada
Show file tree
Hide file tree
Showing 42 changed files with 19,554 additions and 182 deletions.
28 changes: 24 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
### What is Deep Classiflie?
- Deep Classiflie is a framework for developing ML models that bolster fact-checking efficiency. Predominantly a research project<sup id="ae">[e](#ce)</sup>, I plan to extend and maintain this framework in pursuing my own research interests so am sharing it in case it's of any utility to the broader community.
- As a POC, the initial alpha release of Deep Classiflie generates/analyzes a model that continuously classifies a single individual's statements (Donald Trump)<sup id="a1">[1](#f1)</sup> using a single ground truth labeling source (The Washington Post).
- The Deep Classiflie POC model's predictions and performance on the most recent test set can be [explored](#model-exploration) and better understood using the [prediction explorer](pred_explorer.html):
- The Deep Classiflie POC model's [current predictions](current_explorer.html) and performance on the most recent test set can be [explored](#model-exploration) and better understood using
the [current prediction explorer](current_explorer.html):

<img src="docs/assets/current_explorer.gif" alt="current prediction explorer" />
- the [prediction explorer](pred_explorer.html):
<img src="docs/assets/pred_exp.gif" alt="prediction explorer" />
- and the [performance explorer](perf_explorer.html):

Expand Down Expand Up @@ -73,6 +77,17 @@ The best way to start understanding/exploring the current model is to use the ex
<img src="docs/assets/conf_bucket_confusion_matrices.gif" alt="confidence bucket performance explorer" />
</details>

<details><summary markdown="span"><strong>[Current Predictions Explorer](current_explorer.html)</strong>
</summary>

Explore the current (unlabeled) predictions generated by the latest model incarnation. All statements yet to be labeled by current fact-checking sources (currently, only [Washington Post Factchecker](https://www.washingtonpost.com/graphics/politics/trump-claims-database)) are available.
Live predictions continuously added via [ipfs](https://ipfs.io). Twitter statements will be delayed by ~15 minutes to allow thread-based scoring. [Factba.se](https://factba.se) is polled for new statements every 10 minutes.
This explorer provides fact-checkers a means (one of many possible) of using current model predictions and may also help those building fact-checking systems evaluate the potential utility of integrating similar models into their systems.


<img src="docs/assets/current_explorer.gif" alt="current predictions explorer" />
</details>

---

### Core Components
Expand Down Expand Up @@ -108,10 +123,13 @@ The entire initial Deep Classiflie system (raw dataset, model, analytics modules
</summary>

- Interpret statement-level predictions using [captum's](https://captum.ai/) implementation of integrated gradients to visualize attributions of statement predictions to tokens in each statement.
- Prediction and model performance exploration dashboards were built using [bokeh](https://docs.bokeh.org/en/latest/index.html) and [Jekyll](https://github.com/jekyll/jekyll)
- Test set prediction and model performance exploration dashboards were built using [bokeh](https://docs.bokeh.org/en/latest/index.html) and [Jekyll](https://github.com/jekyll/jekyll)
- The [current prediction explorer](current_explorer.html) was built using [datatables](https://datatables.net/) and [ipfs](https://ipfs.io) with pinning provided by [pinata](https://pinata.cloud/)
- Two inference daemons poll, analyze and classify new statements:
1. (still in development) A daemon that publishes via IPFS pubsub, all new statement classifications and inference output.
2. (currently available) Automated false statement reports for predictions meeting the desired [PPV](https://en.wikipedia.org/wiki/Positive_and_negative_predictive_values) confidence threshold can be published on twitter via a twitter bot, which leverages [Tweepy](https://www.tweepy.org/). The bot <sup id="ah">[h](#ch)</sup> tweets out a statement analysis and model interpretation "report" such as the one below for statements the model deems most likely to be labeled falsehoods (see [current performance](#current-performance) for more detail):
1. A daemon that publishes via [IPFS](https://ipfs.io) all new statement classifications and inference output.

<img src="docs/assets/current_explorer.gif" alt="current predictions explorer" />
2. Automated false statement reports for predictions meeting the desired [PPV](https://en.wikipedia.org/wiki/Positive_and_negative_predictive_values) confidence threshold can be published on twitter via a twitter bot, which leverages [Tweepy](https://www.tweepy.org/). The bot <sup id="ah">[h](#ch)</sup> tweets out a statement analysis and model interpretation "report" such as the one below for statements the model deems most likely to be labeled falsehoods (see [current performance](#current-performance) for more detail):

<img src="docs/assets/example_twitter_report.png" alt="Example tweet report" />
- XKCD fans may notice the style of the dashboard explorers and statement reports are XKCD-inspired using the Humor Sans font created by [@ch00ftech](https://twitter.com/ch00ftech). Thanks to him (and [@xkcd](https://twitter.com/xkcd) of course!)
Expand Down Expand Up @@ -168,6 +186,7 @@ To minimize false positives and maximize the model's utility, the following appr
</summary>

- Extensive suite of reporting views for analyzing model performance and global and local levels
- A [current prediction explorer](current_explorer.html) that provides fact-checkers a means (one of many possible) of using current model predictions. This dashboard may also help those building fact-checking systems evaluate the potential utility of integrating similar models into their systems.
- Statement and performance exploration dashboards for interpreting model predictions and understanding its performance
- xkcd-themed visualization of UMAP-transformed statement embeddings
</details>
Expand Down Expand Up @@ -218,6 +237,7 @@ The parameters used in all Deep Classiflie job executions related to the develop
| **gen_dashboards.yaml** | parameters used to generate model analysis dashboards |
| **cust_predict.yaml** | parameters used to perform model inference on arbitrary input statements |
| **tweetbot.yaml** | parameters used to run the tweetbot behind @DeepClassiflie |
| **infsvc.yaml** | parameters used to run the inference service behind the current prediction explorer |

</div>

Expand Down
38 changes: 27 additions & 11 deletions analysis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,16 @@

import utils.constants as constants
from models.deep_classiflie_module import DeepClassiflie
from analysis.inference_utils import tokens_to_sentence, gen_embed_mappings, prep_mapping_tups, \
prep_base_mapping_tups, pred_inputs_from_config, pred_inputs_from_test, prep_rpt_tups, prep_pred_exp_tups
from analysis.inference_utils import tokens_to_sentence, gen_embed_mappings, prep_mapping_tups, prep_base_mapping_tups,\
pred_inputs_from_config, pred_inputs_from_test, prep_rpt_tups, prep_pred_exp_tups, prep_batchpred_tups
from utils.core_utils import log_config
from torch.utils.data import DataLoader
from training.training_utils import load_ckpt
from analysis.inference_utils import InferenceSession
from analysis.interpretation import InterpretTransformer

logger = logging.getLogger(constants.APP_NAME)

try:
from apex import amp
except ImportError as error:
logger.debug(f"{error.__class__.__name__}: No apex module found, fp16 will not be available.")


class Inference(object):
def __init__(self, config: MutableMapping, mapping_set: List[Tuple] = None,
Expand All @@ -30,7 +26,7 @@ def __init__(self, config: MutableMapping, mapping_set: List[Tuple] = None,
self.inf_session = InferenceSession(config, mapping_set, analysis_set, pred_exp_set, rpt_type, base_mode)

def init_predict(self, model: torch.nn.Module = None, ckpt: str = None, tokenizer: PreTrainedTokenizer = None,
eval_tuple: Tuple = None) -> Union[Tuple[List[Tuple], Optional[Dict]], Dict, List[Tuple]]:
eval_tuple: Tuple = None) -> Union[Tuple[List[Tuple], Optional[Dict]], Dict, List[Tuple], List]:
ckpt = self.init_predict_model(model, ckpt)
self.init_predict_tokenizer(tokenizer, ckpt)
self.config_interpretation()
Expand All @@ -41,6 +37,9 @@ def init_predict(self, model: torch.nn.Module = None, ckpt: str = None, tokenize
return self.pred_exp_viz(pred_inputs)
elif self.inf_session.mapping_set:
return gen_embed_mappings(self.inf_session, pred_inputs)
elif self.inf_session.config.experiment.infsvc.enabled:
inf_outputs = self.batch_predict(pred_inputs)
return inf_outputs
elif self.inf_session.config.inference.interpret_preds and self.inf_session.config.experiment.tweetbot.enabled:
unpublished_reports = self.predict_viz(pred_inputs)
return unpublished_reports
Expand Down Expand Up @@ -86,7 +85,7 @@ def config_interpretation(self) -> None:
self.inf_session.model, self.inf_session.tokenizer,
self.inf_session.device, pred_report_path)

def prep_pred_inputs(self, eval_tuple: Tuple) -> List[Dict]:
def prep_pred_inputs(self, eval_tuple: Tuple) -> Union[List[Dict], DataLoader]:
if not (eval_tuple or self.inf_session.config.inference.pred_inputs or self.inf_session.analysis_set
or self.inf_session.mapping_set or self.inf_session.pred_exp_set):
raise ValueError("init_predict must be provided inputs via either test set samples,"
Expand All @@ -99,6 +98,8 @@ def prep_pred_inputs(self, eval_tuple: Tuple) -> List[Dict]:
pred_inputs = prep_base_mapping_tups(self.inf_session)
elif self.inf_session.mapping_set:
pred_inputs = prep_mapping_tups(self.inf_session)
elif self.inf_session.config.experiment.infsvc.enabled:
pred_inputs = prep_batchpred_tups(self.inf_session)
elif eval_tuple:
num_samples = self.inf_session.config.inference.sample_predictions
pred_inputs = pred_inputs_from_test(self.inf_session, eval_tuple, num_samples)
Expand Down Expand Up @@ -127,6 +128,23 @@ def predict(self, pred_inputs: List[Dict]) -> None:
f"PREDICTION: {prob} ({round(prob)}), actual label: {round(label)}"
f" INPUT: {parsed_sent} ")

def batch_predict(self, pred_inputs: DataLoader) -> List:
self.inf_session.model.set_interpret_mode()
batch_inf_outputs = []
pred_batch_iterator = tqdm.tqdm(pred_inputs, desc="Batch")
for i, batch in enumerate(pred_batch_iterator):
batch = tuple(t.to(self.inf_session.device) for t in batch)
with torch.no_grad():
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': batch[2],
'position_ids': batch[3],
'ctxt_type': batch[4],
'labels': None}
probs = (self.inf_session.model(**inputs))
batch_inf_outputs.extend([round(p.squeeze(0).item(), 4) for p in probs])
return batch_inf_outputs

def predict_viz(self, pred_inputs: List[Dict]) -> List[Tuple]:
for sample in tqdm.tqdm(pred_inputs, desc=f'Interpreting {len(pred_inputs)} '
f'predictions and generating per-prediction reports'):
Expand Down Expand Up @@ -179,8 +197,6 @@ def gen_model_rpt(self, pred_inputs: List[Dict]) -> Tuple[List[Tuple], Dict]:
for sample in tqdm.tqdm(pred_inputs, desc=f"Generating report using {len(pred_inputs)} samples"):
input_embedding, inputs, probs, token_list, prob = self.pass_interpretable_inputs(sample)
token_list = list(filter(lambda l: l not in self.inf_session.special_token_mask, token_list))
# all records should have a label ("True" unless explicitly labeled false by wapo) unless
# using "gt, ground truth" version of scoring (model_rpt_all_tweet_data_gt)
label = sample['labels'].item() if sample['labels'] in [0, 1] else None
parsed_sent = tokens_to_sentence(self.inf_session, token_list)
# include only training data in the statement embedding
Expand Down
19 changes: 19 additions & 0 deletions analysis/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from transformers import AlbertTokenizer
import tqdm
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset

import utils.constants as constants
from analysis.interpretation import InterpretTransformer
Expand Down Expand Up @@ -155,6 +156,24 @@ def pred_inputs_from_config(inf_session: InferenceSession) -> List[Dict]:
return pred_inputs


def prep_batchpred_tups(inf_session: InferenceSession) -> DataLoader:
pred_inputs = {'all_input_ids': [], 'all_attention_masks': [], 'all_token_type_ids': [],
'all_position_ids': [], 'all_ctxt_types': []}
for i, (parentid, childid, ex, ctxt_type, _) in enumerate(inf_session.config.inference.pred_inputs):
input_ids, attention_mask, token_type_ids, position_ids = prep_model_inputs(inf_session, ex)
for k, v in zip(pred_inputs.keys(), [input_ids, attention_mask, token_type_ids, position_ids, ctxt_type]):
# noinspection PyUnresolvedReferences
pred_inputs[k].append(v)
for k, v in pred_inputs.items():
pred_inputs[k] = torch.tensor([f for f in v], dtype=torch.float) if k == 'all_ctxt_types' else \
torch.tensor([f for f in v], dtype=torch.long)
pred_dataset = TensorDataset(*list(pred_inputs.values()))
pred_sampler = SequentialSampler(pred_dataset)
pred_dataloader = DataLoader(pred_dataset, sampler=pred_sampler,
batch_size=inf_session.config.experiment.infsvc.batch_size)
return pred_dataloader


def prep_base_mapping_tups(inf_session: InferenceSession) -> List[Dict]:
pred_inputs = []
for i, (tup_id, ex, _) in tqdm.tqdm(enumerate(inf_session.mapping_set),
Expand Down
20 changes: 14 additions & 6 deletions analysis/model_analysis_rpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import constants as db_constants
from db_ingest import get_cnxp_handle
from analysis.inference import Inference
from db_utils import fetchallwrapper, batch_execute_many
from db_utils import fetchallwrapper, batch_execute_many, single_execute
from analysis.interpretation_utils import load_cache
from analysis.gen_pred_explorer import build_pred_exp_doc
from analysis.gen_perf_explorer import build_perf_exp_doc
Expand Down Expand Up @@ -102,17 +102,25 @@ def gen_report(self, rpt_type: str) -> None:
self.config.data_source.train_start_date = datetime.datetime.combine(ds_meta[1], datetime.time())
self.config.data_source.train_end_date = datetime.datetime.combine(ds_meta[2], datetime.time())
rpt_tups, stmt_embed_dict = Inference(self.config, analysis_set=analysis_set, rpt_type=rpt_type).init_predict()
inserted_rowcnt, error_rows = batch_execute_many(self.cnxp.get_connection(),
self.config.inference.sql.save_model_sql, rpt_tups)
logger.info(f"Generated {inserted_rowcnt} inference records for analysis of "
f"model version {constants.APP_INSTANCE}")
self.persist_rpt_data(rpt_tups)
self.maybe_build_cache(stmt_embed_dict)

def persist_rpt_data(self, rpt_tups):
inserted_model_rowcnt, _ = batch_execute_many(self.cnxp.get_connection(),
self.config.inference.sql.save_model_rpt_sql, rpt_tups)
logger.info(f"Generated {inserted_model_rowcnt} inference records for analysis of "
f"model version {constants.APP_INSTANCE}")
inserted_model_rowcnt, _ = single_execute(self.cnxp.get_connection(), self.config.inference.sql.save_model_sql)
logger.info(f"Generated {inserted_model_rowcnt} global model performance summary for "
f"model version {constants.APP_INSTANCE}")
inserted_perf_rowcnt, _ = single_execute(self.cnxp.get_connection(), self.config.inference.sql.save_perf_sql)
logger.info(f"Generated {inserted_perf_rowcnt} local performance summary records for "
f"model version {constants.APP_INSTANCE}")

def gen_analysis_set(self) -> List[Tuple]:
# current use case involves relatively small analysis set that fits in memory and should only be used once
# so wasteful to persist. if later use cases necessitate, will pickle or persist for larger datasets
report_sql = f"select * from {self.report_view}"
# TODO: remove this unnecessary transformation? should be able to directly return report_sql tuple list now...
analysis_set = ModelAnalysisRpt.prep_model_analysis_ds(fetchallwrapper(self.cnxp.get_connection(), report_sql))
return analysis_set

Expand Down
12 changes: 11 additions & 1 deletion configs/config_defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ experiment:
purge_intermediate_reports: False
non_twitter_update_freq_multiple: 5
dcbot_poll_interval: 180
infsvc:
enabled: False
batch_mode: True
batch_size: 16
thread_latency: 900
publish: False
skip_db_refresh: False
# purge_intermediate_reports: False
non_twitter_update_freq_multiple: 5
dcbot_poll_interval: 180
debug:
debug_enabled: False
use_debug_dataset: False
Expand Down Expand Up @@ -112,7 +122,7 @@ trainer:
max_grad_norm: 1.0
amsgrad: False
swa_mode: "best"
last_swa_snaps: 5
last_swa_snaps: 10
warmup_epochs: 1
inference:
report_mode: False
Expand Down
Loading

0 comments on commit fa6dada

Please sign in to comment.