Decision Tree Algorithm for Multiclass problems using Python

栏目: IT技术 · 发布时间: 4年前

内容简介:Understanding the mathematics of Gini Index, Entropy, Information Gain, Feature Importance, and the CCP thresholdDecision tree classifiers are supervised learning models that are useful when we care about interpretability. Think of it like, breaking down t

Decision Tree Algorithm for Multiclass problems using Python

Understanding the mathematics of Gini Index, Entropy, Information Gain, Feature Importance, and the CCP threshold

Photo by Fatos Bytyqi on Unsplash

Introduction

Decision tree classifiers are supervised learning models that are useful when we care about interpretability. Think of it like, breaking down the data by making decisions based on multiple questions at each level. The biggest challenge with the decision tree involves understanding the back end algorithm using which a tree spans out into branches and sub-branches. In this article, we will take a broader look into how different impurity metrics are used to determine the decision variables at each node, how important features are determined, and more importantly how trees are pruned to prevent model overfitting. To get a better understanding of how the decision tree works, you can refer to the link below.

Overview

Decision trees consist of nodes and branches. The nodes can further be classified into a root node (starting node of the tree), decision nodes (sub-nodes that splits based on conditions), and leaf nodes (nodes that don’t branch out further). Since the decision tree follows an if-else structure, every node uses one and only one independent variable to split into two or more branches. The independent variable can be categorical or continuous. For categorical variables, the categories are used to decide the split of the node, for continuous variables the algorithm comes up with multiple threshold values that act as the decision-maker (Raschka, Julian and Hearty, 2016, pp.83, 88, 89).

Say we are trying to develop a decision tree for the dataset below. The dataset consists of Student IDs, their gender information, their study method, and an attribute that identifies if they play cricket or not. We can demarcate between the predictor and dependent variables vs. the independent variables or attributes. The dependent variable being the “Plays Cricket” column and rest of the attributes other than the “Student ID” will form the independent variables set.

Say we are trying to develop a decision tree for the dataset shown in Figure 1 below. The dataset consists of Student IDs, their gender information, their study method, and an attribute that identifies if they play cricket or not. We can demarcate between the predictor and dependent variables vs. the independent variables or attributes. The dependent variable being the “Plays Cricket” column and rest of the attributes other than the “Student ID” will form the independent variables set.

Figure 1. A quick overview of the table containing student information and necessary attributes. Reference — Developed by the author using PowerPoint

Deciding the Split

As step 1 to this problem, we need to identify which independent variable can be used to split the root node. Let’s use Gini Impurity to decide the branching of students in cricketers and non-cricketers. We will be calculating the Gini Impurity using both “Gender” and “Study Method” and consider the one with the lowest impurity score. Note that decision trees algorithm tries to make every node as homogeneous as possible. This is in line with the outcome we are trying to achieve, predict students as cricketers and non-cricketers correctly. Homogeneity is a situation under which all students in a particular node belong to one of the categories, i.e. either they are cricketers (Plays Cricket = Y) or non-Cricketers (Plays Cricket = N).

Figure 2. Weighted Average Impurity calculated using the Gini Index and “Gender” as the independent variable. Reference — Developed by the author using PowerPoint
Figure 3. Weighted Average Impurity calculated using the Gini Index and “Study Method” as the independent variable. Reference — Developed by the author using PowerPoint

Since impurity is a measure of how homogeneous a node is, the algorithm with chose the independent variable with the lowest weighted average impurity score as the decision variable for the root node, in this case, “Gender” will be used to create the first split of the decision tree. Also, note the figures highlighted in yellow, they are used to calculate the information gain of the split. Information Gain can be defined as the change in impurity from the parent node to their respective child nodes, in this scenario we need to check if impurity in the root node is higher than Weighted Impurity of the Split. If the information gain is greater than 0, the algorithm will go ahead and split the node using the decision variable.

How can we prune the above tree?

