-
Notifications
You must be signed in to change notification settings - Fork 7
/
selection_utils.py
69 lines (58 loc) · 2.38 KB
/
selection_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import json
import logging
import numpy as np
import os
import pandas as pd
import tqdm
from typing import List
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
def log_training_dynamics(output_dir: os.path,
epoch: int,
train_ids: List[int],
train_logits: List[List[float]],
train_golds: List[int]):
"""
Save training dynamics (logits) from given epoch as records of a `.jsonl` file.
"""
td_df = pd.DataFrame({"guid": train_ids,
f"logits_epoch_{epoch}": train_logits,
"gold": train_golds})
logging_dir = os.path.join(output_dir, f"training_dynamics")
# Create directory for logging training dynamics, if it doesn't already exist.
if not os.path.exists(logging_dir):
os.makedirs(logging_dir)
epoch_file_name = os.path.join(logging_dir, f"dynamics_epoch_{epoch}.jsonl")
td_df.to_json(epoch_file_name, lines=True, orient="records")
logger.info(f"Training Dynamics logged to {epoch_file_name}")
def read_training_dynamics(model_dir: os.path,
strip_last: bool = False,
id_field: str = "guid",
burn_out: int = None):
"""
Given path to logged training dynamics, merge stats across epochs.
Returns:
- Dict between ID of a train instances and its gold label, and the list of logits across epochs.
"""
train_dynamics = {}
td_dir = os.path.join(model_dir, "training_dynamics")
num_epochs = len([f for f in os.listdir(td_dir) if os.path.isfile(os.path.join(td_dir, f))])
if burn_out:
num_epochs = burn_out
logger.info(f"Reading {num_epochs} files from {td_dir} ...")
for epoch_num in tqdm.tqdm(range(num_epochs)):
epoch_file = os.path.join(td_dir, f"dynamics_epoch_{epoch_num}.jsonl")
assert os.path.exists(epoch_file)
with open(epoch_file, "r") as infile:
# print('*** Current Reading:',epoch_file)
for line in infile:
record = json.loads(line.strip())
guid = record[id_field] if not strip_last else record[id_field][:-1]
if guid not in train_dynamics:
assert epoch_num == 0
train_dynamics[guid] = {"gold": record["gold"], "logits": []}
train_dynamics[guid]["logits"].append(record[f"logits_epoch_{epoch_num}"])
logger.info(f"Read training dynamics for {len(train_dynamics)} train instances.")
return train_dynamics