Skip to content

Commit

Permalink
Return raw data instead of CounterfactualData
Browse files Browse the repository at this point in the history
  • Loading branch information
kmariuszk committed Nov 22, 2023
1 parent 8023e53 commit 14a0a65
Show file tree
Hide file tree
Showing 20 changed files with 285 additions and 262 deletions.
11 changes: 7 additions & 4 deletions src/synthetic/blobs.jl → src/Raw/synthetic/blobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
Loads overlapping synthetic data.
"""
function load_blobs(n=250; seed=data_seed, k=2, centers=2, kwrgs...)
function load_blobs_raw(n=250; seed=data_seed, k=2, centers=2, kwrgs...)
if isa(seed, Random.AbstractRNG)
X, y = MLJBase.make_blobs(n, k; centers=centers, rng=seed, kwrgs...)
else
Random.seed!(seed)
X, y = MLJBase.make_blobs(n, k; centers=centers, kwrgs...)
end
counterfactual_data = CounterfactualExplanations.CounterfactualData(X, y)
counterfactual_data.X = Float32.(counterfactual_data.X)
return counterfactual_data

# counterfactual_data = CounterfactualExplanations.CounterfactualData(X, y)
# counterfactual_data.X = Float32.(counterfactual_data.X)
# return counterfactual_data

return (X, y)
end
11 changes: 7 additions & 4 deletions src/synthetic/circles.jl → src/Raw/synthetic/circles.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
Loads synthetic circles data.
"""
function load_circles(n=250; seed=data_seed, noise=0.15, factor=0.01)
function load_circles_raw(n=250; seed=data_seed, noise=0.15, factor=0.01)
if isa(seed, Random.AbstractRNG)
X, y = MLJBase.make_circles(n; rng=seed, noise=noise, factor=factor)
else
Random.seed!(seed)
X, y = MLJBase.make_circles(n; noise=noise, factor=factor)
end
counterfactual_data = CounterfactualExplanations.CounterfactualData(X, y)
counterfactual_data.X = Float32.(counterfactual_data.X)
return counterfactual_data

# counterfactual_data = CounterfactualExplanations.CounterfactualData(X, y)
# counterfactual_data.X = Float32.(counterfactual_data.X)
# return counterfactual_data

return (X, y)
end
13 changes: 13 additions & 0 deletions src/Raw/synthetic/linearly_separable.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
load_linearly_separable(n=250; seed=data_seed)
Loads linearly separable synthetic data.
"""
function load_linearly_separable_raw(n=250; seed=data_seed)
# counterfactual_data = load_blobs(n; seed=seed, centers=2, cluster_std=0.5)
# return counterfactual_data

raw_data = load_blobs(n; seed=seed, centers=2, cluster_std=0.5)

return raw_data
end
11 changes: 7 additions & 4 deletions src/synthetic/moons.jl → src/Raw/synthetic/moons.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
Loads synthetic moons data.
"""
function load_moons(n=250; seed=data_seed, kwrgs...)
function load_moons_raw(n=250; seed=data_seed, kwrgs...)
if isa(seed, Random.AbstractRNG)
X, y = MLJBase.make_moons(n; rng=seed, kwrgs...)
else
Random.seed!(seed)
X, y = MLJBase.make_moons(n; kwrgs...)
end
counterfactual_data = CounterfactualExplanations.CounterfactualData(X, y)
counterfactual_data.X = Float32.(counterfactual_data.X)
return counterfactual_data

# counterfactual_data = CounterfactualExplanations.CounterfactualData(X, y)
# counterfactual_data.X = Float32.(counterfactual_data.X)
# return counterfactual_data

return (X, y)
end
14 changes: 14 additions & 0 deletions src/Raw/synthetic/multi_class.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
load_multi_class(n=250; seed=data_seed)
Loads multi-class synthetic data.
"""
function load_multi_class_raw(n=250; seed=data_seed, centers=4)
# counterfactual_data = load_blobs(n; seed=seed, centers=centers, cluster_std=0.5)

# return counterfactual_data

raw_data = load_blobs_raw(n; seed=seed, centers=centers, cluster_std=0.5)

return raw_data
end
14 changes: 14 additions & 0 deletions src/Raw/synthetic/overlapping.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
load_overlapping(n=250; seed=data_seed)
Loads overlapping synthetic data.
"""
function load_overlapping_raw(n=250; seed=data_seed)
# counterfactual_data = load_blobs(n; seed=seed, centers=2, cluster_std=2.0)

# return counterfactual_data

raw_data = load_blobs_raw(n; seed=seed, centers=2, cluster_std=2.0)

