Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRXF pmml scorecard reader #168

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .__version__ import version
from .reader import AbstractReader, TrxfReader
from .reader import AbstractReader, TrxfRuleSetReader
from .serializer import AbstractSerializer, NyokaSerializer
from .pmml_exporter import PmmlExporter
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .predicate import CompoundPredicate
from .predicate import Operator
from .predicate import SimplePredicate
from .predicate import SimpleSetPredicate
from .predicate import TruePredicate
from .rule import SimpleRule
from .rule import DEFAULT_WEIGHT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,7 @@ class ComplexPartialScore:
@dataclass(frozen=True)
class Attribute:
score: typing.Union[str, ComplexPartialScore]
predicate: typing.Union[predicate.SimplePredicate, predicate.CompoundPredicate, predicate.TruePredicate]
predicate: typing.Union[predicate.SimplePredicate,
predicate.CompoundPredicate,
predicate.SimpleSetPredicate,
predicate.TruePredicate]
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,13 @@ class TruePredicate:
class CompoundPredicate:
simplePredicates: typing.List[SimplePredicate] = field()
booleanOperator: BooleanOperator = field()


MembershipOperator = enum.Enum('MembershipOperator', [('isIn', 0), ('isNotIn', 1)])


@dataclass(frozen=True)
class SimpleSetPredicate:
field_: str = field()
membershipOperator: MembershipOperator = field()
values: typing.Set[str] = field()
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from aix360.algorithms.rule_induction.trxf.classifier.ruleset_classifier import RuleSetClassifier
from aix360.algorithms.rule_induction.trxf.pmml_export import AbstractSerializer, AbstractReader


Expand All @@ -7,12 +6,12 @@ def __init__(self, reader: AbstractReader, serializer: AbstractSerializer):
self._serializer = serializer
self._reader = reader

def export(self, trxf_classifier: RuleSetClassifier):
def export(self, trxf_input):
"""
Translate a given TRXF RuleSetClassifier to a PMML string
@param trxf_classifier: A TRXF RuleSetClassifier
Translate a given TRXF RuleSetClassifier or Scorecard to a PMML string
@param trxf_input: A TRXF RuleSetClassifier or Scorecard object
@return: The corresponding PMML string
"""
if self._reader.data_dictionary is None:
raise AssertionError("Missing data dictionary in reader object")
return self._serializer.serialize(self._reader.read(trxf_classifier))
return self._serializer.serialize(self._reader.read(trxf_input))
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .abstract_reader import AbstractReader
from .trxf_reader import TrxfReader
from .trxf_ruleset_reader import TrxfRuleSetReader
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
from typing import Dict

import numpy as np
import pandas as pd

from aix360.algorithms.rule_induction.trxf.classifier import ruleset_classifier
from aix360.algorithms.rule_induction.trxf.classifier.ruleset_classifier import RuleSetClassifier
from aix360.algorithms.rule_induction.trxf.core import Conjunction, Relation
from aix360.algorithms.rule_induction.trxf.pmml_export import models
from aix360.algorithms.rule_induction.trxf.pmml_export.models.data_dictionary import Value
from aix360.algorithms.rule_induction.trxf.pmml_export.reader import AbstractReader
from aix360.algorithms.rule_induction.trxf.pmml_export.models import SimplePredicate, Operator, CompoundPredicate, \
BooleanOperator
from aix360.algorithms.rule_induction.trxf.pmml_export.utilities import extract_data_dictionary, trxf_to_pmml_predicate


class TrxfReader(AbstractReader):
class TrxfRuleSetReader(AbstractReader):
def __init__(self, data_dictionary=None):
self._data_dictionary = data_dictionary

