Skip to content

Commit

Permalink
Merge pull request #6 from chorowski-lab/gcie/metrics
Browse files Browse the repository at this point in the history
Add alignment metrics
  • Loading branch information
janchorowski authored Dec 30, 2020
2 parents 3573075 + 220a660 commit 7a71cd5
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 4 deletions.
6 changes: 4 additions & 2 deletions fairseq/data/handwriting/raw_handwriting_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,15 @@ def collater(self, samples):
assert self.pad
collated_labels = torch.IntTensor(size=(len(collated_labels_nontensor), max([len(i) for i in collated_labels_nontensor]))).fill_(self.label_pad_idx)
for i, label in enumerate(collated_labels_nontensor):
collated_labels[i][:len(label)] = torch.tensor(label)
collated_labels[i][:len(label)] = label

# TODO EOS stuff (?) maybe rather as an option

# zeros where None
target_lengths = torch.LongTensor([len(t) if t is not None else 0 for t in collated_labels_nontensor])

input["alignments"] = collated_alignments

# [!] stuff with "_available" tells if data "\is actually present in the tensors or are there some defaults or sth
return {
"id": torch.LongTensor([s["id"] for s in samples]),
Expand Down Expand Up @@ -537,4 +539,4 @@ def __getitem__(self, index):
"label_text": label_text
}
else:
return {"id": index, "source": feats}
return {"id": index, "source": feats}
38 changes: 36 additions & 2 deletions fairseq/models/wav2vec/wav2vec2_scribblelens.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
@dataclass
class Wav2Vec2SLConfig(Wav2Vec2Config):
probe_defs: Optional[Dict[str, Any]] = field(default=None, metadata={"help": "probes"})
compute_alignment_metrics: bool = field(default=False, metadata={"help": "compute mutual info and rand scores"})

@register_model("wav2vec2_scribblelens", dataclass=Wav2Vec2SLConfig)
class Wav2Vec2ModelSL(BaseFairseqModel, probed_model.ProbedModel):
Expand Down Expand Up @@ -261,7 +262,7 @@ def compute_preds(self, x, y, negatives):

return logits

def forward(self, source, padding_mask=None, mask=True, features_only=False):
def forward(self, source, padding_mask=None, mask=True, features_only=False, alignments=None):
if self.feature_grad_mult > 0:
features = self.feature_extractor(source)
if self.feature_grad_mult != 1.0:
Expand All @@ -270,6 +271,8 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False):
with torch.no_grad():
features = self.feature_extractor(source)

compute_alignment_metrics = self.cfg.compute_alignment_metrics and alignments is not None

features_pen = features.float().pow(2).mean()

features = features.transpose(1, 2)
Expand All @@ -284,6 +287,8 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False):
assert extra == 0
padding_mask = padding_mask[:, ::scale]
assert np.all(padding_mask.shape == features.shape[:-1])
if compute_alignment_metrics:
alignments = alignments[:, ::scale]

if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
Expand Down Expand Up @@ -314,6 +319,8 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False):
y = unmasked_features[mask_indices].view(
unmasked_features.size(0), -1, unmasked_features.size(-1)
)
if compute_alignment_metrics:
alignments = alignments[mask_indices]
else:
y = unmasked_features
else:
Expand All @@ -327,7 +334,7 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False):
return {"x": x, "padding_mask": padding_mask}

if self.quantizer:
q = self.quantizer(y, produce_targets=False)
q = self.quantizer(y, produce_targets=compute_alignment_metrics)
y = q["x"]
num_vars = q["num_vars"]
code_ppl = q["code_perplexity"]
Expand Down Expand Up @@ -379,8 +386,35 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False):
result["num_vars"] = num_vars
result["temp"] = curr_temp

if compute_alignment_metrics:
result = {
**result,
**self.get_alignment_metrics(q["targets"], alignments)
}

return result

def get_alignment_metrics(self, targets, ali_gt):
# We currently only quantize unmasked features
import sklearn.metrics

with torch.no_grad():
targets = targets.reshape(-1, self.quantizer.groups)
targetts = targets.detach().cpu().numpy()
ali_es = targets[:, 0] + self.quantizer.num_vars * targets[:, 1]
ali_es = ali_es.detach().cpu().numpy()
ali_gt = ali_gt.detach().cpu().numpy()

return {
"adjusted_mutual_info":
sklearn.metrics.adjusted_mutual_info_score(ali_gt, ali_es, average_method='arithmetic'),
"normalized_mutual_info":
sklearn.metrics.normalized_mutual_info_score(ali_gt, ali_es, average_method='arithmetic'),
"adjusted_rand_score":
sklearn.metrics.adjusted_rand_score(ali_gt, ali_es)
}


