Skip to content

Commit

Permalink
Add Kuzushiji MNIST dataset (#22)
Browse files Browse the repository at this point in the history
* Add support for Kuzushiji MNIST dataset

* Update readme

* Add unit test for MNIST dataset loader

* Fix incorrect documentation for download

Co-authored-by: Tom Rutten <[email protected]>
  • Loading branch information
goodhamgupta and t-rutten authored Dec 7, 2021
1 parent a7d1db2 commit f1fd2c8
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Scidata currently supports the following training and test datasets:
- IMDb Reviews
- Yelp Reviews (Full and Polarity)
- MNIST
- Kuzushiji-MNIST(KMNIST)

Download or fetch datasets locally:

Expand Down
52 changes: 52 additions & 0 deletions lib/scidata/kuzushiji_mnist.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
defmodule Scidata.KuzushijiMNIST do
@moduledoc """
Module for downloading the [Kuzushiji-MNIST dataset](https://github.com/rois-codh/kmnist).
"""

alias Scidata.Utils

@base_url "http://codh.rois.ac.jp/kmnist/dataset/kmnist/"
@train_image_file "train-images-idx3-ubyte.gz"
@train_label_file "train-labels-idx1-ubyte.gz"
@test_image_file "t10k-images-idx3-ubyte.gz"
@test_label_file "t10k-labels-idx1-ubyte.gz"

@doc """
Downloads the Kuzushiji MNIST training dataset or fetches it locally.
Returns a tuple of format:
{{images_binary, images_type, images_shape},
{labels_binary, labels_type, labels_shape}}
If you want to one-hot encode the labels, you can:
labels_binary
|> Nx.from_binary(labels_type)
|> Nx.new_axis(-1)
|> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
"""
def download() do
{download_images(@train_image_file), download_labels(@train_label_file)}
end

@doc """
Downloads the Kuzushiji MNIST test dataset or fetches it locally.
"""
def download_test() do
{download_images(@test_image_file), download_labels(@test_label_file)}
end

defp download_images(image_file) do
data = Utils.get!(@base_url <> image_file).body
<<_::32, n_images::32, n_rows::32, n_cols::32, images::binary>> = data
{images, {:u, 8}, {n_images, 1, n_rows, n_cols}}
end

defp download_labels(label_file) do
data = Utils.get!(@base_url <> label_file).body
<<_::32, n_labels::32, labels::binary>> = data
{labels, {:u, 8}, {n_labels}}
end
end
27 changes: 27 additions & 0 deletions test/kuzushiji_mnist_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
defmodule KuzushijiMNISTTest do
use ExUnit.Case

@moduletag timeout: 120_000

describe "download" do
test "retrieves training set" do
{{_images, {:u, 8}, {n_images, 1, n_rows, n_cols}}, {_labels, {:u, 8}, {n_labels}}} =
Scidata.KuzushijiMNIST.download()

assert n_images == 60000
assert n_rows == 28
assert n_cols == 28
assert n_labels == 60000
end

test "retrieves test set" do
{{_images, {:u, 8}, {n_images, 1, n_rows, n_cols}}, {_labels, {:u, 8}, {n_labels}}} =
Scidata.KuzushijiMNIST.download_test()

assert n_images == 10000
assert n_rows == 28
assert n_cols == 28
assert n_labels == 10000
end
end
end
27 changes: 27 additions & 0 deletions test/mnist_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
defmodule MNISTTest do
use ExUnit.Case

@moduletag timeout: 120_000

describe "download" do
test "retrieves training set" do
{{_images, {:u, 8}, {n_images, 1, n_rows, n_cols}}, {_labels, {:u, 8}, {n_labels}}} =
Scidata.MNIST.download()

assert n_images == 60000
assert n_rows == 28
assert n_cols == 28
assert n_labels == 60000
end

test "retrieves test set" do
{{_images, {:u, 8}, {n_images, 1, n_rows, n_cols}}, {_labels, {:u, 8}, {n_labels}}} =
Scidata.MNIST.download_test()

assert n_images == 10000
assert n_rows == 28
assert n_cols == 28
assert n_labels == 10000
end
end
end

0 comments on commit f1fd2c8

Please sign in to comment.