diff --git a/docs/index.html b/docs/index.html index c6b4759..4617259 100644 --- a/docs/index.html +++ b/docs/index.html @@ -633,7 +633,31 @@

Submodules
xai.convert_probs(probs, threshold=0.5)
-

Convert probabilities into classes

+

Converts all the probabilities in the array provided into binary labels +as per the threshold provided which is 0.5 by default.

+

Example

+
probs = np.array([0.1, 0.2, 0.7, 0.8, 0.6])
+labels = xai.convert_probs(probs, threshold=0.65)
+print(labels)
+
+> [0, 0, 1, 1, 0]
+
+
+
+
Parameters
+
    +
  • probs (ndarray) – Numpy array or list containing a list of floats between 0 and 1

  • +
  • threshold (float) – Float that provides the threshold for which probabilities over the +threshold will be converted to 1

  • +
+
+
Returns
+

Numpy array containing the labels based on threshold provided

+
+
Return type
+

np.ndarray

+
+
@@ -691,7 +715,50 @@

Submodules
xai.evaluation_metrics(y_valid, y_pred)
-

+

Calculates model performance metrics (accuracy, precision, recall, etc) +from the actual and predicted lables provided.

+

Example

+
y_actual: np.ndarray
+y_predicted: np.ndarray
+
+metrics = xai.evaluation_metrics(y_actual, y_predicted)
+for k,v in metrics.items():
+    print(f"{k}: {v}")
+
+> precision: 0.8,
+> recall: 0.9,
+> specificity: 0.7,
+> accuracy: 0.8,
+> auc: 0.7,
+> f1: 0.8
+
+
+
+
Parameters
+
    +
  • y_valid – Numpy array with the actual labels for the datapoints

  • +
  • y_pred – Numpy array with the predicted labels for the datapoints

  • +
+
+
Returns
+

Dictionary containing the metrics as follows:

+
return {
+    "precision": precision,
+    "recall": recall,
+    "specificity": specificity,
+    "accuracy": accuracy,
+    "auc": auc,
+    "f1": f1
+}
+
+
+

+
+
Return type
+

Dict[str, float]

+
+
+
@@ -793,7 +860,50 @@

Submodules
xai.metrics_plot(target, predicted, df=Empty DataFrame Columns: [] Index: [], cross_cols=[], categorical_cols=[], bins=6, plot=True, exclude_metrics=[], plot_threshold=0.5)
-

+

Creates a plot that displays statistical metrics including precision, +recall, accuracy, auc, f1 and specificity for each of the groups created +for the columns provided by cross_cols. For example, if the columns passed +are “gender” and “age”, the resulting plot will show the statistical metrics +for Male and Female for each binned group.

+

Example

