diff --git a/libmultilabel/nn/data_utils.py b/libmultilabel/nn/data_utils.py index 5c4c771b..7647cbab 100644 --- a/libmultilabel/nn/data_utils.py +++ b/libmultilabel/nn/data_utils.py @@ -123,11 +123,11 @@ def get_dataset_loader( return dataset_loader -def _load_raw_data(path, is_test=False, tokenize_text=True, remove_no_label_data=False): - """Load and tokenize raw data. +def _load_raw_data(data, is_test=False, tokenize_text=True, remove_no_label_data=False): + """Load and tokenize raw data in file or dataframe. Args: - path (str): Path to training, test, or validation data. + data (Union[str, pandas,.Dataframe]): Training, test, or validation data in file or dataframe. is_test (bool, optional): Whether the data is for test or not. Defaults to False. remove_no_label_data (bool, optional): Whether to remove training/validation instances that have no labels. This is effective only when is_test=False. Defaults to False. @@ -135,9 +135,7 @@ def _load_raw_data(path, is_test=False, tokenize_text=True, remove_no_label_data Returns: pandas.DataFrame: Data composed of index, label, and tokenized text. """ - logging.info(f'Load data from {path}.') - data = pd.read_csv(path, sep='\t', header=None, - error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('') + data = data.astype(str) if data.shape[1] == 2: data.columns = ['label', 'text'] data = data.reset_index() @@ -166,21 +164,21 @@ def _load_raw_data(path, is_test=False, tokenize_text=True, remove_no_label_data def load_datasets( - training_file=None, - test_file=None, - val_file=None, + training_data=None, + test_data=None, + val_data=None, val_size=0.2, merge_train_val=False, tokenize_text=True, remove_no_label_data=False ): - """Load data from the specified data paths (i.e., `training_file`, `test_file`, and `val_file`). - If `valid.txt` does not exist but `val_size` > 0, the validation set will be split from the training dataset. + """Load data from the specified data paths or the given dataframe. + If `val_data` does not exist but `val_size` > 0, the validation set will be split from the training dataset. Args: - training_file (str, optional): Path to training data. - test_file (str, optional): Path to test data. - val_file (str, optional): Path to validation data. + training_data (Union[str, pandas,.Dataframe], optional): Path to training data or a dataframe. + test_data (Union[str, pandas,.Dataframe], optional): Path to test data or a dataframe. + val_data (Union[str, pandas,.Dataframe], optional): Path to validation data or a dataframe. val_size (float, optional): Training-validation split: a ratio in [0, 1] or an integer for the size of the validation set. Defaults to 0.2. merge_train_val (bool, optional): Whether to merge the training and validation data. @@ -192,22 +190,37 @@ def load_datasets( Returns: dict: A dictionary of datasets. """ - assert training_file or test_file, "At least one of `training_file` and `test_file` must be specified." + if isinstance(training_data, str) or isinstance(test_data, str): + assert training_data or test_data, "At least one of `training_data` and `test_data` must be specified." + elif isinstance(training_data, pd.DataFrame) or isinstance(test_data, pd.DataFrame): + assert not training_data.empty or not test_data.empty, "At least one of `training_data` and `test_data` must be specified." datasets = {} - if training_file is not None: - datasets['train'] = _load_raw_data(training_file, tokenize_text=tokenize_text, + if training_data is not None: + if isinstance(training_data, str): + logging.info(f'Load data from {training_data}.') + training_data = pd.read_csv(training_data, sep='\t', header=None, + error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('') + datasets['train'] = _load_raw_data(training_data, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data) - if val_file is not None: - datasets['val'] = _load_raw_data(val_file, tokenize_text=tokenize_text, - remove_no_label_data=remove_no_label_data) + if val_data is not None: + if isinstance(val_data, str): + logging.info(f'Load data from {val_data}.') + val_data = pd.read_csv(val_data, sep='\t', header=None, + error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('') + datasets['val'] = _load_raw_data(val_data, tokenize_text=tokenize_text, + remove_no_label_data=remove_no_label_data) elif val_size > 0: datasets['train'], datasets['val'] = train_test_split( datasets['train'], test_size=val_size, random_state=42) - if test_file is not None: - datasets['test'] = _load_raw_data(test_file, is_test=True, tokenize_text=tokenize_text, + if test_data is not None: + if isinstance(test_data, str): + logging.info(f'Load data from {test_data}.') + test_data = pd.read_csv(test_data, sep='\t', header=None, + error_bad_lines=False, warn_bad_lines=True, quoting=csv.QUOTE_NONE).fillna('') + datasets['test'] = _load_raw_data(test_data, is_test=True, tokenize_text=tokenize_text, remove_no_label_data=remove_no_label_data) if merge_train_val: diff --git a/torch_trainer.py b/torch_trainer.py index 6516b7ea..166d82ad 100644 --- a/torch_trainer.py +++ b/torch_trainer.py @@ -56,9 +56,9 @@ def __init__( # Load dataset if datasets is None: self.datasets = data_utils.load_datasets( - training_file=config.training_file, - test_file=config.test_file, - val_file=config.val_file, + training_data=config.training_file, + test_data=config.test_file, + val_data=config.val_file, val_size=config.val_size, merge_train_val=config.merge_train_val, tokenize_text=tokenize_text,