return raw_data
end
20 changes: 10 additions & 10 deletions src/tabular/adult.jl → src/Raw/tabular/adult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Load and preprocesses data from the UCI 'Adult' dataset
# Example
data = load_uci_adult(20) # loads and preprocesses 20 samples from the Adult dataset
"""
function load_uci_adult(n::Union{Nothing,Int}=1000)
function load_uci_adult_raw(n::Union{Nothing,Int}=1000)
# Throw an exception if n < 1:
if !isnothing(n) && n < 1
throw(ArgumentError("n must be >= 1"))
Expand Down Expand Up @@ -45,22 +45,22 @@ function load_uci_adult(n::Union{Nothing,Int}=1000)
)

# Preprocessing
transformer = Standardizer(; count=true)
transformer = MLJModels.Standardizer(; count=true)
mach = MLJBase.fit!(machine(transformer, df[:, DataFrames.Not(:target)]))
X = MLJBase.transform(mach, df[:, DataFrames.Not(:target)])
X = Matrix(X)
X = permutedims(X)
X = Float32.(X)

y = df.target
counterfactual_data = CounterfactualData(X, y)
# counterfactual_data = CounterfactualData(X, y)

# Undersample:
if !isnothing(n)
counterfactual_data = CounterfactualExplanations.DataPreprocessing.subsample(
counterfactual_data, n
)
end
# # Undersample:
# if !isnothing(n)
# counterfactual_data = CounterfactualExplanations.DataPreprocessing.subsample(
# counterfactual_data, n
# )
# end

return counterfactual_data
return (X, y)
end
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Loads and pre-processes California Housing data.
"""
function load_california_housing(n::Union{Nothing,Int}=5000)
function load_california_housing_raw(n::Union{Nothing,Int}=5000)

# check that n is > 0
if !isnothing(n) && n <= 0
Expand All @@ -21,14 +21,15 @@ function load_california_housing(n::Union{Nothing,Int}=5000)

# Counterfactual data:
y = Int.(df.target)
counterfactual_data = CounterfactualExplanations.CounterfactualData(X, y)
counterfactual_data.X = Float32.(counterfactual_data.X)
# counterfactual_data = CounterfactualExplanations.CounterfactualData(X, y)
# counterfactual_data.X = Float32.(counterfactual_data.X)

# Undersample:
if !isnothing(n)
counterfactual_data = CounterfactualExplanations.DataPreprocessing.subsample(
counterfactual_data, n
)
end
return counterfactual_data
# if !isnothing(n)
# counterfactual_data = CounterfactualExplanations.DataPreprocessing.subsample(
# counterfactual_data, n
# )
# end

return (X, y)
end
22 changes: 11 additions & 11 deletions src/tabular/credit_default.jl → src/Raw/tabular/credit_default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Loads and pre-processes UCI Credit Default data.
"""
function load_credit_default(n::Union{Nothing,Int}=5000)
function load_credit_default_raw(n::Union{Nothing,Int}=5000)

# Load:
df = CSV.read(joinpath(data_dir, "credit_default.csv"), DataFrames.DataFrame)
Expand All @@ -24,17 +24,17 @@ function load_credit_default(n::Union{Nothing,Int}=5000)

# Counterfactual data:
y = df.target
counterfactual_data = CounterfactualExplanations.CounterfactualData(
X, y; features_categorical=features_categorical
)
counterfactual_data.X = Float32.(counterfactual_data.X)
# counterfactual_data = CounterfactualExplanations.CounterfactualData(
# X, y; features_categorical=features_categorical
# )
# counterfactual_data.X = Float32.(counterfactual_data.X)

# Undersample:
if !isnothing(n)
counterfactual_data = CounterfactualExplanations.DataPreprocessing.subsample(
counterfactual_data, n
)
end
# if !isnothing(n)
# counterfactual_data = CounterfactualExplanations.DataPreprocessing.subsample(
# counterfactual_data, n
# )
# end

return counterfactual_data
return (X, y)
end
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Loads and pre-processes UCI German Credit data.
data = load_german_credit(500) # loads and preprocesses 500 samples from the German Credit dataset
"""
function load_german_credit(n::Union{Nothing,Int}=nothing)
function load_german_credit_raw(n::Union{Nothing,Int}=nothing)
# Throw an exception if n > 1000:
if !isnothing(n) && n > 1000
throw(ArgumentError("n must be <= 1000"))
Expand All @@ -37,14 +37,14 @@ function load_german_credit(n::Union{Nothing,Int}=nothing)

# Counterfactual data:
y = df.target
counterfactual_data = CounterfactualExplanations.CounterfactualData(X, y)
# counterfactual_data = CounterfactualExplanations.CounterfactualData(X, y)

# Undersample:
if !isnothing(n)
counterfactual_data = CounterfactualExplanations.DataPreprocessing.subsample(
counterfactual_data, n
)
end
# if !isnothing(n)
# counterfactual_data = CounterfactualExplanations.DataPreprocessing.subsample(
# counterfactual_data, n
# )
# end

