Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Source features support for V2.0 #2090

Merged
merged 24 commits into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion onmt/bin/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ def build_vocab_main(opts):
transforms = make_transforms(opts, transforms_cls, fields)

logger.info(f"Counter vocab from {opts.n_sample} samples.")
src_counter, tgt_counter = build_vocab(
src_counter, tgt_counter, src_feats_counter = build_vocab(
opts, transforms, n_sample=opts.n_sample)

logger.info(f"Counters src:{len(src_counter)}")
logger.info(f"Counters tgt:{len(tgt_counter)}")
for k, v in src_feats_counter["src_feats"].items():
anderleich marked this conversation as resolved.
Show resolved Hide resolved
logger.info(f"Counters {k}:{len(v)}")

def save_counter(counter, save_path):
check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
Expand All @@ -52,6 +54,9 @@ def save_counter(counter, save_path):
else:
save_counter(src_counter, opts.src_vocab)
save_counter(tgt_counter, opts.tgt_vocab)

for k, v in src_feats_counter["src_feats"].items():
save_counter(v, opts.src_feats_vocab[k])


def _get_parser():
Expand Down
61 changes: 45 additions & 16 deletions onmt/inputters/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchtext.data import Dataset as TorchtextDataset, \
Example as TorchtextExample

from collections import Counter
from collections import Counter, defaultdict
from contextlib import contextmanager

import multiprocessing as mp
Expand Down Expand Up @@ -74,6 +74,9 @@ def _process(item, is_train):
maybe_example['tgt'] = ' '.join(maybe_example['tgt'])
if 'align' in maybe_example:
maybe_example['align'] = ' '.join(maybe_example['align'])
if 'src_feats' in maybe_example:
for k in maybe_example['src_feats'].keys():
maybe_example['src_feats'][k] = ' '.join(maybe_example['src_feats'][k])
return maybe_example

def _maybe_add_dynamic_dict(self, example, fields):
Expand Down Expand Up @@ -107,23 +110,30 @@ def __call__(self, bucket):
class ParallelCorpus(object):
"""A parallel corpus file pair that can be loaded to iterate."""

def __init__(self, name, src, tgt, align=None):
def __init__(self, name, src, tgt, align=None, src_feats=None):
"""Initialize src & tgt side file path."""
self.id = name
self.src = src
self.tgt = tgt
self.align = align
self.src_feats = src_feats

def load(self, offset=0, stride=1):
"""
Load file and iterate by lines.
`offset` and `stride` allow to iterate only on every
`stride` example, starting from `offset`.
"""
#import pdb
#pdb.set_trace()
if self.src_feats:
features_files = [open(feat_path, mode='rb') for feat_name, feat_path in self.src_feats.items()]
else:
features_files = []
with exfile_open(self.src, mode='rb') as fs,\
exfile_open(self.tgt, mode='rb') as ft,\
exfile_open(self.align, mode='rb') as fa:
for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)):
for i, (sline, tline, align, *features) in enumerate(zip(fs, ft, fa, *features_files)):
if (i % stride) == offset:
sline = sline.decode('utf-8')
tline = tline.decode('utf-8')
Expand All @@ -133,12 +143,18 @@ def load(self, offset=0, stride=1):
}
if align is not None:
example['align'] = align.decode('utf-8')
if features:
example["src_feats"] = dict()
for j, feat in enumerate(features):
example["src_feats"][list(self.src_feats.keys())[j]] = feat.decode("utf-8")
anderleich marked this conversation as resolved.
Show resolved Hide resolved
yield example
for f in features_files:
f.close()

def __str__(self):
cls_name = type(self).__name__
return '{}({}, {}, align={})'.format(
cls_name, self.src, self.tgt, self.align)
return '{}({}, {}, align={}, src_feats={})'.format(
cls_name, self.src, self.tgt, self.align, self.src_feats)