Decision trees in general will continue to form branches till every node becomes homogeneous. As a result of this, the tree works well with the training data but fails to produce quality output for the test data. Hence the tree should be pruned to prevent overfitting. From the above example, we can fine-tune the decision tree using the factors outlined below.

  1. Criterion — Python works with Gini & Entropy. Other algorithm uses CHAID (Chi-square Automatic Interaction Detector), miss classification errors, etc.
  2. Information Gain or minimum impurity decrease from root node/parent node to child nodes
  3. Number of samples in each node (minimum or maximum)
  4. Depth of the tree (how many branches should the algorithm form)
  5. Maximum number of leaves

A Real-life Problem using Python

Understanding the Gini Index

The problem statement aims at developing a classification model to predict the quality of red wine. Details about the problem statement can be found here . This is a classic example of a multi-class classification problem. We won’t look into the codes, but rather try and interpret the output using DecisionTreeClassifier() from sklearn.tree in Python.

Reference of the code Snippets below: Das, A. (2020). Decision Tree Classifier and Cost Computation Pruning using Python . [online] Medium. Available at: https://towardsdatascience.com/decision-tree-classifier-and-cost-computation-pruning-using-python-b93a0985ea77

#Reading the data
wine_df = pd.read_csv('winequality-red.csv',sep=';')# splitting data into training and test set for independent attributesfrom sklearn.model_selection import train_test_splitX_train, X_test, y_train, y_test =train_test_split(wine_df.drop('quality',axis=1), wine_df['quality'], test_size=.3,
                                                   random_state=22)
X_train.shape,X_test.shape#developing a model
clf_pruned = DecisionTreeClassifier(criterion = "gini", random_state = 100,
                               max_depth=3, min_samples_leaf=5)
clf_pruned.fit(X_train, y_train)#visualizing the tree
import iofrom io import StringIOfrom sklearn.tree import export_graphviz
# from sklearn.externals.six import StringIO  
from IPython.display import Image  
import pydotplus
import graphvizxvar = wine_df.drop('quality', axis=1)
feature_cols = xvar.columnsdot_data = StringIO()
export_graphviz(clf_pruned, out_file=dot_data,  
                filled=True, rounded=True,
                special_characters=True,feature_names = feature_cols,class_names=['0','1','2'])from pydot import graph_from_dot_data
(graph, ) = graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
Figure 4. Decision tree using Gini Index, depth=3, and max_samples_leaves=5. Note that to handle class imbalance, we categorized the wines into quality 5, 6, and 7. In the diagram above 5 corresponds to class=0, 6 to class=1, and 7 corresponds to class=2.

When we look at a multiclass algorithm, let’s not get confused about how to calculate the Gini Index. In the example before we had P and Q to represent students playing cricket and students not playing cricket. Where P is the probability of playing Cricket and Q is the probability of not playing cricket. In the case of a multiclass decision tree, for node alcohol <=0.25 we will perform the following calculation.

  • P represents probability of the Wine being quality 5 which is (514/1119) = 0.46
  • Q represents probability of the Wine being quality 6 which is (448/1119) = 0.40
  • R represents probability of the Wine being quality 7 which is (157/1119) = 0.14
  • Gini Index = P² + Q² + R² = (0.46)² + (0.40)² + (0.14)² = 0.39
  • Gini Impurity = 1–0.39 = 0.61 (see it matches)

The above tree is pruned so that we can visualize the model better. There is no statistical significance of using a depth 3 or maximum sample leaf as 5.

Instead of using criterion = “gini” we can always use criterion= “entropy” to obtain the above tree diagram. Entropy is calculated as -P*log(P)-Q*log(Q).

Figure 5. Decision tree using entropy, depth=3, and max_samples_leaves=5. Note that to handle class imbalance, we categorized the wines into quality 5, 6, and 7. In the diagram above 5 corresponds to class=0, 6 to class=1, and 7 corresponds to class=2.

Understanding Feature Importance

Feature importance refers to a class of techniques for assigning scores to input features of a predictive model that indicates the relative importance of each feature when making a prediction (Das, 2020).

## Calculating feature importancefeat_importance = clf_pruned.tree_.compute_feature_importances(normalize=False)feat_imp_dict = dict(zip(feature_cols, clf_pruned.feature_importances_))
feat_imp = pd.DataFrame.from_dict(feat_imp_dict, orient='index')
feat_imp.rename(columns = {0:'FeatureImportance'}, inplace = True)
feat_imp.sort_values(by=['FeatureImportance'], ascending=False).head()
Figure 6. Feature importance of the independent variables.

