Skip to content

Commit

Permalink
added xai
Browse files Browse the repository at this point in the history
  • Loading branch information
axsaucedo committed Mar 24, 2019
1 parent be5093d commit 3db24e3
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions xai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def convert_categories(
tmp_df = df.copy()

if not len(categorical_cols):
categorical_cols = df.select_dtypes(include=[np.object]).columns
categorical_cols = df.select_dtypes(include=[np.object, np.bool]).columns

tmp_df[categorical_cols] = tmp_df[categorical_cols].astype('category')
tmp_df[categorical_cols] = tmp_df[categorical_cols].apply(lambda x: x.cat.codes)
Expand Down Expand Up @@ -113,7 +113,7 @@ def group_by_columns(
"""

if not len(categorical_cols):
categorical_cols = df.select_dtypes(include=[np.object]).columns
categorical_cols = df.select_dtypes(include=[np.object, np.bool]).columns

group_list = []
for c in columns:
Expand Down Expand Up @@ -175,7 +175,7 @@ def show_imbalance(
"""

if not len(categorical_cols):
categorical_cols = df.select_dtypes(include=[np.object]).columns
categorical_cols = df.select_dtypes(include=[np.object, np.bool]).columns

cols = cross + [column_name]
grouped = group_by_columns(
Expand Down Expand Up @@ -249,7 +249,7 @@ def show_imbalances(
columns = df.columns

if not len(categorical_cols):
categorical_cols = df.select_dtypes(include=[np.object]).columns
categorical_cols = df.select_dtypes(include=[np.object, np.bool]).columns

if cross and any([x in columns for x in cross]):
raise("Error: Columns in 'cross' are also in 'columns'")
Expand Down Expand Up @@ -318,7 +318,7 @@ def balance(
"""

if not len(categorical_cols):
categorical_cols = df.select_dtypes(include=[np.object]).columns
categorical_cols = df.select_dtypes(include=[np.object, np.bool]).columns

cols = cross + [column_name]
grouped = group_by_columns(
Expand Down Expand Up @@ -407,7 +407,7 @@ def correlations(
else:

if not len(categorical_cols):
categorical_cols = df.select_dtypes(include=[np.object]).columns
categorical_cols = df.select_dtypes(include=[np.object, np.bool]).columns

cols = [c for c in df.columns if c not in categorical_cols]

Expand Down Expand Up @@ -451,7 +451,7 @@ def balanced_train_test_split(
cross = ["target"] + cross

if not len(categorical_cols):
categorical_cols = list(tmp_df.select_dtypes(include=[np.object]).columns)
categorical_cols = list(tmp_df.select_dtypes(include=[np.object, np.bool]).columns)

# TODO: Enable for non-categorical targets
categorical_cols = ["target"] + categorical_cols
Expand Down Expand Up @@ -593,7 +593,7 @@ def metrics_imbalances(
columns = x_test.columns

if not len(categorical_cols):
categorical_cols = x_test.select_dtypes(include=[np.object]).columns
categorical_cols = x_test.select_dtypes(include=[np.object, np.bool]).columns

results = []
for col in columns:
Expand Down Expand Up @@ -678,7 +678,7 @@ def roc_imbalances(
columns = x_test.columns

if not len(categorical_cols):
categorical_cols = x_test.select_dtypes(include=[np.object]).columns
categorical_cols = x_test.select_dtypes(include=[np.object, np.bool]).columns

results = []
for col in columns:
Expand Down Expand Up @@ -761,7 +761,7 @@ def pr_imbalances(
columns = x_test.columns

if not len(categorical_cols):
categorical_cols = x_test.select_dtypes(include=[np.object]).columns
categorical_cols = x_test.select_dtypes(include=[np.object, np.bool]).columns

results = []
for col in columns:
Expand Down

0 comments on commit 3db24e3

Please sign in to comment.