From 259e7775adb2b5cfe5d4ce75f7989f5298c08bb9 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Fri, 10 Jan 2025 12:42:06 +0100 Subject: [PATCH] vamos --- CHANGELOG.md | 6 ++++++ src/tabular/adult.jl | 14 ++++++-------- src/tabular/california_housing.jl | 9 +++++---- src/tabular/credit_default.jl | 6 ++++++ src/tabular/german_credit.jl | 13 +++++-------- src/tabular/gmsc.jl | 6 ++++++ src/utils.jl | 13 +++++++++++++ src/vision/cifar_10.jl | 9 ++++++++- src/vision/fashion_mnist.jl | 7 +++++++ src/vision/mnist.jl | 7 +++++++ 10 files changed, 69 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 487b6ab..37fab37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), *Note*: We try to adhere to these practices as of version [v1.0.1]. +## Version [1.1.2] - 2025-01-10 + +### Changed + +- Improved and streamlined some assertions related to dataset sizes. [#29] + ## Version [1.1.1] - 2025-01-10 ### Changed diff --git a/src/tabular/adult.jl b/src/tabular/adult.jl index d94c83c..985a0e5 100644 --- a/src/tabular/adult.jl +++ b/src/tabular/adult.jl @@ -4,13 +4,9 @@ Loads data from the UCI 'Adult' dataset. """ function load_uci_adult(n::Union{Nothing,Int}=1000; seed=data_seed) - # Throw an exception if n < 1: - if !isnothing(n) && n < 1 - throw(ArgumentError("n must be >= 1")) - end - if !isnothing(n) && n > 32000 - throw(ArgumentError("n must not exceed size of dataset (<=32000)")) - end + + # Assertions: + ensure_positive(n) # Load data df = CSV.read(joinpath(data_dir, "adult.csv"), DataFrames.DataFrame) @@ -41,9 +37,11 @@ function load_uci_adult(n::Union{Nothing,Int}=1000; seed=data_seed) X = MLJBase.transform(mach, df[:, DataFrames.Not(:target)]) X = Matrix(X) X = permutedims(X) - y = df.target + # Checks and warnings + request_more_than_available(n, size(X,2)) + # Randomly under-/over-sample: rng = get_rng(seed) if !isnothing(n) && n != size(X)[2] diff --git a/src/tabular/california_housing.jl b/src/tabular/california_housing.jl index 36b8ccf..f6b4487 100644 --- a/src/tabular/california_housing.jl +++ b/src/tabular/california_housing.jl @@ -5,10 +5,8 @@ Loads California Housing data. """ function load_california_housing(n::Union{Nothing,Int}=5000; seed=data_seed) - # check that n is > 0 - if !isnothing(n) && n <= 0 - throw(ArgumentError("n must be > 0")) - end + # Assertions: + ensure_positive(n) # Load: df = CSV.read(joinpath(data_dir, "cal_housing.csv"), DataFrames.DataFrame) @@ -22,6 +20,9 @@ function load_california_housing(n::Union{Nothing,Int}=5000; seed=data_seed) # Counterfactual data: y = Int.(df.target) + # Checks and warnings + request_more_than_available(n, size(X, 2)) + # Randomly under-/over-sample: rng = get_rng(seed) if !isnothing(n) && n != size(X)[2] diff --git a/src/tabular/credit_default.jl b/src/tabular/credit_default.jl index 1c942d1..ca5a698 100644 --- a/src/tabular/credit_default.jl +++ b/src/tabular/credit_default.jl @@ -5,6 +5,9 @@ Loads UCI Credit Default data. """ function load_credit_default(n::Union{Nothing,Int}=5000; seed=data_seed) + # Assertions: + ensure_positive(n) + # Load: df = CSV.read(joinpath(data_dir, "credit_default.csv"), DataFrames.DataFrame) @@ -28,6 +31,9 @@ function load_credit_default(n::Union{Nothing,Int}=5000; seed=data_seed) # X, y; features_categorical=features_categorical # ) + # Checks and warnings + request_more_than_available(n, size(X, 2)) + # Randomly under-/over-sample: rng = get_rng(seed) if !isnothing(n) && n != size(X)[2] diff --git a/src/tabular/german_credit.jl b/src/tabular/german_credit.jl index 8ff2253..6a712f6 100644 --- a/src/tabular/german_credit.jl +++ b/src/tabular/german_credit.jl @@ -4,15 +4,9 @@ Loads UCI German Credit data. """ function load_german_credit(n::Union{Nothing,Int}=nothing; seed=data_seed) - # Throw an exception if n > 1000: - if !isnothing(n) && n > 1000 - throw(ArgumentError("n must be <= 1000")) - end - # Throw an exception if n < 1: - if !isnothing(n) && n < 1 - throw(ArgumentError("n must be >= 1")) - end + # Assertions: + ensure_positive(n) # Load: df = CSV.read(joinpath(data_dir, "german_credit.csv"), DataFrames.DataFrame) @@ -27,6 +21,9 @@ function load_german_credit(n::Union{Nothing,Int}=nothing; seed=data_seed) # Counterfactual data: y = convert(Vector, df.target) + # Checks and warnings + request_more_than_available(n, size(X, 2)) + # Randomly under-/over-sample: rng = get_rng(seed) if !isnothing(n) && n != size(X)[2] diff --git a/src/tabular/gmsc.jl b/src/tabular/gmsc.jl index 4273e77..572f12f 100644 --- a/src/tabular/gmsc.jl +++ b/src/tabular/gmsc.jl @@ -5,6 +5,9 @@ Loads Give Me Some Credit (GMSC) data. """ function load_gmsc(n::Union{Nothing,Int}=5000; seed=data_seed) + # Assertions: + ensure_positive(n) + # Load: df = CSV.read(joinpath(data_dir, "gmsc.csv"), DataFrames.DataFrame) @@ -18,6 +21,9 @@ function load_gmsc(n::Union{Nothing,Int}=5000; seed=data_seed) # Counterfactual data: y = df.target + # Checks and warnings + request_more_than_available(n, size(X, 2)) + # Randomly under-/over-sample: rng = get_rng(seed) if !isnothing(n) && n != size(X)[2] diff --git a/src/utils.jl b/src/utils.jl index d74ab33..eb147e1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -12,6 +12,19 @@ function get_rng(seed::Union{Int,AbstractRNG}) return seed end +function request_more_than_available(nreq, navailable) + if !isnothing(nreq) && nreq > navailable + @warn "Requested $nreq samples but only $navailable are available. Will resort to random oversampling." + end +end + + +function ensure_positive(n::Union{Nothing,Int}) + if !isnothing(n) && n < 1 + throw(ArgumentError("`n` must be >= 1")) + end +end + function subsample(rng::AbstractRNG, X::AbstractMatrix, y::AbstractVector, n::Int) # Get the unique classes in `y`. classes_ = unique(y) diff --git a/src/vision/cifar_10.jl b/src/vision/cifar_10.jl index 100fdf4..836115a 100644 --- a/src/vision/cifar_10.jl +++ b/src/vision/cifar_10.jl @@ -4,15 +4,22 @@ Loads data from the CIFAR-10 dataset. """ function load_cifar_10(n::Union{Nothing,Int}=nothing; seed=data_seed) + + # Assertions: + ensure_positive(n) + X, y = MLDatasets.CIFAR10()[:] # [:] gives us X, y X = Flux.flatten(X) - X = X .* 2 .- 1 # normalization between [-1, 1] + X = X .* 2 .- 1 # normalization between [-1, 1] y = MLJBase.categorical(y) y = DataAPI.unwrap.(y) # counterfactual_data = CounterfactualExplanations.CounterfactualData( # X, y; domain=(-1.0, 1.0), standardize=false # ) + # Checks and warnings + request_more_than_available(n, size(X, 2)) + # Randomly under-/over-sample: rng = get_rng(seed) if !isnothing(n) && n != size(X)[2] diff --git a/src/vision/fashion_mnist.jl b/src/vision/fashion_mnist.jl index 0833f9c..392fcaf 100644 --- a/src/vision/fashion_mnist.jl +++ b/src/vision/fashion_mnist.jl @@ -4,6 +4,10 @@ Loads FashionMNIST data. """ function load_fashion_mnist(n::Union{Nothing,Int}=nothing; seed=data_seed) + + # Assertions: + ensure_positive(n) + X, y = MLDatasets.FashionMNIST(:train)[:] X = Flux.flatten(X) X = X .* 2.0f0 .- 1.0f0 @@ -13,6 +17,9 @@ function load_fashion_mnist(n::Union{Nothing,Int}=nothing; seed=data_seed) # X, y; domain=(-1.0, 1.0), standardize=false # ) + # Checks and warnings + request_more_than_available(n, size(X, 2)) + # Randomly under-/over-sample: rng = get_rng(seed) if !isnothing(n) && n != size(X)[2] diff --git a/src/vision/mnist.jl b/src/vision/mnist.jl index 7718315..bc63727 100644 --- a/src/vision/mnist.jl +++ b/src/vision/mnist.jl @@ -4,6 +4,10 @@ Loads MNIST data. """ function load_mnist(n::Union{Nothing,Int}=nothing; seed=data_seed) + + # Assertions: + ensure_positive(n) + X, y = MLDatasets.MNIST(:train)[:] X = Flux.flatten(X) X = X .* 2.0f0 .- 1.0f0 @@ -13,6 +17,9 @@ function load_mnist(n::Union{Nothing,Int}=nothing; seed=data_seed) # X, y; domain=(-1.0, 1.0), standardize=false # ) + # Checks and warnings + request_more_than_available(n, size(X, 2)) + # Randomly under-/over-sample: rng = get_rng(seed) if !isnothing(n) && n != size(X)[2]