Skip to content

Commit

Permalink
Improve code + add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
laurens88 committed Apr 4, 2024
1 parent 9088356 commit ad6e6c3
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 36 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ LAB](https://github.com/asreview/asreview) that can be used to:
- [**Stack**](#data-vstack-experimental) multiple datasets
- [**Compose**](#data-compose-experimental) a single (labeled, partly labeled, or unlabeled) dataset from multiple datasets
- [**Snowball**](#snowball) a dataset to find incoming or outgoing citations
- [**Sample**](#sample) n old, n random, and n new papers in order to check if the terminology has changed over time.
- [**Sample**](#sample) old, random, and new papers in order to check if the terminology has changed over time.

Several [tutorials](Tutorials.md) are available that show how
`ASReview-Datatools` can be used in different scenarios.
Expand Down Expand Up @@ -285,10 +285,10 @@ asreview data snowball input_dataset.csv output_dataset.csv --backward --email m

## Sample

This datatool is used to sample n old, random and new records from your dataset by using the `asreview data sample` command. The sampled records are then stored in an output file. This can be useful for detecting concept drift, meaning that the words used for certain concepts change over time. This script assumes that the dataset includes a column named `publication_year`. An example would be:
This datatool is used to sample old, random and new records from your dataset by using the `asreview data sample` command. The sampled records are then stored in an output file. This can be useful for detecting concept drift, meaning that the words used for certain concepts change over time. This script assumes that the dataset includes a column named `publication_year`. An example would be:

```bash
asreview data sample output_dataset.xslx input_dataset.xlsx 50
asreview data sample input_dataset.xlsx output_dataset.xslx 50
```
This samples the `50` oldest and `50` newest records from `input_dataset.xlsx` and samples `50` records randomly (without overlap from the old and new partitions!). The resulting 150 records are written to `output_dataset.xlsx`.

Expand Down
53 changes: 20 additions & 33 deletions asreviewcontrib/datatools/sample.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,38 @@
import argparse
from pathlib import Path

import pandas as pd
from asreview import ASReviewData
from asreview.data.base import load_data


def _check_suffix(input_file, output_file):
# Also raises ValueError on URLs that do not end with a file extension
suffixes = [Path(input_file).suffix, Path(output_file).suffix]

set_ris = {".txt", ".ris"}
set_tabular = {".csv", ".tab", ".tsv", ".xlsx"}
set_suffixes = set(suffixes)

if len(set(suffixes)) > 1:
if not (set_suffixes.issubset(set_ris) or set_suffixes.issubset(set_tabular)):
raise ValueError(
"• Several file types were given; The input and the output file"
"should be of the same type. "
)


def sample(output_path, input_path, nr_records):
_check_suffix(input_path, output_path)

def sample(input_path, output_path, nr_records, year_column="publication_year"):
df_input = load_data(input_path).df

# Check for presence of any variation of a year column
if "publication_year" not in df_input.columns:
raise ValueError("• The input file should have a 'publication_year' column.")
if year_column not in df_input.columns:
raise ValueError(f"• The input file should have a {year_column} column.")

# Check if k is not too large
if nr_records*3 > len(df_input):
if nr_records * 3 > len(df_input):
raise ValueError(
f"• The number of records to sample is too large."
f"Only {len(df_input)} records are present in the input file."
f" You are trying to sample {nr_records*3} records."
)

if nr_records < 1:
raise ValueError("• The number of records to sample should be at least 1.")

# Sort by year
dated_records = df_input[df_input["publication_year"].notnull()]
dated_records = df_input[df_input[year_column].notnull()]

if dated_records.empty:
raise ValueError("• The input file has no publication_year values.")
if len(dated_records) < nr_records*2:
raise ValueError(f"• The input file has no {year_column} values.")

if len(dated_records) < nr_records * 2:
raise ValueError("• Not enough dated records to sample from.")

sorted_records = dated_records.sort_values("publication_year", ascending=True)
sorted_records = dated_records.sort_values(year_column, ascending=True)

# Take k old and k new records
old_records = sorted_records.head(nr_records)
Expand All @@ -65,19 +46,25 @@ def sample(output_path, input_path, nr_records):

# Combine old, new, and sampled records
df_out = pd.concat([old_records, sampled_records, new_records])

asdata = ASReviewData(df=df_out)
asdata.to_file(output_path)


def _parse_arguments_sample():
parser = argparse.ArgumentParser(prog="asreview data sample")
parser.add_argument("output_path", type=str, help="The output file path.")
parser.add_argument("input_path", type=str, help="The input file path.")
parser.add_argument("output_path", type=str, help="The output file path.")
parser.add_argument(
"nr_records",
type=int,
help="The amount of records for old, random, and new records each."
help="The amount of records for old, random, and new records each.",
)
parser.add_argument(
"--year_column",
default="publication_year",
type=str,
help="The name of the column containing the publication year.",
)

return parser
return parser
7 changes: 7 additions & 0 deletions tests/demo_data/sample_data.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
title, doi, publication_year
title1, doi1, 2005
title2, doi2, 2001
title3, doi3,
title4, doi4, 2003
title5, doi5, 2004
title6, doi6, 2000
17 changes: 17 additions & 0 deletions tests/test_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# create unit tests for the sample.py file
from pathlib import Path

import pandas as pd

from asreviewcontrib.datatools.sample import sample

INPUT_DIR = Path(__file__).parent / "demo_data" / "sample_data.csv"


def test_sample(tmpdir):
sample(INPUT_DIR, tmpdir / "output.csv", 1, "publication_year")
df = pd.read_csv(tmpdir / "output.csv")
assert len(df) == 3
assert "publication_year" in df.columns
assert df.iloc[0]["publication_year"] == 2000
assert df.iloc[2]["publication_year"] == 2005

0 comments on commit ad6e6c3

Please sign in to comment.