diff --git a/DemoSyntheticDataInferred.ipynb b/DemoSyntheticDataInferred.ipynb
new file mode 100644
index 0000000..833e43a
--- /dev/null
+++ b/DemoSyntheticDataInferred.ipynb
@@ -0,0 +1,1164 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "25ca67d9",
+ "metadata": {},
+ "source": [
+ "# Infering Protected Attribute Using a Linked Feature\n",
+ "\n",
+ "The background to the problem is that of inferring protected features, typically demographic data such as gender, for datasets where that information is absent. Our approach works on the assumption that the dataset contains a \"linked\" feature whose values cam give us probabilistic information about the protected feature values. For example, sssume that we have data about customer purchases and the stores from which the purchases were made. It is reasonable to assume that for certain stores women are more likely to buy items (the store may sell items of more interest to women) and for other stores men are more likely to buy items (the stores sell items of more interest to men). \n",
+ "\n",
+ "Let's now examine the transactions for a single customer $X$. A priori we do not know if the customer is male or female so we have the probabilities of the customer being male of female, in the absence of any other data, equal to each other i.e.\n",
+ "\n",
+ "$$P(X=Male) = P(X=Female) = 0.5$$\n",
+ "\n",
+ "Let $S_i$ be the store at which transaction $i$ has been made, the probability of a man making a purchase at a store $S$ be $P(S|X=Male)$ and the probability of woman making a purchase at a store ($P(S|X=Female)$). The probability of the customer being male given a list of his $N$ transactions can be calculated using Bayes theorem (and assuming that purchases are conditionally independent) to give\n",
+ "\n",
+ "$$\n",
+ "\\begin{split}\n",
+ "P(X=Male| S_1...S_n) & = \\frac{P(X=Male) P(S_1...S_n|X=Male)}{K}\\\\\n",
+ " & = \\frac{P(X=Male) \\prod_{i=1}^n P(S_i|X=Male)}{K}\n",
+ "\\end{split}\n",
+ "$$\n",
+ "\n",
+ "and\n",
+ "\n",
+ "$$\n",
+ "\\begin{split}\n",
+ "P(X=Female| S_1...S_n) & = \\frac{P(X=Female) P(S_1...S_n|X=Female)}{K}\\\\\n",
+ " & = \\frac{P(X=Female) \\prod_{i=1}^n P(S_i|X=Female)}{K}\n",
+ "\\end{split}\n",
+ "$$\n",
+ "\n",
+ "where $K$ is a common normalizing constant. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "96992752",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Thanks for trying out the ETIQ.ai toolkit!\n",
+ "\n",
+ "Visit our getting started documentation at https://docs.etiq.ai/\n",
+ "\n",
+ "Visit our Slack channel at https://etiqcore.slack.com/ for support or feedback.\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "from etiq_core import *\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "from sklearn.metrics import accuracy_score, confusion_matrix\n",
+ "import warnings\n",
+ "warnings.filterwarnings('ignore')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e250122f",
+ "metadata": {},
+ "source": [
+ "# Generating Synthetic Data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "28b7515b",
+ "metadata": {},
+ "source": [
+ "In order to test our inference pipeline synthetic transactions data is generated. Using the following functions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "d3fa88d5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from collections import Counter\n",
+ "\n",
+ "suspect_stores = ['MCC3', 'MCC5', 'MCC7', 'MCC11']\n",
+ "\n",
+ "# A utility function to sample from a categorical disbtribution given the categories and the probability of each\n",
+ "# category\n",
+ "def prob_categorical(cats, p):\n",
+ " return(np.random.choice(cats, 1, p=p)[0])\n",
+ "\n",
+ "# This function generates a dataframe consisting of Customers (identified by a unique ID) and their genders\n",
+ "# The probability of a random customer being a Male is equal to the probability of the same customer being female\n",
+ "def generate_customers(num_customers: int = 100, random_seed: int = 3):\n",
+ " np.random.seed(random_seed)\n",
+ " customers = pd.DataFrame()\n",
+ " customers['id'] = ['Cust' + str(i) for i in range(1,num_customers+1)]\n",
+ " customers['gender'] = customers.apply(lambda row: 'Male' if np.random.uniform() >= 0.5 else 'Female', axis=1)\n",
+ " return customers\n",
+ "\n",
+ "# This function is used to determine whether a transactions should be flagged\n",
+ "def flag_rule(row):\n",
+ " # Flag all transactions from four stores greater than 60\n",
+ " if row['MCC'] in suspect_stores:\n",
+ " return (1 if row['amount'] > 60 else 0)\n",
+ " return 0\n",
+ "\n",
+ "# This function generates a dataframe of synthetic transactions\n",
+ "def generate_transactions(num_transactions: int = 1000, customers: pd.DataFrame = None, linked_feature_probability: pd.DataFrame = None, random_seed: int = 4):\n",
+ " np.random.seed(random_seed)\n",
+ " transactions = pd.DataFrame()\n",
+ " if (customers is None) or (linked_feature_probability is None):\n",
+ " return transactions\n",
+ " customers_dict = {row['id']: row['gender'] for _,row in customers.iterrows()}\n",
+ " male_prob = list(linked_feature_probability['Male']/np.sum(linked_feature_probability['Male']))\n",
+ " female_prob = list(linked_feature_probability['Female']/np.sum(linked_feature_probability['Female']))\n",
+ " cats = list(linked_feature_probability['ID'])\n",
+ " mcc_rule = lambda row: prob_categorical(cats, male_prob) if row['gender'] == 'Male' else prob_categorical(cats, female_prob)\n",
+ " amount_rule = lambda row: np.random.uniform(20.0,100)\n",
+ " transactions['customerID'] = np.random.choice(customers['id'], num_transactions)\n",
+ " transactions['gender'] = transactions.apply(lambda row: customers_dict[row['customerID']], axis=1) \n",
+ " transactions['MCC'] = transactions.apply(mcc_rule, axis=1)\n",
+ " transactions['amount'] = transactions.apply(amount_rule, axis=1)\n",
+ " transactions['flag'] = transactions.apply(flag_rule, axis=1)\n",
+ " return transactions\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2c2fdadf",
+ "metadata": {},
+ "source": [
+ "We generate 1000 customers"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "1bb56108",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " id | \n",
+ " gender | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " Cust1 | \n",
+ " Male | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " Cust2 | \n",
+ " Female | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " Cust3 | \n",
+ " Male | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " Cust4 | \n",
+ " Male | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " Cust5 | \n",
+ " Male | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " id gender\n",
+ "0 Cust1 Male\n",
+ "1 Cust2 Female\n",
+ "2 Cust3 Male\n",
+ "3 Cust4 Male\n",
+ "4 Cust5 Male"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "customers = generate_customers(1000, random_seed=13)\n",
+ "customers.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "9df569e9",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Counter({'Male': 484, 'Female': 516})"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Get the number of Male and Female customers in the list of 1000 customers\n",
+ "Counter(customers['gender'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "73c0ee14",
+ "metadata": {},
+ "source": [
+ "## Generate Transaction Data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5446204b",
+ "metadata": {},
+ "source": [
+ "Create probabilities for protected feature given linked feature values for a set of 11 stores."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "e92a3c88",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mcc_strong = pd.DataFrame()\n",
+ "mcc_strong['ID'] = ['MCC1', 'MCC2', 'MCC3', 'MCC4', 'MCC5', 'MCC6', 'MCC7', 'MCC8', 'MCC9', 'MCC10', 'MCC11']\n",
+ "mcc_strong['Female'] = [0.5, 0.3, 0.7, 0.5, 0.75, 0.25, 0.5, 0.52, 0.42, 0.9, 0.1]\n",
+ "mcc_strong['Male'] = [0.5, 0.7, 0.3, 0.5, 0.25, 0.75, 0.5, 0.48, 0.58, 0.1, 0.9]\n",
+ "\n",
+ "mcc_weak = pd.DataFrame()\n",
+ "mcc_weak['ID'] = ['MCC1', 'MCC2', 'MCC3', 'MCC4', 'MCC5', 'MCC6', 'MCC7', 'MCC8', 'MCC9', 'MCC10', 'MCC11']\n",
+ "mcc_weak['Female'] = [0.5, 0.45, 0.55, 0.5, 0.55, 0.45, 0.5, 0.55, 0.45, 0.55, 0.45]\n",
+ "mcc_weak['Male'] = [0.5, 0.55, 0.45, 0.5, 0.45, 0.55, 0.5, 0.45, 0.55, 0.45, 0.55]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "34918186",
+ "metadata": {},
+ "source": [
+ "Create transaction data using the probabilities above"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "f115a7b7",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " customerID | \n",
+ " MCC | \n",
+ " amount | \n",
+ " flag | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " Cust875 | \n",
+ " MCC3 | \n",
+ " 28.065022 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " Cust665 | \n",
+ " MCC3 | \n",
+ " 58.534485 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " Cust250 | \n",
+ " MCC4 | \n",
+ " 56.476419 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " Cust644 | \n",
+ " MCC2 | \n",
+ " 83.464443 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " Cust953 | \n",
+ " MCC5 | \n",
+ " 94.138779 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " customerID MCC amount flag\n",
+ "0 Cust875 MCC3 28.065022 0\n",
+ "1 Cust665 MCC3 58.534485 0\n",
+ "2 Cust250 MCC4 56.476419 0\n",
+ "3 Cust644 MCC2 83.464443 0\n",
+ "4 Cust953 MCC5 94.138779 1"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "full_data_strong = generate_transactions(40000, customers, mcc_strong, random_seed=3)\n",
+ "# Drop the actual gender from the transactions\n",
+ "transactions_strong = full_data_strong.drop(['gender'], axis=1)\n",
+ "transactions_strong.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "0a7d75bc",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " customerID | \n",
+ " MCC | \n",
+ " amount | \n",
+ " flag | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " Cust868 | \n",
+ " MCC9 | \n",
+ " 84.300496 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " Cust207 | \n",
+ " MCC3 | \n",
+ " 26.826899 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " Cust702 | \n",
+ " MCC3 | \n",
+ " 43.868908 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " Cust999 | \n",
+ " MCC3 | \n",
+ " 53.864379 | \n",
+ " 0 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " Cust119 | \n",
+ " MCC3 | \n",
+ " 93.994402 | \n",
+ " 1 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " customerID MCC amount flag\n",
+ "0 Cust868 MCC9 84.300496 0\n",
+ "1 Cust207 MCC3 26.826899 0\n",
+ "2 Cust702 MCC3 43.868908 0\n",
+ "3 Cust999 MCC3 53.864379 0\n",
+ "4 Cust119 MCC3 93.994402 1"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "full_data_weak = generate_transactions(40000, customers, mcc_weak, random_seed=5)\n",
+ "# Drop the actual gender from the transactions\n",
+ "transactions_weak = full_data_weak.drop(['gender'], axis=1)\n",
+ "transactions_weak.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "9143f571",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dict_mcc_strong = {row['ID']: (row['Female'], row['Male']) for _,row in mcc_strong.iterrows()}\n",
+ "dict_mcc_weak = {row['ID']: (row['Female'], row['Male']) for _,row in mcc_weak.iterrows()}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3f0cced4",
+ "metadata": {},
+ "source": [
+ "# Setup Inferred data pipelines"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "0eb87dd6",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:etiq_core.pipeline.InferProtectedPipeline0349:Etiq removed the column encoding the protected attribute values from the dataset. The models are fitted and metrics are computed on a dataset without the protected attribute column. The protected attribute values can be found in the protected_train or protected_valid fields of each dataset\n",
+ "INFO:etiq_core.pipeline.InferProtectedPipeline0349:Starting pipeline\n",
+ "INFO:etiq_core.pipeline.InferProtectedPipeline0349:Infering protected feature \"gender\" using feature \"MCC\"\n",
+ "INFO:etiq_core.pipeline.InferProtectedPipeline0349:Fitting model\n",
+ "INFO:etiq_core.pipeline.InferProtectedPipeline0349:Computed metrics for the initial dataset\n",
+ "INFO:etiq_core.pipeline.InferProtectedPipeline0349:Completed pipeline\n",
+ "WARNING:etiq_core.pipeline.InferProtectedPipeline0365:Etiq removed the column encoding the protected attribute values from the dataset. The models are fitted and metrics are computed on a dataset without the protected attribute column. The protected attribute values can be found in the protected_train or protected_valid fields of each dataset\n",
+ "INFO:etiq_core.pipeline.InferProtectedPipeline0365:Starting pipeline\n",
+ "INFO:etiq_core.pipeline.InferProtectedPipeline0365:Infering protected feature \"gender\" using feature \"MCC\"\n",
+ "INFO:etiq_core.pipeline.InferProtectedPipeline0365:Fitting model\n",
+ "INFO:etiq_core.pipeline.InferProtectedPipeline0365:Computed metrics for the initial dataset\n",
+ "INFO:etiq_core.pipeline.InferProtectedPipeline0365:Completed pipeline\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Specify the categorical and continuous features\n",
+ "cat_vars = ['customerID','MCC', 'flag']\n",
+ "cont_vars = ['amount']\n",
+ "\n",
+ "transforms = [Dropna, EncodeLabels] \n",
+ "# Note that we don't have the protected feature so the BiasParams protected field should be set to None\n",
+ "debias_param = BiasParams(protected=None, privileged=1, unprivileged=2, \n",
+ " positive_outcome_label='0', negative_outcome_label='1')\n",
+ "\n",
+ "dl_strong = DatasetLoader(data=transactions_strong, label='flag', transforms=transforms, bias_params=debias_param,\n",
+ " train_valid_test_splits=[0.8, 0.1, 0.1], cat_col=cat_vars,\n",
+ " cont_col=cont_vars, names_col = transactions_strong.columns.values)\n",
+ "\n",
+ "dl_weak = DatasetLoader(data=transactions_weak, label='flag', transforms=transforms, bias_params=debias_param,\n",
+ " train_valid_test_splits=[0.8, 0.1, 0.1], cat_col=cat_vars,\n",
+ " cont_col=cont_vars, names_col = transactions_weak.columns.values)\n",
+ "\n",
+ "# Model\n",
+ "xgb_strong = DefaultXGBoostClassifier()\n",
+ "xgb_weak = DefaultXGBoostClassifier()\n",
+ "\n",
+ "# \"Strong\" inferred data pipeline\n",
+ "metrics_initial = [accuracy, equal_opportunity, demographic_parity]\n",
+ "pipeline_infered_strong = InferProtectedDataPipeline(dataset_loader=dl_strong, model=xgb_strong, \n",
+ " metrics=metrics_initial, infer_feature='gender',\n",
+ " linked_feature='MCC', data_key_column='customerID',\n",
+ " feature_prob_lookup=dict_mcc_strong,\n",
+ " privileged_class=1)\n",
+ "pipeline_infered_strong.run()\n",
+ "\n",
+ "# \"Weak\" inferred data pipeline\n",
+ "pipeline_infered_weak = InferProtectedDataPipeline(dataset_loader=dl_weak, model=xgb_weak, \n",
+ " metrics=metrics_initial, infer_feature='gender',\n",
+ " linked_feature='MCC', data_key_column='customerID',\n",
+ " feature_prob_lookup=dict_mcc_weak,\n",
+ " privileged_class=1)\n",
+ "pipeline_infered_weak.run()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "646ae473",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'InferProtectedPipeline0349': [{'accuracy': ('privileged',\n",
+ " 1.0,\n",
+ " 'unprivileged',\n",
+ " 1.0)},\n",
+ " {'equal_opportunity': ('privileged', 1.0, 'unprivileged', 1.0)},\n",
+ " {'demographic_parity': ('privileged',\n",
+ " 0.7865279841505696,\n",
+ " 'unprivileged',\n",
+ " 0.8244197780020182)}]}"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pipeline_infered_strong.get_protected_metrics()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "3ced96ec",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'InferProtectedPipeline0365': [{'accuracy': ('privileged',\n",
+ " 1.0,\n",
+ " 'unprivileged',\n",
+ " 1.0)},\n",
+ " {'equal_opportunity': ('privileged', 1.0, 'unprivileged', 1.0)},\n",
+ " {'demographic_parity': ('privileged',\n",
+ " 0.7896440129449838,\n",
+ " 'unprivileged',\n",
+ " 0.8275299238302503)}]}"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pipeline_infered_weak.get_protected_metrics()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "40341cfd",
+ "metadata": {},
+ "source": [
+ "## How well did we do at inferring the protected characteristic?"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "7ec94113",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Load the full dataset\n",
+ "customers_dict = {row['id']: row['gender'] for _,row in customers.iterrows()}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "654bf3c7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Get all the customer IDs from the training set\n",
+ "customer_idx = np.where(pipeline_infered_strong.get_dataset().get_dataset_column_names()== 'customerID')[0][0]\n",
+ "c = pipeline_infered_strong.get_dataset().x_train[:,customer_idx].astype(int)\n",
+ "train_customers_strong = pipeline_infered_strong.store.encoder['customerID'].inverse_transform(c)\n",
+ "\n",
+ "customer_idx = np.where(pipeline_infered_weak.get_dataset().get_dataset_column_names()== 'customerID')[0][0]\n",
+ "c = pipeline_infered_weak.get_dataset().x_train[:,customer_idx].astype(int)\n",
+ "train_customers_weak = pipeline_infered_weak.store.encoder['customerID'].inverse_transform(c)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "7fc289a2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Get the gender of the customer for each of the training set transactions\n",
+ "actual_train_gender_strong = [customers_dict[acustid] for acustid in train_customers_strong]\n",
+ "actual_train_gender_strong_encoded = [1 if x=='Male' else 0 for x in actual_train_gender_strong]\n",
+ "\n",
+ "# Get the gender of the customer for each of the training set transactions\n",
+ "actual_train_gender_weak = [customers_dict[acustid] for acustid in train_customers_weak]\n",
+ "actual_train_gender_weak_encoded = [1 if x=='Male' else 0 for x in actual_train_gender_weak]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "e6d912ba",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Get the inferred gender of each customer in the training \n",
+ "infered_train_gender_strong = pipeline_infered_strong.get_dataset().protected_train.astype(int)\n",
+ "\n",
+ "# Get the inferred gender of each customer in the training \n",
+ "infered_train_gender_weak = pipeline_infered_weak.get_dataset().protected_train.astype(int)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "f3de3f2e",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[16534, 66],\n",
+ " [ 0, 15399]])"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Get the confusion matrix of the gender of the customer in each of the training set transactions\n",
+ "confusion_matrix(actual_train_gender_strong_encoded , infered_train_gender_strong)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "a8a62b24",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[11712, 4808],\n",
+ " [ 5200, 10279]])"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Get the confusion matrix of the gender of the customer in each of the training set transactions\n",
+ "confusion_matrix(actual_train_gender_weak_encoded , infered_train_gender_weak)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "df896e76",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "actual_train_customers_dict_strong = {a: customers_dict[a] for a in train_customers_strong}\n",
+ "inferred_train_customers_strong_d = {a: ('Male' if infered_train_gender_strong[idx]==1 else 'Female') for idx,a in enumerate(train_customers_strong)}\n",
+ "\n",
+ "actual_train_customers_dict_weak = {a: customers_dict[a] for a in train_customers_weak}\n",
+ "inferred_train_customers_weak_d = {a: ('Male' if infered_train_gender_weak[idx]==1 else 'Female') for idx,a in enumerate(train_customers_weak)}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "973590d0",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[514, 2],\n",
+ " [ 0, 484]])"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "actual_train_customers_list = list(actual_train_customers_dict_strong.values())\n",
+ "inferred_train_customers_list = list(inferred_train_customers_strong_d.values())\n",
+ "confusion_matrix(actual_train_customers_list, inferred_train_customers_list)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8537edb2",
+ "metadata": {},
+ "source": [
+ "In summary where there is strong gender disparity in shopping habits our technique manages to correctly infer the gender of 998 out of 1000 customers. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "d03f76e8",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[367, 149],\n",
+ " [167, 317]])"
+ ]
+ },
+ "execution_count": 20,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "actual_train_customers_list = list(actual_train_customers_dict_weak.values())\n",
+ "inferred_train_customers_list = list(inferred_train_customers_weak_d.values())\n",
+ "confusion_matrix(actual_train_customers_list, inferred_train_customers_list)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0ae58151",
+ "metadata": {},
+ "source": [
+ "On the other hand where there is weak gender disparity in shopping habits out technique only manages to correctly infer the gender of 684 out of our 1000 customers."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3287aa80",
+ "metadata": {},
+ "source": [
+ "# Running a Debias pipeline over an inferred data pipeline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "e773c389",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:etiq_core.pipeline.DebiasPipeline0296:Warning: potential data leak. You are evaluating a fitted model. If you use the same dataset that you used to train the model and you pass on just the training/validation/test split without passing on which was your validation dataset when you fitted the model, some observations that were previously in train can now be in test - which can skew your results\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0296:Starting pipeline\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0296:Start Phase IdentifyPipeline0575\n",
+ "INFO:etiq_core.pipeline.IdentifyPipeline0575:Starting pipeline\n",
+ "INFO:etiq_core.pipeline.IdentifyPipeline0575:Completed pipeline\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0296:Completed Phase IdentifyPipeline0575\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0296:Start Phase RepairPipeline0743\n",
+ "INFO:etiq_core.pipeline.RepairPipeline0743:Starting pipeline\n",
+ "INFO:etiq_core.pipeline.RepairPipeline0743:Completed pipeline\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0296:Completed Phase RepairPipeline0743\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0296:Refitting model\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0296:Computed metrics for the repaired dataset\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0296:Compare pipeline predictions\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0296:Completed pipeline\n"
+ ]
+ }
+ ],
+ "source": [
+ "# the DebiasPipeline aims to identify sources of bias by applying analyses formalized in the Identify pipelines\n",
+ "# the Identify pipeline is looking for 3 sources of bias (limited features, poor sampling and proxies)\n",
+ "\n",
+ "identify_pipeline = IdentifyBiasSources(nr_groups=20, # nr of segments based on using unsupervised learning to group similar rows\n",
+ " train_model_segment=True,\n",
+ " group_def=['unsupervised'],\n",
+ " fit_metrics=[accuracy, equal_opportunity])\n",
+ " \n",
+ "# the DebiasPipeline aims to mitigate sources of bias by applying different types of repair algorithms\n",
+ "# the library offers implementations of repair algorithms described in the academic fairness literature\n",
+ "repair_pipeline = RepairResamplePipeline(steps=[ResampleUnbiasedSegmentsStep(ratio_resample=1)], random_seed=4)\n",
+ "\n",
+ "debias_pipeline = DebiasPipeline(data_pipeline=pipeline_infered_strong, \n",
+ " model=xgb_strong,\n",
+ " metrics=metrics_initial,\n",
+ " identify_pipeline=identify_pipeline,\n",
+ " repair_pipeline=repair_pipeline)\n",
+ "debias_pipeline.run()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "15769fe1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'InferProtectedPipeline0349': [{'accuracy': ('privileged',\n",
+ " 1.0,\n",
+ " 'unprivileged',\n",
+ " 1.0)},\n",
+ " {'equal_opportunity': ('privileged', 1.0, 'unprivileged', 1.0)},\n",
+ " {'demographic_parity': ('privileged',\n",
+ " 0.7865279841505696,\n",
+ " 'unprivileged',\n",
+ " 0.8244197780020182)}],\n",
+ " 'DebiasPipeline0296': [{'accuracy': ('privileged', 1.0, 'unprivileged', 1.0)},\n",
+ " {'equal_opportunity': ('privileged', 1.0, 'unprivileged', 1.0)},\n",
+ " {'demographic_parity': ('privileged',\n",
+ " 0.8244197780020182,\n",
+ " 'unprivileged',\n",
+ " 0.7865279841505696)}]}"
+ ]
+ },
+ "execution_count": 22,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "debias_pipeline.get_protected_metrics()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "358a0eb2",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " issue | \n",
+ " features | \n",
+ " segments | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " correlation_issue | \n",
+ " MCC | \n",
+ " [1, 2, 3, 4, 6, 7, 10, 11, 14, 15, 16, 17, 18,... | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " low_priv_sample | \n",
+ " N/A | \n",
+ " [0] | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " missing_sample | \n",
+ " NaN | \n",
+ " [5] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " issue features \\\n",
+ "0 correlation_issue MCC \n",
+ "1 low_priv_sample N/A \n",
+ "0 missing_sample NaN \n",
+ "\n",
+ " segments \n",
+ "0 [1, 2, 3, 4, 6, 7, 10, 11, 14, 15, 16, 17, 18,... \n",
+ "1 [0] \n",
+ "0 [5] "
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "debias_pipeline.get_issues_summary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a08c22b5",
+ "metadata": {},
+ "source": [
+ "## Compare Against the Non-Inferred Debias Pipeline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "e2bdc712",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:etiq_core.pipeline.DataPipeline0493:Etiq removed the column encoding the protected attribute values from the dataset. The models are fitted and metrics are computed on a dataset without the protected attribute column. The protected attribute values can be found in the protected_train or protected_valid fields of each dataset\n",
+ "INFO:etiq_core.pipeline.DataPipeline0493:Starting pipeline\n",
+ "INFO:etiq_core.pipeline.DataPipeline0493:Fitting model\n",
+ "INFO:etiq_core.pipeline.DataPipeline0493:Computed metrics for the initial dataset\n",
+ "INFO:etiq_core.pipeline.DataPipeline0493:Completed pipeline\n"
+ ]
+ }
+ ],
+ "source": [
+ "full_cat_vars = ['customerID','MCC', 'flag', 'gender']\n",
+ "debias_param = BiasParams(protected='gender', privileged='Male', unprivileged='Female', \n",
+ " positive_outcome_label='0', negative_outcome_label='1')\n",
+ "\n",
+ "dl_full_strong = DatasetLoader(data=full_data_strong, label='flag', transforms=transforms, bias_params=debias_param,\n",
+ " train_valid_test_splits=[0.8, 0.1, 0.1], cat_col=full_cat_vars,\n",
+ " cont_col=cont_vars, names_col = full_data_strong.columns.values)\n",
+ "# Model\n",
+ "xgb_full_strong = DefaultXGBoostClassifier()\n",
+ "pipeline_full_strong = DataPipeline(dataset_loader=dl_full_strong, model=xgb_full_strong, \n",
+ " metrics=metrics_initial)\n",
+ "pipeline_full_strong.run()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3ceb31a9",
+ "metadata": {},
+ "source": [
+ "We now run the debias pipeline on the full data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "2882e286",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:etiq_core.pipeline.DebiasPipeline0993:Warning: potential data leak. You are evaluating a fitted model. If you use the same dataset that you used to train the model and you pass on just the training/validation/test split without passing on which was your validation dataset when you fitted the model, some observations that were previously in train can now be in test - which can skew your results\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0993:Starting pipeline\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0993:Start Phase IdentifyPipeline0849\n",
+ "INFO:etiq_core.pipeline.IdentifyPipeline0849:Starting pipeline\n",
+ "INFO:etiq_core.pipeline.IdentifyPipeline0849:Completed pipeline\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0993:Completed Phase IdentifyPipeline0849\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0993:Start Phase RepairPipeline0769\n",
+ "INFO:etiq_core.pipeline.RepairPipeline0769:Starting pipeline\n",
+ "INFO:etiq_core.pipeline.RepairPipeline0769:Completed pipeline\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0993:Completed Phase RepairPipeline0769\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0993:Refitting model\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0993:Computed metrics for the repaired dataset\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0993:Compare pipeline predictions\n",
+ "INFO:etiq_core.pipeline.DebiasPipeline0993:Completed pipeline\n"
+ ]
+ }
+ ],
+ "source": [
+ "identify_pipeline_full = IdentifyBiasSources(nr_groups=20, # nr of segments based on using unsupervised learning to group similar rows\n",
+ " train_model_segment=True,\n",
+ " group_def=['unsupervised'],\n",
+ " fit_metrics=[accuracy, equal_opportunity])\n",
+ " \n",
+ "# the DebiasPipeline aims to mitigate sources of bias by applying different types of repair algorithms\n",
+ "# the library offers implementations of repair algorithms described in the academic fairness literature\n",
+ "repair_pipeline_full = RepairResamplePipeline(steps=[ResampleUnbiasedSegmentsStep(ratio_resample=1)], random_seed=4)\n",
+ "\n",
+ "debias_pipeline_full = DebiasPipeline(data_pipeline=pipeline_full_strong, \n",
+ " model=xgb_full_strong,\n",
+ " metrics=metrics_initial,\n",
+ " identify_pipeline=identify_pipeline_full,\n",
+ " repair_pipeline=repair_pipeline_full)\n",
+ "debias_pipeline_full.run()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "c39436ce",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'DataPipeline0493': [{'accuracy': ('privileged', 1.0, 'unprivileged', 1.0)},\n",
+ " {'equal_opportunity': ('privileged', 1.0, 'unprivileged', 1.0)},\n",
+ " {'demographic_parity': ('privileged',\n",
+ " 0.8237082066869301,\n",
+ " 'unprivileged',\n",
+ " 0.7873704982733103)}],\n",
+ " 'DebiasPipeline0993': [{'accuracy': ('privileged', 1.0, 'unprivileged', 1.0)},\n",
+ " {'equal_opportunity': ('privileged', 1.0, 'unprivileged', 1.0)},\n",
+ " {'demographic_parity': ('privileged',\n",
+ " 0.8237082066869301,\n",
+ " 'unprivileged',\n",
+ " 0.7873704982733103)}]}"
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "debias_pipeline_full.get_protected_metrics()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "1bf06762",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'InferProtectedPipeline0349': [{'accuracy': ('privileged',\n",
+ " 1.0,\n",
+ " 'unprivileged',\n",
+ " 1.0)},\n",
+ " {'equal_opportunity': ('privileged', 1.0, 'unprivileged', 1.0)},\n",
+ " {'demographic_parity': ('privileged',\n",
+ " 0.7865279841505696,\n",
+ " 'unprivileged',\n",
+ " 0.8244197780020182)}],\n",
+ " 'DebiasPipeline0296': [{'accuracy': ('privileged', 1.0, 'unprivileged', 1.0)},\n",
+ " {'equal_opportunity': ('privileged', 1.0, 'unprivileged', 1.0)},\n",
+ " {'demographic_parity': ('privileged',\n",
+ " 0.8244197780020182,\n",
+ " 'unprivileged',\n",
+ " 0.7865279841505696)}]}"
+ ]
+ },
+ "execution_count": 33,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "debias_pipeline.get_protected_metrics()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "82f6f917",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}