Skip to content
michiard edited this page Sep 16, 2014 · 9 revisions

Welcome to the Treelib wiki!

If you are not familiar with decision trees, use the following resources:

  • An Introduction to Statistical Learning, by G. James, D. Witten, T. Hastie and R. Tibshirani (Springer texts in Statistics)

  • Elements of Statistical Learning, by T. Hastie, R. Tibshirani, J. Friedman (Spinger series in Statistics)

Next, a short overview to understand how to use the library. For those of you interested in some of the internals, follow this link [todo].

Features:

Treelib helps you build Classification or Regression Trees by executing model training in parallel, using the Apache Spark execution engine. In addition to model building, additional utilities include Tree Pruning, and Random Forests. Treelib implements state-of-the-art algorithms including Binary Tree models (using the CART algorithm) and Multi-way Tree modles (using ID3).

How to use Treelib:

Clone this repo, and compile the library using sbt compile. You fill find the library jars in the target folder of the project (look for the specific scala version you are using). Copy the jar of Treelib in the lib folder of your Spark application, then simply import Treelib in your Spark application code.

Every tree builder (RegressionTree, CART Classification Tree, ID3 Classification Tree) has the same prototype: how to declare a tree, build and validate tree models is shown next.

How to declare a tree builder

Regression Trees

First, import the required package by import fr.eurecom.dsg.treelib.cart._ and then create a tree builder: val tree = new RegressionTree

Classification Tree with CART

First, import the required package by import fr.eurecom.dsg.treelib.cart._ and then create a tree builder: val tree = new ClassificationTree()

Multi-way Tree with ID3

First, import the required package by import fr.eurecom.dsg.treelib.id3._ and then create tree builder: val tree = new ID3TreeBuilder()

Using tree builder

The first steps to build a tree model is to set the training data using tree.setDataSet(data), where data is the RDD storing the training dataset. Note: Treelib only supports data WIHOUT header. However, it is possible to manually label features using tree.setFeatureNames(<Array_string_name_of_features>). If not otherwise specified, features will be named: "Column1", "Column2",...

To start building a tree, use the method tree.buildTree(yFeature, SetOfXFeatures), where yFeature is the name of the target feature; SetOfXFeatures is the set of predictors. These two parameters are optional: calling tree.buildTree(), with no arguments implicitly assumes that the last feature of the dataset is the target, while the preceding features are predictors.

Example:

Assume that we have a training data with schema: "Temperature", "Humidity", "Outlook", "Money", "Month" , "DayOfWeek", PlayGolf", where

"Temperature" is a float, to indicate the temperature in Celsius degrees; "Humidity" can take a value in the set {High, Normal}; "Outlook" can take a value in the set {Overcast, Rainy, Sunny}; "Money" is a float indicating the wealth of a person, "Month" has value range [1, 12] "DayOfWeek" has the value set {0, 1, 2, 3, 4, 5, 6}, corresponding to {Saturday, Sunday, Monday, Tuesday, Wednesday, Thurday, Friday} relatively; "PlayGolf" has two categorical valuse : "Yes" and "No"