+
target: np.ndarray
+predicted: np.ndarray
+
+df_metrics = xai.metrics_plot(
+                target,
+                predicted,
+                df=df_data,
+                cross_cols=["gender", "age"],
+                bins=3
+
+
+
+
Parameters
+
    +
  • target (ndarray) – Numpy array containing the target labels for the datapoints

  • +
  • predicted (ndarray) – Numpy array containing the predicted labels for the datapoints

  • +
  • df (DataFrame) – Pandas dataframe containing all the features for the datapoints. +It can be empty if only looking to calculate global metrics, but +if you would like to compute for categories across columns, the +columns you are grouping by need to be provided

  • +
  • cross_cols (List[str]) – Contains the columns that you would like to use to cross the values

  • +
  • bins (int) – [Default: 6] The number of bins in which you’d like +numerical columns to be split

  • +
  • plot (bool) – [Default: True] If True a plot will be drawn with the results

  • +
  • exclude_metrics (List[str]) – These are the metrics that you can choose to exclude if you only +want specific ones (for example, excluding “f1”, “specificity”, etc)

  • +
  • plot_threshold (float) – The percentage that will be used to draw the threshold line in the plot +which would provide guidance on what is the ideal metrics to achieve.

  • +
+
+
Returns
+

Pandas Dataframe containing all the metrics for the groups provided

+
+
Return type
+

pd.DataFrame

+
+
+
diff --git a/docs/searchindex.js b/docs/searchindex.js index 284dd8d..31c67fc 100644 --- a/docs/searchindex.js +++ b/docs/searchindex.js @@ -1 +1 @@ -Search.setIndex({docnames:["gettingstarted","index","xai"],envversion:{"sphinx.domains.c":1,"sphinx.domains.changeset":1,"sphinx.domains.cpp":1,"sphinx.domains.javascript":1,"sphinx.domains.math":2,"sphinx.domains.python":1,"sphinx.domains.rst":1,"sphinx.domains.std":1,sphinx:56},filenames:["gettingstarted.rst","index.rst","xai.rst"],objects:{"":{xai:[1,0,0,"-"]},"xai.data":{load_census:[1,1,1,""]},xai:{balance:[1,1,1,""],balanced_train_test_split:[1,1,1,""],confusion_matrix_plot:[1,1,1,""],convert_categories:[1,1,1,""],convert_probs:[1,1,1,""],correlations:[1,1,1,""],data:[1,0,0,"-"],evaluation_metrics:[1,1,1,""],feature_importance:[1,1,1,""],group_by_columns:[1,1,1,""],imbalance_plot:[1,1,1,""],metrics_plot:[1,1,1,""],normalize_numeric:[1,1,1,""],pr_plot:[1,1,1,""],roc_plot:[1,1,1,""],smile_imbalance:[1,1,1,""]}},objnames:{"0":["py","module","Python module"],"1":["py","function","Python function"]},objtypes:{"0":"py:module","1":"py:function"},terms:{"boolean":1,"default":1,"float":1,"function":1,"int":1,"null":1,"return":1,"throw":1,"true":1,For:1,One:1,The:1,Will:1,With:1,abl:1,abov:1,across:1,actual:1,actual_label:1,adult:1,age:1,algorithm:1,also:1,altern:1,analys:1,ani:1,approv:1,arg:1,argument:1,arrai:1,astyp:1,autom:1,avail:1,axi:1,bal_df:1,balance_on:1,balanced_train_test_split:1,bare:1,base:1,batch_siz:1,behaviour:1,best:1,between:1,bin:1,binari:1,binary_target_label:1,bool:1,broadli:1,build_model:1,built:[],can:1,cat_df:1,categor:1,categori:1,categorical_col:1,censu:1,challeng:1,chart:1,check:1,class_count:1,classifi:1,clone:1,code:1,col:1,com:1,combin:1,comput:1,conceiv:1,confus:1,confusion_matrix_plot:1,construct:1,contain:1,convert:1,convert_categori:1,convert_prob:1,copi:1,core:1,correl:1,cross:1,cross_col:1,cross_cross:[],current:1,datafram:1,dataframegroupbi:1,datapoint:1,dataset:1,deep:1,def:1,definit:1,dendogram:1,design:1,detail:1,develop:1,deviat:1,df_group:1,df_test:1,diagram:1,discrep:1,displai:1,display_breakdown:1,divid:1,document:[0,1],domain:1,drop:1,dtype:1,due:1,dure:1,each:1,earli:1,element:1,empow:1,empti:1,enabl:1,end:1,engin:1,enough:1,epoch:1,equal:1,error:1,ethic:1,ethicalml:1,ethnic:1,evaluation_metr:1,everyth:0,exampl:1,except:1,exclude_metr:1,expect:1,expert:1,explicitli:1,f_in:1,fallback:1,fallback_typ:1,fals:1,fast:1,feature_import:1,femal:1,find:[0,1],first:1,fit:1,folder:1,frame:[],frequenc:1,from:1,full:1,func:1,gender:1,get:1,get_avg:1,github:1,group_by_column:1,groupbi:1,grouped_df:1,guid:1,half:1,head:1,henc:1,here:[0,1],higher:1,highest:1,html:1,http:1,idea:1,ignor:1,imag:1,imbalance_plot:1,imp:1,ims:1,include_categor:1,index:[0,1],infer:1,input:1,insight:1,instanc:1,instead:1,institut:1,interact:1,intern:1,involv:1,its:1,just:1,knowledg:1,label:1,label_x_neg:1,label_x_po:1,label_y_neg:1,label_y_po:1,layer:1,less:1,librari:1,like:1,list:1,load:1,load_censu:1,loan:1,london:1,lower:1,mai:1,main:1,maintain:1,male:1,manual_review:1,matrix:1,max_per_group:1,maximum:1,meantim:1,mechan:1,mention:1,metrics_imbal:1,metrics_plot:1,min_per_class:1,min_per_group:1,mind:1,minimum:1,monitor:1,more:1,multiarrai:[],name:1,ndarrai:1,neg:1,none:1,norm_df:1,normal:1,normalize_numer:1,noth:1,number:1,numer:1,numerci:1,numerical_col:1,numpi:1,object:1,optim:1,option:1,other_numeric_attribut:1,otherwis:1,our:1,out:1,overview:1,packag:1,page:0,panda:1,param:1,paramet:1,part:1,pass:1,percentag:1,perform:1,pip:1,pleas:1,plot:1,plot_threshold:1,plot_typ:1,plt_kwarg:1,posit:1,pr_imbal:1,pr_plot:1,practic:1,pred:1,predict:1,predicted_label:1,predictedd:1,principl:1,print:1,prob:1,proc_df:1,process:1,product:1,protected_col:1,provid:1,pypi:1,quick:1,quit:1,random:1,random_st:1,reject:1,rel:1,relev:1,repeat:1,replac:1,repo:1,repositori:1,request:1,requir:1,respons:1,result:1,return_xi:1,roc_imbal:1,roc_plot:1,rtype:1,run:1,sampl:1,scale:1,scienc:1,search:0,see:1,seed:1,set:1,setup:1,show:1,show_imbal:1,singl:1,smile_imbal:1,solut:1,sourc:1,specif:1,stage:1,standard:1,start:1,step:1,str:1,string:1,sub:1,subcategori:1,substract:1,take:1,talk:1,target:1,tensorflow:1,test_idx:1,than:1,thi:1,those:1,three:1,through:1,titl:0,tool:1,total:1,toward:1,train_idx:1,trane:1,treat:1,trigger:1,tupl:1,type:1,under:1,union:1,unstabl:1,updat:1,use:1,used:1,uses:1,valu:1,variou:1,verbos:1,visual:1,welcom:0,when:1,where:1,whether:1,which:1,within:1,would:1,x_test:1,x_train:1,xai:0,y_pred:1,y_test:1,y_train:1,y_valid:1,you:[0,1]},titles:["Getting started guide","Welcome to the XAI docs - eXplainable machine learning","<no title>"],titleterms:{"class":1,"import":1,about:1,accuraci:1,adding:1,against:1,all:1,alpha:1,analysi:1,anoth:1,balanc:1,benefit:1,bucket:1,column:1,content:1,creat:1,curv:1,data:1,doc:1,docstr:1,done:1,downsampl:1,evalu:1,explain:1,featur:1,get:0,group:1,guid:0,identifi:1,imbal:1,indic:0,instal:1,intersect:1,learn:1,machin:1,manual:1,mean:1,metric:1,model:1,modul:1,one:1,permut:1,pre:1,precis:1,probabl:1,protect:1,python:1,quickstart:1,recal:1,review:1,roc:1,should:1,split:1,start:0,statist:1,submodul:1,tabl:0,test:1,threshold:1,train:1,upsampl:1,usag:1,using:1,version:1,view:1,visualis:1,welcom:1,what:1,xai:1}}) \ No newline at end of file +Search.setIndex({docnames:["gettingstarted","index","xai"],envversion:{"sphinx.domains.c":1,"sphinx.domains.changeset":1,"sphinx.domains.cpp":1,"sphinx.domains.javascript":1,"sphinx.domains.math":2,"sphinx.domains.python":1,"sphinx.domains.rst":1,"sphinx.domains.std":1,sphinx:56},filenames:["gettingstarted.rst","index.rst","xai.rst"],objects:{"":{xai:[1,0,0,"-"]},"xai.data":{load_census:[1,1,1,""]},xai:{balance:[1,1,1,""],balanced_train_test_split:[1,1,1,""],confusion_matrix_plot:[1,1,1,""],convert_categories:[1,1,1,""],convert_probs:[1,1,1,""],correlations:[1,1,1,""],data:[1,0,0,"-"],evaluation_metrics:[1,1,1,""],feature_importance:[1,1,1,""],group_by_columns:[1,1,1,""],imbalance_plot:[1,1,1,""],metrics_plot:[1,1,1,""],normalize_numeric:[1,1,1,""],pr_plot:[1,1,1,""],roc_plot:[1,1,1,""],smile_imbalance:[1,1,1,""]}},objnames:{"0":["py","module","Python module"],"1":["py","function","Python function"]},objtypes:{"0":"py:module","1":"py:function"},terms:{"boolean":1,"default":1,"float":1,"function":1,"int":1,"null":1,"return":1,"throw":1,"true":1,For:1,One:1,The:1,These:1,Will:1,With:1,abl:1,abov:1,achiev:1,across:1,actual:1,actual_label:1,adult:1,age:1,algorithm:1,also:1,altern:1,analys:1,ani:1,approv:1,arg:1,argument:1,arrai:1,astyp:1,auc:1,autom:1,avail:1,axi:1,bal_df:1,balance_on:1,balanced_train_test_split:1,bare:1,base:1,batch_siz:1,behaviour:1,best:1,between:1,bin:1,binari:1,binary_target_label:1,bool:1,broadli:1,build_model:1,built:[],calcul:1,can:1,cat_df:1,categor:1,categori:1,categorical_col:1,censu:1,challeng:1,chart:1,check:1,choos:1,class_count:1,classifi:1,clone:1,code:1,col:1,com:1,combin:1,comput:1,conceiv:1,confus:1,confusion_matrix_plot:1,construct:1,contain:1,convert:1,convert_categori:1,convert_prob:1,copi:1,core:1,correl:1,cross:1,cross_col:1,cross_cross:[],current:1,datafram:1,dataframegroupbi:1,datapoint:1,dataset:1,deep:1,def:1,definit:1,dendogram:1,design:1,detail:1,develop:1,deviat:1,df_data:1,df_group:1,df_metric:1,df_test:1,diagram:1,dict:1,dictionari:1,discrep:1,displai:1,display_breakdown:1,divid:1,document:[0,1],domain:1,draw:1,drawn:1,drop:1,dtype:1,due:1,dure:1,each:1,earli:1,element:1,empow:1,empti:1,enabl:1,end:1,engin:1,enough:1,epoch:1,equal:1,error:1,etc:1,ethic:1,ethicalml:1,ethnic:1,evaluation_metr:1,everyth:0,exampl:1,except:1,exclud:1,exclude_metr:1,expect:1,expert:1,explicitli:1,f_in:1,fallback:1,fallback_typ:1,fals:1,fast:1,feature_import:1,femal:1,find:[0,1],first:1,fit:1,folder:1,follow:1,frame:[],frequenc:1,from:1,full:1,func:1,gender:1,get:1,get_avg:1,github:1,global:1,group_by_column:1,groupbi:1,grouped_df:1,guid:1,guidanc:1,half:1,head:1,henc:1,here:[0,1],higher:1,highest:1,html:1,http:1,idea:1,ideal:1,ignor:1,imag:1,imbalance_plot:1,imp:1,ims:1,includ:1,include_categor:1,index:[0,1],infer:1,input:1,insight:1,instanc:1,instead:1,institut:1,interact:1,intern:1,involv:1,item:1,its:1,just:1,knowledg:1,label:1,label_x_neg:1,label_x_po:1,label_y_neg:1,label_y_po:1,labl:1,layer:1,less:1,librari:1,like:1,line:1,list:1,load:1,load_censu:1,loan:1,london:1,look:1,lower:1,mai:1,main:1,maintain:1,male:1,manual_review:1,matrix:1,max_per_group:1,maximum:1,meantim:1,mechan:1,mention:1,metrics_imbal:1,metrics_plot:1,min_per_class:1,min_per_group:1,mind:1,minimum:1,monitor:1,more:1,multiarrai:[],name:1,ndarrai:1,need:1,neg:1,none:1,norm_df:1,normal:1,normalize_numer:1,noth:1,number:1,numer:1,numerci:1,numerical_col:1,numpi:1,object:1,ones:1,onli:1,optim:1,option:1,other_numeric_attribut:1,otherwis:1,our:1,out:1,over:1,overview:1,packag:1,page:0,panda:1,param:1,paramet:1,part:1,pass:1,per:1,percentag:1,perform:1,pip:1,pleas:1,plot:1,plot_threshold:1,plot_typ:1,plt_kwarg:1,posit:1,pr_imbal:1,pr_plot:1,practic:1,pred:1,predict:1,predicted_label:1,predictedd:1,principl:1,print:1,prob:1,proc_df:1,process:1,product:1,protected_col:1,provid:1,pypi:1,quick:1,quit:1,random:1,random_st:1,reject:1,rel:1,relev:1,repeat:1,replac:1,repo:1,repositori:1,request:1,requir:1,respons:1,result:1,return_xi:1,roc_imbal:1,roc_plot:1,rtype:1,run:1,sampl:1,scale:1,scienc:1,search:0,see:1,seed:1,set:1,setup:1,show:1,show_imbal:1,singl:1,smile_imbal:1,solut:1,sourc:1,specif:1,stage:1,standard:1,start:1,step:1,str:1,string:1,sub:1,subcategori:1,substract:1,take:1,talk:1,target:1,tensorflow:1,test_idx:1,than:1,thi:1,those:1,three:1,through:1,titl:0,tool:1,total:1,toward:1,train_idx:1,trane:1,treat:1,trigger:1,tupl:1,type:1,under:1,union:1,unstabl:1,updat:1,use:1,used:1,uses:1,valu:1,variou:1,verbos:1,visual:1,want:1,welcom:0,when:1,where:1,whether:1,which:1,within:1,would:1,x_test:1,x_train:1,xai:0,y_actual:1,y_pred:1,y_predict:1,y_test:1,y_train:1,y_valid:1,you:[0,1]},titles:["Getting started guide","Welcome to the XAI docs - eXplainable machine learning","<no title>"],titleterms:{"class":1,"import":1,about:1,accuraci:1,adding:1,against:1,all:1,alpha:1,analysi:1,anoth:1,balanc:1,benefit:1,bucket:1,column:1,content:1,creat:1,curv:1,data:1,doc:1,docstr:1,done:1,downsampl:1,evalu:1,explain:1,featur:1,get:0,group:1,guid:0,identifi:1,imbal:1,indic:0,instal:1,intersect:1,learn:1,machin:1,manual:1,mean:1,metric:1,model:1,modul:1,one:1,permut:1,pre:1,precis:1,probabl:1,protect:1,python:1,quickstart:1,recal:1,review:1,roc:1,should:1,split:1,start:0,statist:1,submodul:1,tabl:0,test:1,threshold:1,train:1,upsampl:1,usag:1,using:1,version:1,view:1,visualis:1,welcom:1,what:1,xai:1}}) \ No newline at end of file diff --git a/xai/__init__.py b/xai/__init__.py index 98a72a6..0312313 100644 --- a/xai/__init__.py +++ b/xai/__init__.py @@ -3,7 +3,7 @@ import numpy as np from scipy.stats import spearmanr as sr from scipy.cluster import hierarchy as hc -from typing import List, Any, Union, Tuple, Optional +from typing import List, Any, Union, Tuple, Optional, Dict import random, math # TODO: Remove Dependencies, starting with Sklearn from sklearn.metrics import roc_curve, \ @@ -662,12 +662,97 @@ def resample(x): return x_train, y_train, x_test, y_test, train_idx, test_idx -def convert_probs(probs, threshold=0.5): - """Convert probabilities into classes""" - # TODO: Enable for multiclass +def convert_probs( + probs: np.ndarray, + threshold: float = 0.5 + ) -> np.ndarray: + """ + Converts all the probabilities in the array provided into binary labels + as per the threshold provided which is 0.5 by default. + + Example + --------- + + .. code-block:: python + + probs = np.array([0.1, 0.2, 0.7, 0.8, 0.6]) + labels = xai.convert_probs(probs, threshold=0.65) + print(labels) + + > [0, 0, 1, 1, 0] + + Args + ------- + + probs : + Numpy array or list containing a list of floats between 0 and 1 + threshold : + Float that provides the threshold for which probabilities over the + threshold will be converted to 1 + + Returns + ---------- + + : np.ndarray + Numpy array containing the labels based on threshold provided + + """ + return (probs >= threshold).astype(int) -def evaluation_metrics(y_valid, y_pred): +def evaluation_metrics( + y_valid, + y_pred + ) -> Dict[str, float]: + """ + Calculates model performance metrics (accuracy, precision, recall, etc) + from the actual and predicted lables provided. + + Example + --------- + + .. code-block:: python + + y_actual: np.ndarray + y_predicted: np.ndarray + + metrics = xai.evaluation_metrics(y_actual, y_predicted) + for k,v in metrics.items(): + print(f"{k}: {v}") + + > precision: 0.8, + > recall: 0.9, + > specificity: 0.7, + > accuracy: 0.8, + > auc: 0.7, + > f1: 0.8 + + Args + ------- + + y_valid : + Numpy array with the actual labels for the datapoints + y_pred : + Numpy array with the predicted labels for the datapoints + + Returns + ---------- + + : Dict[str, float] + Dictionary containing the metrics as follows: + + .. code-block:: python + + return { + "precision": precision, + "recall": recall, + "specificity": specificity, + "accuracy": accuracy, + "auc": auc, + "f1": f1 + } + + """ TP = np.sum( y_pred[y_valid==1] ) TN = np.sum( y_pred[y_valid==0] == 0 ) @@ -695,15 +780,71 @@ def evaluation_metrics(y_valid, y_pred): } def metrics_plot( - target, - predicted, - df=pd.DataFrame(), - cross_cols=[], - categorical_cols=[], - bins=6, - plot=True, - exclude_metrics=[], - plot_threshold=0.5): + target: np.ndarray, + predicted: np.ndarray, + df: pd.DataFrame = pd.DataFrame(), + cross_cols: List[str] = [], + categorical_cols: List[str] = [], + bins: int = 6, + plot: bool = True, + exclude_metrics: List[str] = [], + plot_threshold: float = 0.5 + ) -> pd.DataFrame: + """ + Creates a plot that displays statistical metrics including precision, + recall, accuracy, auc, f1 and specificity for each of the groups created + for the columns provided by cross_cols. For example, if the columns passed + are "gender" and "age", the resulting plot will show the statistical metrics + for Male and Female for each binned group. + + Example + --------- + + .. code-block:: python + + target: np.ndarray + predicted: np.ndarray + + df_metrics = xai.metrics_plot( + target, + predicted, + df=df_data, + cross_cols=["gender", "age"], + bins=3 + + Args + ------- + + target: + Numpy array containing the target labels for the datapoints + predicted : + Numpy array containing the predicted labels for the datapoints + df : + Pandas dataframe containing all the features for the datapoints. + It can be empty if only looking to calculate global metrics, but + if you would like to compute for categories across columns, the + columns you are grouping by need to be provided + cross_cols : + Contains the columns that you would like to use to cross the values + bins : + [Default: 6] The number of bins in which you'd like + numerical columns to be split + plot : + [Default: True] If True a plot will be drawn with the results + exclude_metrics : + These are the metrics that you can choose to exclude if you only + want specific ones (for example, excluding "f1", "specificity", etc) + plot_threshold: + The percentage that will be used to draw the threshold line in the plot + which would provide guidance on what is the ideal metrics to achieve. + + Returns + ---------- + + : pd.DataFrame + Pandas Dataframe containing all the metrics for the groups provided + + """ grouped = _group_metrics( target, @@ -717,7 +858,7 @@ def metrics_plot( prfs = [] classes = [] for group, group_df in grouped: - group_valid = group_df["target"].values + group_valid = group_df['target'].values group_pred = group_df["predicted"].values metrics_dict = \ evaluation_metrics(group_valid, group_pred)