def quantize(self, x):
assert self.quantizer is not None
x = self.feature_extractor(x)
Expand Down
80 changes: 80 additions & 0 deletions uwr_related/configs/scribblelens_base_metrics.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# @package _group_

common:
fp16: false
log_format: json
log_interval: 20
tensorboard_logdir: tensorboard

checkpoint:
keep_last_epochs: 3

task:
_name: scribblelens
data: /pio/scratch/2/jch/wav2vec/data/scribblelens
vocab_path: '${env:PWD}/fairseq/data/handwriting/tasman.alphabet.plus.space.mode5.json'
enable_padding: True
pad_to_multiples_of: 4
max_sample_size: 250000
min_sample_size: 32000
normalize: false
labels: True

dataset:
num_workers: 0
max_tokens: 10000
skip_invalid_size_inputs_valid_test: true
valid_subset: test

distributed_training:
distributed_world_size: 1
ddp_backend: no_c10d

criterion:
_name: wav2vec
infonce: true
log_keys: ["prob_perplexity","code_perplexity","temp", "adjusted_mutual_info","normalized_mutual_info","adjusted_rand_score"]
loss_weights: [0.1, 10]

optimization:
max_update: 400000
lr: [0.0003]

optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-06
weight_decay: 0.01

lr_scheduler:
_name: polynomial_decay
warmup_updates: 20000

model:
_name: wav2vec2_scribblelens
conv_feature_layers: '[(64, (3, 3), (1, 2), (1, 1)), (128, (5, 5), (2, 2), (2, 2)), (256, (3,3), (1, 1), (1, 1)), (256, (3,3), (1, 2), (1, 1)), (512, (3,3), (1, 1), (1, 1)), (512, (3,3), (1, 2), (1, 1)), (512, (3,2), (2, 1), (1, 0))]'
quantize_targets: true
final_dim: 256
encoder_embed_dim: 768

encoder_layerdrop: 0.05
dropout: 0.1
attention_dropout: 0.1
dropout_input: 0.1
dropout_features: 0.1
feature_grad_mult: 0.1

latent_vars: 320
latent_groups: 2
latent_temp: [2,0.5,0.999995]

compute_alignment_metrics: true

probe_defs:
post_extract_proj_mlp:
cls: Conv1DProbe
module_name: post_extract_proj
layer_dims: [768, 512, 73]
kernel_size: 3
output_selector: 'lambda x: {"output": x.transpose(1, 2)}'
target_selector: 'lambda x: {"target":x["alignments"], "padding_mask": x["net_input"].get("padding_mask")}'
23 changes: 23 additions & 0 deletions uwr_related/experiments/gci/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
python train.py --distributed-world-size 1 --update-freq 2 \
/pio/scratch/1/i283340/MGR/NewSetup/DistSup/data \
--save-dir /pio/scratch/1/i290956/try_sl2 --num-workers 0 \
--keep-last-epochs 3 \
--tensorboard-logdir /pio/scratch/1/i290956/runs/try_sl2 --log-format simple \
--task scribblelens --criterion wav2vec --arch wav2vec2_scribblelens \
--valid-subset test --pad-to-multiples-of 4 `#--max-sample-size 256` \
--log-keys '["prob_perplexity","code_perplexity","temp","adjusted_mutual_info","normalized_mutual_info","adjusted_rand_score"]' --quantize-targets --extractor-mode default \
--conv-feature-layers '[(64, (3, 3), (1, 2), (1, 1)), (128, (5, 5), (2, 2), (2, 2)), (256, (3,3), (1, 1), (1, 1)), (256, (3,3), (1, 2), (1, 1)), (512, (3,3), (1, 1), (1, 1)), (512, (3,3), (1, 2), (1, 1)), (512, (3,2), (2, 1), (1, 0))]' \
--final-dim 256 \
--latent-vars 320 --latent-groups 2 --latent-temp '(2,0.5,0.999995)' --infonce \
--optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay \
--total-num-update 400000 --lr 0.0005 --warmup-updates 32000 \
--mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 \
--encoder-layerdrop 0.05 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.1 \
--loss-weights '[0.1, 10]' --conv-pos 128 --conv-pos-groups 16 \
--num-negatives 100 --cross-sample-negatives 0 \
`#--max-sample-size 250000 --min-sample-size 32000` \
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --max-tokens 10000 --max-update 400000 \
--skip-invalid-size-inputs-valid-test --ddp-backend no_c10d \
--labels "a" \
--enable-padding # crashes without that, needs to make all lines same-size

0 comments on commit 7a71cd5

Please sign in to comment.