Our goal is to predict whether a person will play golf or not, based on all the features described above (hence, the last feature is our target). We build and train our Classification tree as follows:

    // read training data
    val playgolf_training_data = context.textFile("data/playgolf.csv", 1)
    // create Regression Tree
    val tree = new ClassificationTree()
    // set the training data
    tree.setDataset(playgolf_training_data)
    // set the header for the training data
    tree.setFeatureNames(Array[String]("Temperature", "Humidity", "Outlook", "Money", "Month" , "DayOfWeek", PlayGolf"))
    // build tree with the target is the last feature (PlayGolf), the rest is predictor
    tree.buildTree()

Alternatively, if we want to omit the predictor "Month", we can call:

tree.buildTree("Playgolf", Set[Any]("Temperature", "Humidity", "Outlook", "Money", "DayOfWeek"))

If you don't specify the type of predictor, the algorithm will detect them automatically by considering their values. For instance, with the above statement, "Temperatue", "Money", "Month", "DayOfWeek" will be treat as Numerical feature, the rest are as Categorical feature. In case you want to explicit the type of feature, you can use as object, similiarly to R. For example:

tree.buildTree("Playgolf", Set[Any]("Temperature", "Humidity", "Outlook", as.Number("Money"), as.String("DayOfWeek"))

Some additional tuning parameters can also be easily set::

    tree.setDelimiter("\t")    // set delimiter of fields is tab, default is ","
    tree.setMinSplit(1)        // only grow node if it has more than 1 records
    tree.setThreshold(0.3)     // only grow node if it the coefficient of variation of Y> 0.3; CV = Deviation(Y)/E(Y)

Prediction

Once you grew a tree, you can use it to make predictions as follows:

predictOneInstance(record: Array[String], ignoreBranchIDs: Set[BigInt] = Set[BigInt]())

where record is an array of predictors' values; ignoreBranchIDs is the set of ID of nodes, which we want to force them make prediction as a leaf node, instead of using their children (this parameter is optional). For example:

    tree.predictOneInstance(Array[String]("cool","sunny","normal","false","30"))

Evaluation

Finally, you can evaluate the quality of the tree model using a test dataset:

    val predictRDDOfTheFullTree = tree.predict(testingData)
    val actualValueRDD = testingData.map(line => line.split(',').last) // the last feature is the target
    Evaluation.evaluate(predictRDDOfTheFullTree, actualValueRDD)

The metric which we use to evaluate are the mean, standard deviation and square error of the difference between the predicted value and the ground-truth values. For example, an evaluation result of a regression tree may look like:

    Mean of different:-0.000000
    Deviation of different:2.596211
    SE of different:0.036566

Writing Tree model to file and reading it from file

Treelib allows you to save and load tree models as well:

    // tree is created by 'val tree = new RegressionTree()'

    // write model to file
    tree.writeModelToFile("my_tree.model")
    
    // create a new regression tree
    val new_tree = new RegressionTree ();
	    
    // load tree model from a file
    new_tree.loadModelFromFile("my_tree.model")

Viewing a model in the console

Let's proceed with a new example: we want to build a Regression tree using the 'bodyfat' dataset.

    val context = new SparkContext("local", "SparkContext")

    val trainingData = context.textFile("data/bodyfat.csv", 1)
                 
    val tree = new RegressionTree()
    tree.setDataset(trainingData)
    tree.setFeatureNames(Array[String]("","age","DEXfat","waistcirc","hipcirc","elbowbreadth","kneebreadth","anthro3a","anthro3b","anthro3c","anthro4"))

    tree.setMinSplit(10)

    var stime = System.nanoTime()

    println(tree.buildTree("DEXfat", Set("age", "waistcirc","hipcirc","elbowbreadth","kneebreadth")))

    println("Build tree in %f second(s)".format((System.nanoTime() - stime)/1e9))

The tree model will be written on stdout as follows:

    waistcirc( < 88.400000)
    |-(yes)-hipcirc( < 96.250000)
    |    |-(yes)--age( < 59.500000)
    |    |    |-(yes)---waistcirc( < 67.300000)
    |    |    |    |-(yes)----11.21
    |    |    |    |-(no)----17.015555555555558
    |    |    |-(no)---22.328333333333333
    |    |-(no)--waistcirc( < 80.750000)
    |    |    |-(yes)---kneebreadth( < 8.550000)
    |    |    |    |-(yes)----20.30333333333333
    |    |    |    |-(no)----25.279000000000003
    |    |    |-(no)---29.372000000000003
    |-(no)-kneebreadth( < 11.300000)
    |    |-(yes)--hipcirc( < 109.900000)
    |    |    |-(yes)---35.27846153846154
    |    |    |-(no)---42.95437500000001
    |    |-(no)--61.370000000000005
    Build tree in 2.120787 second(s)
Clone this wiki locally