Expand Down Expand Up @@ -48,35 +44,13 @@ def load_data_dictionary(self, X: pd.DataFrame, values: Dict = None):
@param X: Input dataframe
@param values: A dict mapping column name to a list of possible categorical values. It will be inferred from X if not provided.
"""
dtypes = X.dtypes
data_fields = []
for index, value in dtypes.items():
vals = None
if np.issubdtype(value, np.integer):
data_type = models.DataType.integer
op_type = models.OpType.ordinal
elif np.issubdtype(value, np.double):
data_type = models.DataType.double
op_type = models.OpType.continuous
elif np.issubdtype(value, np.floating):
data_type = models.DataType.float
op_type = models.OpType.continuous
elif np.issubdtype(value, np.bool_):
data_type = models.DataType.boolean
op_type = models.OpType.categorical
else:
data_type = models.DataType.string
op_type = models.OpType.categorical
vals = values[index] if values is not None and index in values else list(X[index].unique())
wrapped_vals = list(map(lambda v: Value(v), vals)) if vals is not None else vals
data_fields.append(models.DataField(name=str(index), optype=op_type, dataType=data_type, values=wrapped_vals))
self._data_dictionary = models.DataDictionary(data_fields)
self._data_dictionary = extract_data_dictionary(X, values)


def _convert_to_simple_rules(trxf_rules):
simple_rules = []
for rule in trxf_rules:
predicate = _convert_to_pmml_predicate(rule.conjunction)
predicate = trxf_to_pmml_predicate(rule.conjunction)
confidence = rule.confidence if rule.confidence is not None else models.DEFAULT_CONFIDENCE
weight = rule.weight if rule.weight is not None else models.DEFAULT_WEIGHT
simple_rule = models.SimpleRule(predicate=predicate, score=str(rule.label), id=str(rule.conjunction),
Expand All @@ -86,22 +60,6 @@ def _convert_to_simple_rules(trxf_rules):
return simple_rules


def _convert_to_pmml_predicate(trxf_conjunction: Conjunction):
trxf_to_pmml_op = {
Relation.EQ: Operator.equal,
Relation.NEQ: Operator.notEqual,
Relation.LT: Operator.lessThan,
Relation.LE: Operator.lessOrEqual,
Relation.GT: Operator.greaterThan,
Relation.GE: Operator.greaterOrEqual
}
simple_predicates = [SimplePredicate(operator=trxf_to_pmml_op[trxf_predicate.relation],
value=str(trxf_predicate.value),
field=str(trxf_predicate.feature.variable_names[0]))
for trxf_predicate in trxf_conjunction.predicates]
return CompoundPredicate(simplePredicates=simple_predicates, booleanOperator=BooleanOperator.and_)


def _extract_mining_schema(trxf_rules):
mining_fields = {}
for rule in trxf_rules:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import pandas as pd

from aix360.algorithms.rule_induction.trxf import scorecard
from aix360.algorithms.rule_induction.trxf.pmml_export import models
from aix360.algorithms.rule_induction.trxf.pmml_export.models import SimpleSetPredicate
from aix360.algorithms.rule_induction.trxf.pmml_export.models.predicate import MembershipOperator
from aix360.algorithms.rule_induction.trxf.pmml_export.reader import AbstractReader
from aix360.algorithms.rule_induction.trxf.pmml_export.utilities import extract_data_dictionary, trxf_to_pmml_predicate


class TrxfScorecardReader(AbstractReader):
def __init__(self, data_dictionary=None):
self._data_dictionary = data_dictionary

@property
def data_dictionary(self):
return self._data_dictionary

def read(self, trxf_scorecard: scorecard.Scorecard) -> models.Scorecard:
"""
Translate a TRXF Scorecard to an internal Scorecard
"""
mining_schema = _extract_mining_schema(trxf_scorecard.features)
output = models.Output([models.OutputField(name='RawResult',
feature='predictedValue',
dataType=models.DataType.double,
optype=models.OpType.continuous)])
characteristics = _extract_characteristics(trxf_scorecard)

assert self._data_dictionary is not None
return models.Scorecard(dataDictionary=self._data_dictionary,
miningSchema=mining_schema,
output=output,
characteristics=characteristics,
initialScore=str(trxf_scorecard.bias))

def load_data_dictionary(self, X: pd.DataFrame, values=None):
"""
Extract the data dictionary from a feature dataframe, and store it
"""
self._data_dictionary = extract_data_dictionary(X, values)


def _extract_mining_schema(scorecard_features):
mining_fields = {}
for feature in scorecard_features:
name = feature.variable_names[0]
if name not in mining_fields:
mining_field = models.MiningField(name=name)
mining_fields[name] = mining_field
return models.MiningSchema(miningFields=list(mining_fields.values()))


def _extract_characteristics(trxf_scorecard):
characteristics = []
for partition in trxf_scorecard.partitions:
feature_name = partition.feature.variable_names[0]
attributes = []
for bin in partition.bins:
if isinstance(bin, scorecard.IntervalBin):
conjunction = bin.to_conjunction()
predicate = trxf_to_pmml_predicate(conjunction)
elif isinstance(bin, scorecard.SetBin):
predicate = SimpleSetPredicate(field_=feature_name,
membershipOperator=MembershipOperator.isIn,
values=bin.values)
else:
raise ValueError('Unsupported Bin type {} for feature {}'.format(type(bin), feature_name))
score = str(bin.sub_score)
attribute = models.Attribute(score=score, predicate=predicate)
attributes.append(attribute)
characteristic = models.Characteristic(name=feature_name, attributes=attributes)
characteristics.append(characteristic)
return models.Characteristics(characteristics)


Original file line number Diff line number Diff line change
Expand Up @@ -158,5 +158,10 @@ def _nyoka_pmml_attributes(self, attribute: models.Attribute) -> nyoka_pmml.Attr
SimplePredicate=[
nyoka_pmml.SimplePredicate(field=sp.field, operator=sp.operator.name, value=sp.value)
for sp in attribute.predicate.simplePredicates]),
SimpleSetPredicate=None if (attribute.predicate is None or not isinstance(
attribute.predicate, models.SimpleSetPredicate)) else nyoka_pmml.SimpleSetPredicate(
field=attribute.predicate.field_,
booleanOperator=attribute.predicate.membershipOperator.name,
Array=nyoka_pmml.ArrayType(list(attribute.predicate.values))),
True_=None if (attribute.predicate is None or not isinstance(
attribute.predicate, models.TruePredicate)) else nyoka_pmml.True_())
56 changes: 56 additions & 0 deletions aix360/algorithms/rule_induction/trxf/pmml_export/utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Dict

import numpy as np
import pandas as pd
from aix360.algorithms.rule_induction.trxf.pmml_export.models import Operator, SimplePredicate, CompoundPredicate, \
BooleanOperator, Value

from aix360.algorithms.rule_induction.trxf.core import Conjunction, Relation

from aix360.algorithms.rule_induction.trxf.pmml_export import models


def extract_data_dictionary(X: pd.DataFrame, values: Dict):
"""
Extract the data dictionary from a feature dataframe
"""
dtypes = X.dtypes
data_fields = []
for index, value in dtypes.items():
vals = None
if np.issubdtype(value, np.integer):
data_type = models.DataType.integer
op_type = models.OpType.ordinal
elif np.issubdtype(value, np.double):
data_type = models.DataType.double
op_type = models.OpType.continuous
elif np.issubdtype(value, np.floating):
data_type = models.DataType.float
op_type = models.OpType.continuous
elif np.issubdtype(value, np.bool_):
data_type = models.DataType.boolean
op_type = models.OpType.categorical
else:
data_type = models.DataType.string
op_type = models.OpType.categorical
vals = values[index] if values is not None and index in values else list(X[index].unique())
wrapped_vals = list(map(lambda v: Value(v), vals)) if vals is not None else vals
data_fields.append(models.DataField(name=str(index), optype=op_type, dataType=data_type, values=wrapped_vals))

return models.DataDictionary(data_fields)


def trxf_to_pmml_predicate(trxf_conjunction: Conjunction):
trxf_to_pmml_op = {
Relation.EQ: Operator.equal,
Relation.NEQ: Operator.notEqual,
Relation.LT: Operator.lessThan,
Relation.LE: Operator.lessOrEqual,
Relation.GT: Operator.greaterThan,
Relation.GE: Operator.greaterOrEqual
}
simple_predicates = [SimplePredicate(operator=trxf_to_pmml_op[trxf_predicate.relation],
value=str(trxf_predicate.value),
field=str(trxf_predicate.feature.variable_names[0]))
for trxf_predicate in trxf_conjunction.predicates]
return CompoundPredicate(simplePredicates=simple_predicates, booleanOperator=BooleanOperator.and_)
13 changes: 12 additions & 1 deletion aix360/algorithms/rule_induction/trxf/scorecard/bins.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
from numbers import Real
from typing import Dict, Set, Any, Optional
from aix360.algorithms.rule_induction.trxf.core import Feature
from aix360.algorithms.rule_induction.trxf.core import Feature, Predicate, Relation, Conjunction


class Bin(abc.ABC):
Expand Down Expand Up @@ -121,6 +121,17 @@ def overlaps(self, other: 'LinearIntervalBin') -> bool:
'is an instance of "{}"'.format(str(other.__class__)))
return (self.left_end < other.right_end) and (self.right_end > other.left_end)

def to_conjunction(self):
"""
Converts bin to trxf.Conjunction
"""
left = Predicate(feature=self.feature, relation=Relation.GE, value=self.left_end) if \
self.left_end > float('-inf') else None
right = Predicate(feature=self.feature, relation=Relation.LT, value=self.right_end) if \
self.right_end < float('inf') else None
predicates = [p for p in [left, right] if p is not None]
return Conjunction(predicates)

def _get_feature_value(self, assignment: Dict[str, Any]) -> Real:
"""
Evaluates the value of the feature for the specified variable assignment. Raises ValueError if the feature
Expand Down
11 changes: 6 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
author_email='[email protected]',
packages=setuptools.find_packages(),
license='Apache License 2.0',
long_description=open('README.md', 'r', encoding='utf-8').read(),
long_description_content_type='text/markdown',
long_description=open('README.md', 'r', encoding='utf-8').read(),
long_description_content_type='text/markdown',
install_requires=[
'joblib>=0.11',
'scikit-learn>=0.21.2',
Expand All @@ -31,16 +31,17 @@
'pandas',
'scipy>=0.17',
'xport',
'scikit-image',
'scikit-image',
'requests',
'xgboost==1.1.0',
'xgboost==1.1.0',
'bleach>=2.1.0',
'docutils>=0.13.1',
'Pygments',
'osqp',
'osqp',
'lime==0.1.1.37',
'shap==0.34.0',
'nyoka==5.2.0',
'pypmml',
'xmltodict==0.12.0',
'numba',
'tqdm',
Expand Down
Loading