Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Yelp Reviews Datasets #20

Merged
merged 10 commits into from
Dec 7, 2021
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Scidata currently supports the following training and test datasets:
- CIFAR100
- FashionMNIST
- IMDb Reviews
- Yelp Reviews (Full and Polarity)
- MNIST

Download or fetch datasets locally:
Expand Down
2 changes: 1 addition & 1 deletion lib/scidata/utils.ex
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ defmodule Scidata.Utils do

defp decode({request, response}) do
cond do
String.ends_with?(request.url, ".tar.gz") ->
String.ends_with?(request.url, ".tar.gz") or String.ends_with?(request.url, ".tgz") ->
{:ok, files} = :erl_tar.extract({:binary, response.body}, [:memory, :compressed])
response = %{response | body: files}
{request, response}
Expand Down
44 changes: 44 additions & 0 deletions lib/scidata/yelp_full_reviews.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
defmodule Scidata.YelpFullReviews do
@moduledoc """
Module for downloading the [Yelp Reviews dataset](https://www.yelp.com/dataset).
"""

@base_url "https://s3.amazonaws.com/fast-ai-nlp/"

@dataset_file "yelp_review_full_csv.tgz"

alias Scidata.Utils
alias NimbleCSV.RFC4180, as: CSV

@doc """
Downloads the Yelp Reviews training dataset or fetches it locally.
"""
@spec download() :: %{review: [binary(), ...], rating: [5 | 4 | 3 | 2 | 1]}
def download(), do: download_dataset(:train)

@doc """
Downloads the Yelp Reviews test dataset or fetches it locally.
"""
@spec download_test() :: %{
review: [binary(), ...],
rating: [5 | 4 | 3 | 2 | 1]
}
def download_test(), do: download_dataset(:test)

defp download_dataset(dataset_type) do
files = Utils.get!(@base_url <> @dataset_file).body
regex = ~r"#{dataset_type}"

records =
for {fname, contents} <- files,
List.to_string(fname) =~ regex,
reduce: [[]] do
_ -> CSV.parse_string(contents, skip_headers: false)
end

%{
review: records |> Enum.map(&List.last(&1)),
rating: records |> Enum.map(fn x -> x |> List.first() |> String.to_integer() end)
}
end
end
51 changes: 51 additions & 0 deletions lib/scidata/yelp_polarity_reviews.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
defmodule Scidata.YelpPolarityReviews do
@moduledoc """
Module for downloading the [Yelp Polarity Reviews dataset](https://course.fast.ai/datasets#nlp).
"""

@base_url "https://s3.amazonaws.com/fast-ai-nlp/"

@dataset_file "yelp_review_polarity_csv.tgz"

alias Scidata.Utils
alias NimbleCSV.RFC4180, as: CSV

@doc """
Downloads the Yelp Polarity Reviews training dataset or fetches it locally.
"""
@spec download() :: %{review: [binary(), ...], sentiment: [1 | 0]}
def download(), do: download_dataset(:train)

@doc """
Downloads the Yelp Polarity Reviews test dataset or fetches it locally.
"""
@spec download_test() :: %{
review: [binary(), ...],
sentiment: [1 | 0]
}
def download_test(), do: download_dataset(:test)

defp download_dataset(dataset_type) do
files = Utils.get!(@base_url <> @dataset_file).body
regex = ~r"#{dataset_type}"

records =
for {fname, contents} <- files,
List.to_string(fname) =~ regex,
reduce: [[]] do
_ -> CSV.parse_string(contents, skip_headers: false)
end

%{
review: records |> Enum.map(&List.last(&1)),
sentiment: get_rating(records)
}
end

defp get_rating(records) do
Enum.map(records, fn
["1" | _] -> 0
["2" | _] -> 1
end)
end
end
3 changes: 2 additions & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ defmodule Scidata.MixProject do

defp deps do
[
{:ex_doc, ">= 0.24.0", only: :dev, runtime: false}
{:ex_doc, ">= 0.24.0", only: :dev, runtime: false},
{:nimble_csv, "~> 1.1"}
]
end

Expand Down
1 change: 1 addition & 0 deletions mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
"makeup": {:hex, :makeup, "1.0.5", "d5a830bc42c9800ce07dd97fa94669dfb93d3bf5fcf6ea7a0c67b2e0e4a7f26c", [:mix], [{:nimble_parsec, "~> 0.5 or ~> 1.0", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cfa158c02d3f5c0c665d0af11512fed3fba0144cf1aadee0f2ce17747fba2ca9"},
"makeup_elixir": {:hex, :makeup_elixir, "0.15.1", "b5888c880d17d1cc3e598f05cdb5b5a91b7b17ac4eaf5f297cb697663a1094dd", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.1", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "db68c173234b07ab2a07f645a5acdc117b9f99d69ebf521821d89690ae6c6ec8"},
"makeup_erlang": {:hex, :makeup_erlang, "0.1.1", "3fcb7f09eb9d98dc4d208f49cc955a34218fc41ff6b84df7c75b3e6e533cc65f", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "174d0809e98a4ef0b3309256cbf97101c6ec01c4ab0b23e926a9e17df2077cbb"},
"nimble_csv": {:hex, :nimble_csv, "1.1.0", "b1dba4a86be9e03065c9de829050468e591f569100332db949e7ce71be0afc25", [:mix], [], "hexpm", "e986755bc302832cac429be6deda0fc9d82d3c82b47abefb68b3c17c9d949a3f"},
"nimble_parsec": {:hex, :nimble_parsec, "1.1.0", "3a6fca1550363552e54c216debb6a9e95bd8d32348938e13de5eda962c0d7f89", [:mix], [], "hexpm", "08eb32d66b706e913ff748f11694b17981c0b04a33ef470e33e11b3d3ac8f54b"},
}
23 changes: 23 additions & 0 deletions test/yelp_full_reviews_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
defmodule YelpFullReviewsTest do
use ExUnit.Case

@moduletag timeout: 120_000

describe "download" do
test "retrieves training set" do
%{review: train_inputs, rating: train_targets} = Scidata.YelpFullReviews.download()

assert length(train_inputs) == 650_000
assert length(train_targets) == 650_000
assert train_targets |> Enum.uniq() |> Enum.sort() == [1, 2, 3, 4, 5]
end

test "retrieves test set" do
%{review: test_inputs, rating: test_targets} = Scidata.YelpFullReviews.download_test()

assert length(test_inputs) == 50000
assert length(test_targets) == 50000
assert test_targets |> Enum.uniq() |> Enum.sort() == [1, 2, 3, 4, 5]
end
end
end
24 changes: 24 additions & 0 deletions test/yelp_polarity_reviews_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
defmodule YelpPolarityReviewsTest do
use ExUnit.Case

@moduletag timeout: 120_000

describe "download" do
test "retrieves training set" do
%{review: train_inputs, sentiment: train_targets} = Scidata.YelpPolarityReviews.download()

assert length(train_inputs) == 560_000
assert length(train_targets) == 560_000
assert train_targets |> Enum.uniq() |> Enum.sort() == [0, 1]
end

test "retrieves test set" do
%{review: test_inputs, sentiment: test_targets} =
Scidata.YelpPolarityReviews.download_test()

assert length(test_inputs) == 38000
assert length(test_targets) == 38000
assert test_targets |> Enum.uniq() |> Enum.sort() == [0, 1]
end
end
end