Skip to content

Commit

Permalink
vamos
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Jan 10, 2025
1 parent 6e0c132 commit 259e777
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 21 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),

*Note*: We try to adhere to these practices as of version [v1.0.1].

## Version [1.1.2] - 2025-01-10

### Changed

- Improved and streamlined some assertions related to dataset sizes. [#29]

## Version [1.1.1] - 2025-01-10

### Changed
Expand Down
14 changes: 6 additions & 8 deletions src/tabular/adult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,9 @@
Loads data from the UCI 'Adult' dataset.
"""
function load_uci_adult(n::Union{Nothing,Int}=1000; seed=data_seed)
# Throw an exception if n < 1:
if !isnothing(n) && n < 1
throw(ArgumentError("n must be >= 1"))
end
if !isnothing(n) && n > 32000
throw(ArgumentError("n must not exceed size of dataset (<=32000)"))
end

# Assertions:
ensure_positive(n)

# Load data
df = CSV.read(joinpath(data_dir, "adult.csv"), DataFrames.DataFrame)
Expand Down Expand Up @@ -41,9 +37,11 @@ function load_uci_adult(n::Union{Nothing,Int}=1000; seed=data_seed)
X = MLJBase.transform(mach, df[:, DataFrames.Not(:target)])
X = Matrix(X)
X = permutedims(X)

y = df.target

# Checks and warnings
request_more_than_available(n, size(X,2))

# Randomly under-/over-sample:
rng = get_rng(seed)
if !isnothing(n) && n != size(X)[2]
Expand Down
9 changes: 5 additions & 4 deletions src/tabular/california_housing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@ Loads California Housing data.
"""
function load_california_housing(n::Union{Nothing,Int}=5000; seed=data_seed)

# check that n is > 0
if !isnothing(n) && n <= 0
throw(ArgumentError("n must be > 0"))
end
# Assertions:
ensure_positive(n)

# Load:
df = CSV.read(joinpath(data_dir, "cal_housing.csv"), DataFrames.DataFrame)
Expand All @@ -22,6 +20,9 @@ function load_california_housing(n::Union{Nothing,Int}=5000; seed=data_seed)
# Counterfactual data:
y = Int.(df.target)

# Checks and warnings
request_more_than_available(n, size(X, 2))

# Randomly under-/over-sample:
rng = get_rng(seed)
if !isnothing(n) && n != size(X)[2]
Expand Down
6 changes: 6 additions & 0 deletions src/tabular/credit_default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ Loads UCI Credit Default data.
"""
function load_credit_default(n::Union{Nothing,Int}=5000; seed=data_seed)

# Assertions:
ensure_positive(n)

# Load:
df = CSV.read(joinpath(data_dir, "credit_default.csv"), DataFrames.DataFrame)

Expand All @@ -28,6 +31,9 @@ function load_credit_default(n::Union{Nothing,Int}=5000; seed=data_seed)
# X, y; features_categorical=features_categorical
# )

# Checks and warnings
request_more_than_available(n, size(X, 2))

# Randomly under-/over-sample:
rng = get_rng(seed)
if !isnothing(n) && n != size(X)[2]
Expand Down
13 changes: 5 additions & 8 deletions src/tabular/german_credit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,9 @@
Loads UCI German Credit data.
"""
function load_german_credit(n::Union{Nothing,Int}=nothing; seed=data_seed)
# Throw an exception if n > 1000:
if !isnothing(n) && n > 1000
throw(ArgumentError("n must be <= 1000"))
end

# Throw an exception if n < 1:
if !isnothing(n) && n < 1
throw(ArgumentError("n must be >= 1"))
end
# Assertions:
ensure_positive(n)

# Load:
df = CSV.read(joinpath(data_dir, "german_credit.csv"), DataFrames.DataFrame)
Expand All @@ -27,6 +21,9 @@ function load_german_credit(n::Union{Nothing,Int}=nothing; seed=data_seed)
# Counterfactual data:
y = convert(Vector, df.target)

# Checks and warnings
request_more_than_available(n, size(X, 2))

# Randomly under-/over-sample:
rng = get_rng(seed)
if !isnothing(n) && n != size(X)[2]
Expand Down
6 changes: 6 additions & 0 deletions src/tabular/gmsc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ Loads Give Me Some Credit (GMSC) data.
"""
function load_gmsc(n::Union{Nothing,Int}=5000; seed=data_seed)

# Assertions:
ensure_positive(n)

# Load:
df = CSV.read(joinpath(data_dir, "gmsc.csv"), DataFrames.DataFrame)

Expand All @@ -18,6 +21,9 @@ function load_gmsc(n::Union{Nothing,Int}=5000; seed=data_seed)
# Counterfactual data:
y = df.target

# Checks and warnings
request_more_than_available(n, size(X, 2))

# Randomly under-/over-sample:
rng = get_rng(seed)
if !isnothing(n) && n != size(X)[2]
Expand Down
13 changes: 13 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@ function get_rng(seed::Union{Int,AbstractRNG})
return seed
end

function request_more_than_available(nreq, navailable)
if !isnothing(nreq) && nreq > navailable
@warn "Requested $nreq samples but only $navailable are available. Will resort to random oversampling."
end
end


function ensure_positive(n::Union{Nothing,Int})
if !isnothing(n) && n < 1
throw(ArgumentError("`n` must be >= 1"))
end
end

function subsample(rng::AbstractRNG, X::AbstractMatrix, y::AbstractVector, n::Int)
# Get the unique classes in `y`.
classes_ = unique(y)
Expand Down
9 changes: 8 additions & 1 deletion src/vision/cifar_10.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,22 @@
Loads data from the CIFAR-10 dataset.
"""
function load_cifar_10(n::Union{Nothing,Int}=nothing; seed=data_seed)

# Assertions:
ensure_positive(n)

X, y = MLDatasets.CIFAR10()[:] # [:] gives us X, y
X = Flux.flatten(X)
X = X .* 2 .- 1 # normalization between [-1, 1]
X = X .* 2 .- 1 # normalization between [-1, 1]
y = MLJBase.categorical(y)
y = DataAPI.unwrap.(y)
# counterfactual_data = CounterfactualExplanations.CounterfactualData(
# X, y; domain=(-1.0, 1.0), standardize=false
# )

# Checks and warnings
request_more_than_available(n, size(X, 2))

# Randomly under-/over-sample:
rng = get_rng(seed)
if !isnothing(n) && n != size(X)[2]
Expand Down
7 changes: 7 additions & 0 deletions src/vision/fashion_mnist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
Loads FashionMNIST data.
"""
function load_fashion_mnist(n::Union{Nothing,Int}=nothing; seed=data_seed)

# Assertions:
ensure_positive(n)

X, y = MLDatasets.FashionMNIST(:train)[:]
X = Flux.flatten(X)
X = X .* 2.0f0 .- 1.0f0
Expand All @@ -13,6 +17,9 @@ function load_fashion_mnist(n::Union{Nothing,Int}=nothing; seed=data_seed)
# X, y; domain=(-1.0, 1.0), standardize=false
# )

# Checks and warnings
request_more_than_available(n, size(X, 2))

# Randomly under-/over-sample:
rng = get_rng(seed)
if !isnothing(n) && n != size(X)[2]
Expand Down
7 changes: 7 additions & 0 deletions src/vision/mnist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
Loads MNIST data.
"""
function load_mnist(n::Union{Nothing,Int}=nothing; seed=data_seed)

# Assertions:
ensure_positive(n)

X, y = MLDatasets.MNIST(:train)[:]
X = Flux.flatten(X)
X = X .* 2.0f0 .- 1.0f0
Expand All @@ -13,6 +17,9 @@ function load_mnist(n::Union{Nothing,Int}=nothing; seed=data_seed)
# X, y; domain=(-1.0, 1.0), standardize=false
# )

# Checks and warnings
request_more_than_available(n, size(X, 2))

# Randomly under-/over-sample:
rng = get_rng(seed)
if !isnothing(n) && n != size(X)[2]
Expand Down

0 comments on commit 259e777

Please sign in to comment.