Let us try and understand the following to get a better picture of how feature importance is decided.

  1. First split is based on alcohol <=10.25; This variable with this threshold ensures minimum impurity of all other variables hence in the above table (Figure 5) you see that the feature importance is high
  2. The next split is based on sulphates <-0.555 and <=0.685, so sulphates come second in the order
  3. When we look at the third level of the split, we see that there are three contenders, chlorides <=0.08, total sulfur dioxide <=88.5, volatile acidity <=0.87.

To decide which independent variable is important we need to understand the information gain by splitting these parent nodes into their respective child nodes.

Figure 7. Variables used to decide the split when depth =3.
Figure 8. Information Gain using volatile acidity≤0.87 as the threshold.
Figure 9. Information Gain using total sulfur dioxide≤88.5 as the threshold.

So, we see that the information gain by splitting Node “volatile acidity <=0.87” is higher than by splitting “total sulfur dioxide <=88.5” hence in the importance table “volatile acidity <=0.87” is placed above “total sulfur dioxide <=88.5”.

Understanding cost complexity

The cost complexity refers to the complexity parameter that is used to define the cost complexity measure Ra(T) of a given tree T. Ra(T) is written as:

Ra(T) = R(T) + a|T|

where |T| is the number of terminal nodes, R(T) is the total misclassification rate of the terminal node, and a is the CCP parameter. To summarise, the subtree with the highest cost complexity that is smaller than ccp_alpha will be retained. It is always good to select a CCP parameter that produces the highest test accuracy (Scikit Learn, n.d.).

Figure 10. Variation in train and test accuracy for different levels of alpha/a/CCP value.

Reference

  1. Raschka, S., Julian, D. and Hearty, J. (2016). Python : deeper insights into machine learning : leverage benefits of machine learning techniques using Python : a course in three modules . Birmingham, Uk: Packt Publishing, pp.83, 88, 89.
  2. Scikit-learn: Machine Learning in Python , Pedregosa et al. , JMLR 12, pp. 2825–2830, 2011.
  3. Scikit Learn (2019). sklearn.tree.DecisionTreeClassifier — scikit-learn 0.22.1 documentation . [online] Scikit-learn.org. Available at: https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html.
  4. Scikit Learn (n.d.). Post pruning decision trees with cost complexity pruning . [online] Available at: https://scikit-learn.org/stable/auto_examples/tree/plot_cost_complexity_pruning.html#sphx-glr-auto-examples-tree-plot-cost-complexity-pruning-py.
  5. Das, A. (2020). Decision Tree Classifier and Cost Computation Pruning using Python . [online] Medium. Available at: https://towardsdatascience.com/decision-tree-classifier-and-cost-computation-pruning-using-python-b93a0985ea77 [Accessed 18 Jul. 2020].

以上所述就是小编给大家介绍的《Decision Tree Algorithm for Multiclass problems using Python》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 码农网 的支持!

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

传统企业,互联网在踢门

传统企业,互联网在踢门

刘润 / 中国华侨出版社 / 2014-7 / 42

1、第一本传统企业互联网化的战略指导书,首次提出“互联网加减法”,迄今最清晰的转型公式 鉴于目前很多传统企业“老办法不管用,新办法不会用”的现状,本书将用“互联网的加减法” 这个简单模型清晰地说明商业新时代的游戏规则和全新玩法,帮助传统企业化解“本领恐慌” 。 2、小米董事长&CEO 金山软件董事长雷军,新东方教育科技集团董事长兼CEO俞敏洪,复旦大学管理学院院长陆雄文,复旦大学博士、......一起来看看 《传统企业,互联网在踢门》 这本书的介绍吧!

CSS 压缩/解压工具
CSS 压缩/解压工具

在线压缩/解压 CSS 代码

图片转BASE64编码
图片转BASE64编码

在线图片转Base64编码工具

HSV CMYK 转换工具
HSV CMYK 转换工具

HSV CMYK互换工具