return counterfactual_data
return (X, y)
end
18 changes: 9 additions & 9 deletions src/tabular/gmsc.jl → src/Raw/tabular/gmsc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Loads and pre-processes Give Me Some Credit (GMSC) data.
"""
function load_gmsc(n::Union{Nothing,Int}=5000)
function load_gmsc_raw(n::Union{Nothing,Int}=5000)

# Load:
df = CSV.read(joinpath(data_dir, "gmsc.csv"), DataFrames.DataFrame)
Expand All @@ -17,15 +17,15 @@ function load_gmsc(n::Union{Nothing,Int}=5000)

# Counterfactual data:
y = df.target
counterfactual_data = CounterfactualExplanations.CounterfactualData(X, y)
counterfactual_data.X = Float32.(counterfactual_data.X)
# counterfactual_data = CounterfactualExplanations.CounterfactualData(X, y)
# counterfactual_data.X = Float32.(counterfactual_data.X)

# Undersample:
if !isnothing(n)
counterfactual_data = CounterfactualExplanations.DataPreprocessing.subsample(
counterfactual_data, n
)
end
# if !isnothing(n)
# counterfactual_data = CounterfactualExplanations.DataPreprocessing.subsample(
# counterfactual_data, n
# )
# end

return counterfactual_data
return (X, y)
end
30 changes: 15 additions & 15 deletions src/vision/cifar_10.jl → src/Raw/vision/cifar_10.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@ Loads and preprocesses data from the CIFAR-10 dataset for use in counterfactual
data = load_cifar_10(1000) # loads and preprocesses 1000 samples from the CIFAR-10 dataset
"""
function load_cifar_10(n::Union{Nothing,Int}=nothing)
function load_cifar_10_raw(n::Union{Nothing,Int}=nothing)
X, y = MLDatasets.CIFAR10()[:] # [:] gives us X, y
X = Flux.flatten(X)
X = X .* 2 .- 1 # normalization between [-1, 1]
y = MLJBase.categorical(y)
counterfactual_data = CounterfactualExplanations.CounterfactualData(
X, y; domain=(-1.0, 1.0), standardize=false
)
if !isnothing(n)
counterfactual_data = CounterfactualExplanations.DataPreprocessing.subsample(
counterfactual_data, n
)
end
return counterfactual_data
# counterfactual_data = CounterfactualExplanations.CounterfactualData(
# X, y; domain=(-1.0, 1.0), standardize=false
# )
# if !isnothing(n)
# counterfactual_data = CounterfactualExplanations.DataPreprocessing.subsample(
# counterfactual_data, n
# )
# end
return (X, y)
end

"""
Expand All @@ -41,13 +41,13 @@ Loads and preprocesses test data from the CIFAR-10 dataset for use in counterfac
test_data = load_cifar_10_test() # loads and preprocesses test data from the CIFAR-10 dataset
"""
function load_cifar_10_test()
function load_cifar_10_test_raw()
X, y = MLDatasets.CIFAR10(:test)[:]
X = Flux.flatten(X)
X = X .* 2 .- 1 # normalization between [-1, 1]
y = MLJBase.categorical(y)
counterfactual_data = CounterfactualExplanations.CounterfactualData(
X, y; domain=(-1.0, 1.0)
)
return counterfactual_data
# counterfactual_data = CounterfactualExplanations.CounterfactualData(
# X, y; domain=(-1.0, 1.0)
# )
return (X, y)
end
39 changes: 39 additions & 0 deletions src/Raw/vision/fashion_mnist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
load_fashion_mnist(n::Union{Nothing,Int}=nothing)
Loads and prepares FashionMNIST data.
"""
function load_fashion_mnist_raw(n::Union{Nothing,Int}=nothing)
X, y = MLDatasets.FashionMNIST(:train)[:]
X = Flux.flatten(X)
X = X .* 2.0f0 .- 1.0f0
y = MLJBase.categorical(y)
# counterfactual_data = CounterfactualExplanations.CounterfactualData(
# X, y; domain=(-1.0, 1.0), standardize=false
# )
# counterfactual_data.X = Float32.(counterfactual_data.X)
# # Undersample:
# if !isnothing(n)
# counterfactual_data = CounterfactualExplanations.DataPreprocessing.subsample(
# counterfactual_data, n
# )
# end
return (X, y)
end

"""
load_fashion_mnist_test()
Loads and prepares FashionMNIST test data.
"""
function load_fashion_mnist_test_raw()
X, y = MLDatasets.FashionMNIST(:test)[:]
X = Flux.flatten(X)
X = X .* 2.0f0 .- 1.0f0
y = MLJBase.categorical(y)
# counterfactual_data = CounterfactualExplanations.CounterfactualData(
# X, y; domain=(-1.0, 1.0)
# )
# counterfactual_data.X = Float32.(counterfactual_data.X)
return (X, y)
end
Loading

0 comments on commit 14a0a65

Please sign in to comment.