-
Notifications
You must be signed in to change notification settings - Fork 1
Home
If you are not familiar with decision trees, use the following resources:
-
An Introduction to Statistical Learning, by G. James, D. Witten, T. Hastie and R. Tibshirani (Springer texts in Statistics)
-
Elements of Statistical Learning, by T. Hastie, R. Tibshirani, J. Friedman (Spinger series in Statistics)
Next, a short overview to understand how to use the library. For those of you interested in some of the internals, follow this link to download our technical report.
Treelib helps you build Classification or Regression Trees by executing model training in parallel, using the Apache Spark execution engine. In addition to model building, additional utilities include Tree Pruning, and Random Forests. Treelib implements state-of-the-art algorithms including Binary Tree models (using the CART algorithm) and Multi-way Tree modles (using ID3).
Clone this repo, and compile the library using sbt compile
. You fill find the library jars in the target folder of the project (look for the specific scala version you are using). Copy the jar of Treelib in the lib folder of your Spark application, then simply import Treelib in your Spark application code.
Every tree builder (RegressionTree, CART Classification Tree, ID3 Classification Tree) has the same prototype: how to declare a tree, build and validate tree models is shown next.
First, import the required package by import fr.eurecom.dsg.treelib.cart._
and then create a tree builder: val tree = new RegressionTree
First, import the required package by import fr.eurecom.dsg.treelib.cart._
and then create a tree builder: val tree = new ClassificationTree()
First, import the required package by import fr.eurecom.dsg.treelib.id3._
and then create tree builder: val tree = new ID3TreeBuilder()
The first steps to build a tree model is to set the training data using tree.setDataSet(data)
, where data
is the RDD storing the training dataset. Note: Treelib only supports data WIHOUT header. However, it is possible to manually label features using tree.setFeatureNames(<Array_string_name_of_features>)
. If not otherwise specified, features will be named: "Column1", "Column2",...
To start building a tree, use the method tree.buildTree(yFeature, SetOfXFeatures)
, where yFeature
is the name of the target feature; SetOfXFeatures
is the set of predictors. These two parameters are optional: calling tree.buildTree()
, with no arguments implicitly assumes that the last feature of the dataset is the target, while the preceding features are predictors.
Assume that we have a training data with schema: "Temperature", "Humidity", "Outlook", "Money", "Month" , "DayOfWeek", PlayGolf", where
"Temperature" is a float, to indicate the temperature in Celsius degrees; "Humidity" can take a value in the set {High, Normal}; "Outlook" can take a value in the set {Overcast, Rainy, Sunny}; "Money" is a float indicating the wealth of a person, "Month" has value range [1, 12] "DayOfWeek" has the value set {0, 1, 2, 3, 4, 5, 6}, corresponding to {Saturday, Sunday, Monday, Tuesday, Wednesday, Thurday, Friday} relatively; "PlayGolf" has two categorical valuse : "Yes" and "No"
Our goal is to predict whether a person will play golf or not, based on all the features described above (hence, the last feature is our target). We build and train our Classification tree as follows:
// read training data
val playgolf_training_data = context.textFile("data/playgolf.csv", 1)
// create Regression Tree
val tree = new ClassificationTree()
// set the training data
tree.setDataset(playgolf_training_data)
// set the header for the training data
tree.setFeatureNames(Array[String]("Temperature", "Humidity", "Outlook", "Money", "Month" , "DayOfWeek", PlayGolf"))
// build tree with the target is the last feature (PlayGolf), the rest is predictor
tree.buildTree()
Alternatively, if we want to omit the predictor "Month", we can call:
tree.buildTree("Playgolf", Set[Any]("Temperature", "Humidity", "Outlook", "Money", "DayOfWeek"))
If you don't specify the type of predictor, the algorithm will detect them automatically by considering their values. For instance, with the above statement, "Temperatue", "Money", "Month", "DayOfWeek" will be treat as Numerical feature, the rest are as Categorical feature. In case you want to explicit the type of feature, you can use as
object, similiarly to R. For example:
tree.buildTree("Playgolf", Set[Any]("Temperature", "Humidity", "Outlook", as.Number("Money"), as.String("DayOfWeek"))
Some additional tuning parameters can also be easily set::
tree.setDelimiter("\t") // set delimiter of fields is tab, default is ","
tree.setMinSplit(1) // only grow node if it has more than 1 records
tree.setThreshold(0.3) // only grow node if it the coefficient of variation of Y> 0.3; CV = Deviation(Y)/E(Y)
Once you grew a tree, you can use it to make predictions as follows:
predictOneInstance(record: Array[String], ignoreBranchIDs: Set[BigInt] = Set[BigInt]())
where record
is an array of predictors' values; ignoreBranchIDs
is the set of ID of nodes, which we want to force them make prediction as a leaf node, instead of using their children (this parameter is optional). For example:
tree.predictOneInstance(Array[String]("cool","sunny","normal","false","30"))
Finally, you can evaluate the quality of the tree model using a test dataset:
val predictRDDOfTheFullTree = tree.predict(testingData)
val actualValueRDD = testingData.map(line => line.split(',').last) // the last feature is the target
Evaluation.evaluate(predictRDDOfTheFullTree, actualValueRDD)
The metric which we use to evaluate are the mean, standard deviation and square error of the difference between the predicted value and the ground-truth values. For example, an evaluation result of a regression tree may look like:
Mean of different:-0.000000
Deviation of different:2.596211
SE of different:0.036566
Treelib allows you to save and load tree models as well:
// tree is created by 'val tree = new RegressionTree()'
// write model to file
tree.writeModelToFile("my_tree.model")
// create a new regression tree
val new_tree = new RegressionTree ();
// load tree model from a file
new_tree.loadModelFromFile("my_tree.model")
Let's proceed with a new example: we want to build a Regression tree using the 'bodyfat' dataset.
val context = new SparkContext("local", "SparkContext")
val trainingData = context.textFile("data/bodyfat.csv", 1)
val tree = new RegressionTree()
tree.setDataset(trainingData)
tree.setFeatureNames(Array[String]("","age","DEXfat","waistcirc","hipcirc","elbowbreadth","kneebreadth","anthro3a","anthro3b","anthro3c","anthro4"))
tree.setMinSplit(10)
var stime = System.nanoTime()
println(tree.buildTree("DEXfat", Set("age", "waistcirc","hipcirc","elbowbreadth","kneebreadth")))
println("Build tree in %f second(s)".format((System.nanoTime() - stime)/1e9))
The tree model will be written on stdout
as follows:
waistcirc( < 88.400000)
|-(yes)-hipcirc( < 96.250000)
| |-(yes)--age( < 59.500000)
| | |-(yes)---waistcirc( < 67.300000)
| | | |-(yes)----11.21
| | | |-(no)----17.015555555555558
| | |-(no)---22.328333333333333
| |-(no)--waistcirc( < 80.750000)
| | |-(yes)---kneebreadth( < 8.550000)
| | | |-(yes)----20.30333333333333
| | | |-(no)----25.279000000000003
| | |-(no)---29.372000000000003
|-(no)-kneebreadth( < 11.300000)
| |-(yes)--hipcirc( < 109.900000)
| | |-(yes)---35.27846153846154
| | |-(no)---42.95437500000001
| |-(no)--61.370000000000005
Build tree in 2.120787 second(s)