Skip to content

Commit

Permalink
Merge pull request #255 from ASUS-AICS/NN-Data_util-loading-dataframe
Browse files Browse the repository at this point in the history
NN load_datasets loading dataframe.
  • Loading branch information
Gordon119 authored Feb 16, 2023
2 parents 0b80898 + 2c70516 commit a9e2cae
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 25 deletions.
57 changes: 35 additions & 22 deletions libmultilabel/nn/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,19 @@ def get_dataset_loader(
return dataset_loader


def _load_raw_data(path, is_test=False, tokenize_text=True, remove_no_label_data=False):
"""Load and tokenize raw data.
def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data=False):
"""Load and tokenize raw data in file or dataframe.
Args:
path (str): Path to training, test, or validation data.
data (Union[str, pandas,.Dataframe]): Training, test, or validation data in file or dataframe.
is_test (bool, optional): Whether the data is for test or not. Defaults to False.
remove_no_label_data (bool, optional): Whether to remove training/validation instances that have no labels.
This is effective only when is_test=False. Defaults to False.
Returns:
pandas.DataFrame: Data composed of index, label, and tokenized text.
"""
logging.info(f'Load data from {path}.')
data = pd.read_csv(path, sep='\t', header=None,
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
data = data.astype(str)
if data.shape[1] == 2:
data.columns = ['label', 'text']
data = data.reset_index()
Expand Down Expand Up @@ -166,21 +164,21 @@ def _load_raw_data(path, is_test=False, tokenize_text=True, remove_no_label_data


def load_datasets(
training_file=None,
test_file=None,
val_file=None,
training_data=None,
test_data=None,
val_data=None,
val_size=0.2,
merge_train_val=False,
tokenize_text=True,
remove_no_label_data=False
):
"""Load data from the specified data paths (i.e., `training_file`, `test_file`, and `val_file`).
If `valid.txt` does not exist but `val_size` > 0, the validation set will be split from the training dataset.
"""Load data from the specified data paths or the given dataframe.
If `val_data` does not exist but `val_size` > 0, the validation set will be split from the training dataset.
Args:
training_file (str, optional): Path to training data.
test_file (str, optional): Path to test data.
val_file (str, optional): Path to validation data.
training_data (Union[str, pandas,.Dataframe], optional): Path to training data or a dataframe.
test_data (Union[str, pandas,.Dataframe], optional): Path to test data or a dataframe.
val_data (Union[str, pandas,.Dataframe], optional): Path to validation data or a dataframe.
val_size (float, optional): Training-validation split: a ratio in [0, 1] or an integer for the size of the validation set.
Defaults to 0.2.
merge_train_val (bool, optional): Whether to merge the training and validation data.
Expand All @@ -192,22 +190,37 @@ def load_datasets(
Returns:
dict: A dictionary of datasets.
"""
assert training_file or test_file, "At least one of `training_file` and `test_file` must be specified."
if isinstance(training_data, str) or isinstance(test_data, str):
assert training_data or test_data, "At least one of `training_data` and `test_data` must be specified."
elif isinstance(training_data, pd.DataFrame) or isinstance(test_data, pd.DataFrame):
assert not training_data.empty or not test_data.empty, "At least one of `training_data` and `test_data` must be specified."

datasets = {}
if training_file is not None:
datasets['train'] = _load_raw_data(training_file, tokenize_text=tokenize_text,
if training_data is not None:
if isinstance(training_data, str):
logging.info(f'Load data from {training_data}.')
training_data = pd.read_csv(training_data, sep='\t', header=None,
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
datasets['train'] = _load_raw_data(training_data, tokenize_text=tokenize_text,
remove_no_label_data=remove_no_label_data)

if val_file is not None:
datasets['val'] = _load_raw_data(val_file, tokenize_text=tokenize_text,
remove_no_label_data=remove_no_label_data)
if val_data is not None:
if isinstance(val_data, str):
logging.info(f'Load data from {val_data}.')
val_data = pd.read_csv(val_data, sep='\t', header=None,
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
datasets['val'] = _load_raw_data(val_data, tokenize_text=tokenize_text,
remove_no_label_data=remove_no_label_data)
elif val_size > 0:
datasets['train'], datasets['val'] = train_test_split(
datasets['train'], test_size=val_size, random_state=42)

if test_file is not None:
datasets['test'] = _load_raw_data(test_file, is_test=True, tokenize_text=tokenize_text,
if test_data is not None:
if isinstance(test_data, str):
logging.info(f'Load data from {test_data}.')
test_data = pd.read_csv(test_data, sep='\t', header=None,
error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('')
datasets['test'] = _load_raw_data(test_data, is_test=True, tokenize_text=tokenize_text,
remove_no_label_data=remove_no_label_data)

if merge_train_val:
Expand Down
6 changes: 3 additions & 3 deletions torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def __init__(
# Load dataset
if datasets is None:
self.datasets = data_utils.load_datasets(
training_file=config.training_file,
test_file=config.test_file,
val_file=config.val_file,
training_data=config.training_file,
test_data=config.test_file,
val_data=config.val_file,
val_size=config.val_size,
merge_train_val=config.merge_train_val,
tokenize_text=tokenize_text,
Expand Down

0 comments on commit a9e2cae

Please sign in to comment.