def get_corpora(opts, is_train=False):
Expand All @@ -150,7 +166,8 @@ def get_corpora(opts, is_train=False):
corpus_id,
corpus_dict["path_src"],
corpus_dict["path_tgt"],
corpus_dict["path_align"])
corpus_dict["path_align"],
corpus_dict["src_feats"])
else:
if CorpusName.VALID in opts.data.keys():
corpora_dict[CorpusName.VALID] = ParallelCorpus(
Expand Down Expand Up @@ -193,6 +210,9 @@ def _tokenize(self, stream):
example['src'], example['tgt'] = src, tgt
if 'align' in example:
example['align'] = example['align'].strip('\n').split()
if 'src_feats' in example:
for k in example['src_feats'].keys():
example['src_feats'][k] = example['src_feats'][k].strip('\n').split()
yield example

def _transform(self, stream):
Expand Down Expand Up @@ -284,8 +304,11 @@ def write_files_from_queues(sample_path, queues):

def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
"""Build vocab on (strided) subpart of the data."""
#import pdb
#pdb.set_trace()
sub_counter_src = Counter()
sub_counter_tgt = Counter()
sub_counter_src_feats = {'src_feats': defaultdict(Counter)}
anderleich marked this conversation as resolved.
Show resolved Hide resolved
datasets_iterables = build_corpora_iters(
corpora, transforms, opts.data,
skip_empty_level=opts.skip_empty_level,
Expand All @@ -298,6 +321,9 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
build_sub_vocab.queues[c_name][offset].put("blank")
continue
src_line, tgt_line = maybe_example['src'], maybe_example['tgt']
if 'src_feats' in maybe_example:
for feat_name, feat_line in maybe_example["src_feats"].items():
sub_counter_src_feats['src_feats'][feat_name].update(feat_line.split(' '))
sub_counter_src.update(src_line.split(' '))
sub_counter_tgt.update(tgt_line.split(' '))
if opts.dump_samples:
Expand All @@ -309,7 +335,7 @@ def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
break
if opts.dump_samples:
build_sub_vocab.queues[c_name][offset].put("break")
return sub_counter_src, sub_counter_tgt
return sub_counter_src, sub_counter_tgt, sub_counter_src_feats


def init_pool(queues):
Expand All @@ -333,6 +359,7 @@ def build_vocab(opts, transforms, n_sample=3):
corpora = get_corpora(opts, is_train=True)
counter_src = Counter()
counter_tgt = Counter()
counter_src_feats = {'src_feats': defaultdict(Counter)}
anderleich marked this conversation as resolved.
Show resolved Hide resolved
from functools import partial
queues = {c_name: [mp.Queue(opts.vocab_sample_queue_size)
for i in range(opts.num_threads)]
Expand All @@ -345,17 +372,19 @@ def build_vocab(opts, transforms, n_sample=3):
args=(sample_path, queues),
daemon=True)
write_process.start()
with mp.Pool(opts.num_threads, init_pool, [queues]) as p:
func = partial(
build_sub_vocab, corpora, transforms,
opts, n_sample, opts.num_threads)
for sub_counter_src, sub_counter_tgt in p.imap(
func, range(0, opts.num_threads)):
counter_src.update(sub_counter_src)
counter_tgt.update(sub_counter_tgt)
#with mp.Pool(opts.num_threads, init_pool, [queues]) as p:
func = partial(
build_sub_vocab, corpora, transforms,
opts, n_sample, opts.num_threads)
sub_counter_src, sub_counter_tgt, sub_counter_src_feats = func(0)
# for sub_counter_src, sub_counter_tgt in p.imap(
# func, range(0, opts.num_threads)):
counter_src.update(sub_counter_src)
counter_tgt.update(sub_counter_tgt)
counter_src_feats.update(sub_counter_src_feats)
anderleich marked this conversation as resolved.
Show resolved Hide resolved
if opts.dump_samples:
write_process.join()
return counter_src, counter_tgt
return counter_src, counter_tgt, counter_src_feats


def save_transformed_sample(opts, transforms, n_sample=3):
Expand Down
12 changes: 9 additions & 3 deletions onmt/inputters/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

def _get_dynamic_fields(opts):
# NOTE: not support nfeats > 0 yet
src_nfeats = 0
tgt_nfeats = 0
#src_nfeats = 0
tgt_nfeats = None #0
anderleich marked this conversation as resolved.
Show resolved Hide resolved
with_align = hasattr(opts, 'lambda_align') and opts.lambda_align > 0.0
fields = get_fields('text', src_nfeats, tgt_nfeats,
fields = get_fields('text', opts.src_feats_vocab, tgt_nfeats,
dynamic_dict=opts.copy_attn,
src_truncate=opts.src_seq_length_trunc,
tgt_truncate=opts.tgt_seq_length_trunc,
Expand All @@ -33,6 +33,12 @@ def build_dynamic_fields(opts, src_specials=None, tgt_specials=None):
opts.src_vocab, 'src', counters,
min_freq=opts.src_words_min_frequency)

if opts.src_feats_vocab:
for feat_name, filepath in opts.src_feats_vocab.items():
_, _ = _load_vocab(
filepath, feat_name, counters,
min_freq=0)

if opts.tgt_vocab:
_tgt_vocab, _tgt_vocab_size = _load_vocab(
opts.tgt_vocab, 'tgt', counters,
Expand Down
45 changes: 31 additions & 14 deletions onmt/inputters/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,37 @@ def text_fields(**kwargs):
eos = kwargs.get("eos", DefaultTokens.EOS)
truncate = kwargs.get("truncate", None)
fields_ = []
feat_delim = u"│" if n_feats > 0 else None
for i in range(n_feats + 1):
name = base_name + "_feat_" + str(i - 1) if i > 0 else base_name
tokenize = partial(
_feature_tokenize,
layer=i,
truncate=truncate,
feat_delim=feat_delim)
use_len = i == 0 and include_lengths
feat = Field(
init_token=bos, eos_token=eos,
pad_token=pad, tokenize=tokenize,
include_lengths=use_len)
fields_.append((name, feat))

feat_delim = None #u"│" if n_feats > 0 else None

# Base field
tokenize = partial(
_feature_tokenize,
layer=None,
truncate=truncate,
feat_delim=feat_delim)
anderleich marked this conversation as resolved.
Show resolved Hide resolved
feat = Field(
init_token=bos, eos_token=eos,
pad_token=pad, tokenize=tokenize,
include_lengths=include_lengths)
fields_.append((base_name, feat))

# Feats fields
#for i in range(n_feats + 1):
if n_feats:
for feat_name in n_feats.keys():
#name = base_name + "_feat_" + str(i - 1) if i > 0 else base_name
tokenize = partial(
_feature_tokenize,
layer=None,
truncate=truncate,
feat_delim=feat_delim)
feat = Field(
init_token=bos, eos_token=eos,
pad_token=pad, tokenize=tokenize,
include_lengths=False)
fields_.append((feat_name, feat))

assert fields_[0][0] == base_name # sanity check
field = TextMultiField(fields_[0][0], fields_[0][1], fields_[1:])
return field
5 changes: 5 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ def _add_dynamic_fields_opts(parser, build_vocab_only=False):
group.add("-share_vocab", "--share_vocab", action="store_true",
help="Share source and target vocabulary.")

group.add("-src_feats_vocab", "--src_feats_vocab",
help=("List of paths to save" if build_vocab_only else "List of paths to")
+ " src features vocabulary files. "
"Files format: one <word> or <word>\t<count> per line.")

if not build_vocab_only:
group.add("-src_vocab_size", "--src_vocab_size",
type=int, default=50000,
Expand Down
3 changes: 3 additions & 0 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def main(opt, fields, transforms_cls, checkpoint, device_id,
"""Start training on `device_id`."""
# NOTE: It's important that ``opt`` has been validated and updated
# at this point.

#import pdb
#pdb.set_trace()
configure_process(opt, device_id)
init_logger(opt.log_file)

Expand Down
97 changes: 97 additions & 0 deletions onmt/transforms/features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from onmt.utils.logging import logger
from onmt.transforms import register_transform
from .transform import Transform, ObservableStats
import re
from collections import defaultdict


@register_transform(name='filterfeats')
class FilterFeatsTransform(Transform):
"""Filter out examples with a mismatch between source and features."""

def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
pass

def _parse_opts(self):
pass
anderleich marked this conversation as resolved.
Show resolved Hide resolved

def apply(self, example, is_train=False, stats=None, **kwargs):
"""Return None if mismatch"""

if 'src_feats' not in example:
# Do nothing
return example

for feat_name, feat_values in example['src_feats'].items():
if len(example['src']) != len(feat_values):
logger.warning(f"Skipping example due to mismatch between source and feature {feat_name}")
return None
return example

def _repr_args(self):
return ''


@register_transform(name='inferfeats')
class InferFeatsTransform(Transform):
"""Infer features for subword tokenization."""

anderleich marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, opts):
super().__init__(opts)

@classmethod
def add_options(cls, parser):
pass

def _parse_opts(self):
pass

def apply(self, example, is_train=False, stats=None, **kwargs):

if "src_feats" not in example:
# Do nothing
return example

feats_i = 0
inferred_feats = defaultdict(list)
for subword in example["src"]:
next_ = False
for k, v in example["src_feats"].items():
# TODO: what about custom placeholders??

# Placeholders
if re.match(r'⦅\w+⦆', subword):
inferred_feat = "N"

# Punctuation only
elif not re.sub(r'(\W)+', '', subword).strip():
inferred_feat = "N"

# Joiner annotate
elif re.search("■", subword):
inferred_feat = v[feats_i]

# Whole word
else:
inferred_feat = v[feats_i]
next_ = True

inferred_feats[k].append(inferred_feat)

if next_:
feats_i += 1

# Check all features have been consumed
for k, v in example["src_feats"].items():
assert feats_i == len(v), f'Not all features consumed for {k}'

for k, v in inferred_feats.items():
example["src_feats"][k] = inferred_feats[k]
return example

def _repr_args(self):
return ''
Loading