当前位置:网站首页>Building and visualizing decision tree with Python

Building and visualizing decision tree with Python

2020-11-06 01:28:05 Artificial intelligence meets pioneer

author |Nikhil Adithyan compile |VK source |Towards Data Science

Decision tree

Decision trees are part of the most powerful supervised learning approach today . The decision tree is basically a flow chart of a binary tree , Each node splits a set of observations according to a characteristic variable .

The goal of a decision tree is to divide data into groups , Each element in such a group belongs to the same category . Decision trees can also be used to approximate continuous target variables . under these circumstances , The tree will be split , The mean square error of each group is minimized .

An important feature of decision trees is that they are easy to explain . You don't need to be familiar with machine learning technology to understand what a decision tree is doing . The decision tree is easy to explain .

Pros and cons

The advantage of the decision tree method is :

  • Decision trees can generate understandable rules .

  • Decision trees are classified without a lot of computation .

  • Decision trees can handle continuous variables and categorical variables .

  • The decision tree provides a clear indication , Which fields are the most important .

The disadvantage of the decision tree method is :

  • Decision trees are not suitable for tasks where the goal is to predict continuous attribute values .

  • The decision tree has more classes 、 It is easy to make mistakes in the classification problem with few training samples .

  • Training decision trees can be computationally expensive . The process of generating a decision tree is computationally expensive . On each node , Each candidate split field must be sorted , To find the best split . In some algorithms , Use field combinations , You have to search for the best combination weight . Pruning algorithms can also be expensive , Because many candidate subtrees have to be formed and compared .

Python Decision tree

Python It's a universal programming language , It provides powerful machine learning packages and tools for data scientists . In this paper , We will use python The most famous machine learning package scikit-learn To build a decision tree model . We will use scikit learn Provided “DecisionTreeClassifier” Algorithms create models , And then use “plot_tree” Function visualization model .

step 1: Import package

The main software package that we build the model is pandas、scikit learn and NumPy. According to the code in python Import the required package from .

import pandas as pd #  Data processing 
import numpy as np #  Using arrays 
import matplotlib.pyplot as plt #  visualization 
from matplotlib import rcParams #  Picture size 
from termcolor import colored as cl #  Text customization 

from sklearn.tree import DecisionTreeClassifier as dtc #  Tree algorithm 
from sklearn.model_selection import train_test_split #  Split data 
from sklearn.metrics import accuracy_score #  Model accuracy 
from sklearn.tree import plot_tree #  Tree diagram 

rcParams['figure.figsize'] = (25, 20)

After importing all the packages needed to build our model , It's time to import the data and do something about it EDA 了 .

step 2: Import data and EDA

In this step , We will use python Provided in “Pandas” Package to import and do something on it EDA. We will build our decision tree model , The data set is a drug data set , It's a prescription for patients based on specific criteria . Let's use it python Import data !

Python Realization :
df = pd.read_csv('drug.csv')
df.drop('Unnamed: 0', axis = 1, inplace = True)

print(cl(df.head(), attrs = ['bold']))

Output :

   Age Sex      BP Cholesterol  Na_to_K   Drug
0   23   F    HIGH        HIGH   25.355  drugY
1   47   M     LOW        HIGH   13.093  drugC
2   47   M     LOW        HIGH   10.114  drugC
3   28   F  NORMAL        HIGH    7.798  drugX
4   61   F     LOW        HIGH   18.043  drugY

Now we have a clear idea of datasets . After importing data , Let's use “info” Function to get some basic information about the data . The information provided by this function includes the number of entries 、 Reference no. 、 Name 、 Non null count 、 Attribute type, etc .

Python Realization :
df.info()

Output :

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 200 entries, 0 to 199
Data columns (total 6 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   Age          200 non-null    int64  
 1   Sex          200 non-null    object 
 2   BP           200 non-null    object 
 3   Cholesterol  200 non-null    object 
 4   Na_to_K      200 non-null    float64
 5   Drug         200 non-null    object 
dtypes: float64(1), int64(1), object(4)
memory usage: 9.5+ KB

step 3: Data processing

We can see things like Sex, BP and Cholesterol Such properties are classified and object typed in nature . The problem is ,scikit-learn The decision tree algorithm in does not support X Variable ( features ) yes “ object ” type . therefore , It is necessary to integrate these “object” Value to “binary” value . Let's use it python To achieve

Python Realization :
for i in df.Sex.values:
    if i  == 'M':
        df.Sex.replace(i, 0, inplace = True)
    else:
        df.Sex.replace(i, 1, inplace = True)

for i in df.BP.values:
    if i == 'LOW':
        df.BP.replace(i, 0, inplace = True)
    elif i == 'NORMAL':
        df.BP.replace(i, 1, inplace = True)
    elif i == 'HIGH':
        df.BP.replace(i, 2, inplace = True)

for i in df.Cholesterol.values:
    if i == 'LOW':
        df.Cholesterol.replace(i, 0, inplace = True)
    else:
        df.Cholesterol.replace(i, 1, inplace = True)

print(cl(df, attrs = ['bold']))

Output :

     Age  Sex  BP  Cholesterol  Na_to_K   Drug
0     23    1   2            1   25.355  drugY
1     47    1   0            1   13.093  drugC
2     47    1   0            1   10.114  drugC
3     28    1   1            1    7.798  drugX
4     61    1   0            1   18.043  drugY
..   ...  ...  ..          ...      ...    ...
195   56    1   0            1   11.567  drugC
196   16    1   0            1   12.006  drugC
197   52    1   1            1    9.894  drugX
198   23    1   1            1   14.020  drugX
199   40    1   0            1   11.349  drugX

[200 rows x 6 columns]

We can observe all the “object” Values are processed as “binary” Value to represent categorical data . for example , In cholesterol properties , Show “ low ” The value of is processed as 0,“ high ” It is treated as 1. Now we're ready to create dependent and independent variables from the data .

step 4: Split data

After processing our data into the right structure , We now set up “X” Variable ( The independent variables ),“Y” Variable ( The dependent variable ). Let's use it python To achieve

Python Realization :
X_var = df[['Sex', 'BP', 'Age', 'Cholesterol', 'Na_to_K']].values #  The independent variables 
y_var = df['Drug'].values #  The dependent variable 

print(cl('X variable samples : {}'.format(X_var[:5]), attrs = ['bold']))
print(cl('Y variable samples : {}'.format(y_var[:5]), attrs = ['bold']))

Output :

X variable samples : [[ 1.     2.    23.     1.    25.355]
 [ 1.     0.    47.     1.    13.093]
 [ 1.     0.    47.     1.    10.114]
 [ 1.     1.    28.     1.     7.798]
 [ 1.     0.    61.     1.    18.043]]
Y variable samples : ['drugY' 'drugC' 'drugC' 'drugX' 'drugY']

We can now use scikit learn Medium “train_test_split” The algorithm divides the data into training set and test set , It includes what we define as X and Y Variable . According to the code in python Split data in .

Python Realization :
X_train, X_test, y_train, y_test = train_test_split(X_var, y_var, test_size = 0.2, random_state = 0)

print(cl('X_train shape : {}'.format(X_train.shape), attrs = ['bold'], color = 'black'))
print(cl('X_test shape : {}'.format(X_test.shape), attrs = ['bold'], color = 'black'))
print(cl('y_train shape : {}'.format(y_train.shape), attrs = ['bold'], color = 'black'))
print(cl('y_test shape : {}'.format(y_test.shape), attrs = ['bold'], color = 'black'))

Output :

X_train shape : (160, 5)
X_test shape : (40, 5)
y_train shape : (160,)
y_test shape : (40,)

Now we have all the components for building a decision tree model . therefore , Let's continue to use python Building our model .

step 5: Building models and Forecasting

stay scikit The learning package provides “DecisionTreeClassifier” With the help of algorithm , It is feasible to build a decision tree . after , We can use our trained models to predict our data . Last , The accuracy of our prediction results can be used “ Accuracy ” Evaluate indicators to calculate . Let's use it python To complete the process !

Python Realization :
model = dtc(criterion = 'entropy', max_depth = 4)
model.fit(X_train, y_train)

pred_model = model.predict(X_test)

print(cl('Accuracy of the model is {:.0%}'.format(accuracy_score(y_test, pred_model)), attrs = ['bold']))

Output :

Accuracy of the model is 88%

In the first step of the code , We have defined a system called “model” Variables of variables , We store... In it DecisionTreeClassifier Model . Next , We will use our training set to fit and train the model . after , We defined a variable , be called “pred_model” Variable , We store all the values predicted by the model on the data . Last , We calculated the accuracy of our predicted and actual values , The accuracy is 88%.

step 6: Visualization model

Now we have a decision tree model , Let's take advantage of python in scikit learn Provided by the package “plot_tree” Function to visualize it . From code python The decision tree model in generates a beautiful tree graph .

Python Realization :
feature_names = df.columns[:5]
target_names = df['Drug'].unique().tolist()

plot_tree(model, 
          feature_names = feature_names, 
          class_names = target_names, 
          filled = True, 
          rounded = True)

plt.savefig('tree_visualization.png') 

Output :

Conclusion

There are many techniques and other algorithms for optimizing decision trees and avoiding over fitting , Like pruning . Although decision trees are usually unstable , This means that small changes in the data can lead to huge changes in the structure of the optimal tree , But its simplicity makes it a powerful candidate for widespread use . Before neural networks became popular , Decision tree is the most advanced algorithm in machine learning . Other integration models , Like the random forest model , More powerful than ordinary decision tree models .

Decision trees are very powerful because of their simplicity and interpretability . Decision tree and random forest modeling in user registration 、 Credit score 、 Failure prediction 、 Medical diagnosis and other fields have a wide range of applications . I've provided the complete code for this article .

Complete code :

import pandas as pd #  Data processing 
import numpy as np #  Using arrays 
import matplotlib.pyplot as plt #  visualization 
from matplotlib import rcParams #  Picture size 
from termcolor import colored as cl #  Text customization 

from sklearn.tree import DecisionTreeClassifier as dtc #  Tree algorithm 
from sklearn.model_selection import train_test_split #  Split data 
from sklearn.metrics import accuracy_score #  Model accuracy 
from sklearn.tree import plot_tree #  Tree diagram 

rcParams['figure.figsize'] = (25, 20)

df = pd.read_csv('drug.csv')
df.drop('Unnamed: 0', axis = 1, inplace = True)

print(cl(df.head(), attrs = ['bold']))

df.info()

for i in df.Sex.values:
    if i  == 'M':
        df.Sex.replace(i, 0, inplace = True)
    else:
        df.Sex.replace(i, 1, inplace = True)
        
for i in df.BP.values:
    if i == 'LOW':
        df.BP.replace(i, 0, inplace = True)
    elif i == 'NORMAL':
        df.BP.replace(i, 1, inplace = True)
    elif i == 'HIGH':
        df.BP.replace(i, 2, inplace = True)
        
for i in df.Cholesterol.values:
    if i == 'LOW':
        df.Cholesterol.replace(i, 0, inplace = True)
    else:
        df.Cholesterol.replace(i, 1, inplace = True)
        
print(cl(df, attrs = ['bold']))

X_var = df[['Sex', 'BP', 'Age', 'Cholesterol', 'Na_to_K']].values #  The independent variables 
y_var = df['Drug'].values #  The dependent variable 

print(cl('X variable samples : {}'.format(X_var[:5]), attrs = ['bold']))
print(cl('Y variable samples : {}'.format(y_var[:5]), attrs = ['bold']))

X_train, X_test, y_train, y_test = train_test_split(X_var, y_var, test_size = 0.2, random_state = 0)

print(cl('X_train shape : {}'.format(X_train.shape), attrs = ['bold'], color = 'red'))
print(cl('X_test shape : {}'.format(X_test.shape), attrs = ['bold'], color = 'red'))
print(cl('y_train shape : {}'.format(y_train.shape), attrs = ['bold'], color = 'green'))
print(cl('y_test shape : {}'.format(y_test.shape), attrs = ['bold'], color = 'green'))

model = dtc(criterion = 'entropy', max_depth = 4)
model.fit(X_train, y_train)

pred_model = model.predict(X_test)

print(cl('Accuracy of the model is {:.0%}'.format(accuracy_score(y_test, pred_model)), attrs = ['bold']))

feature_names = df.columns[:5]
target_names = df['Drug'].unique().tolist()

plot_tree(model, 
          feature_names = feature_names, 
          class_names = target_names, 
          filled = True, 
          rounded = True)

plt.savefig('tree_visualization.png')

Link to the original text :https://towardsdatascience.com/building-and-visualizing-decision-tree-in-python-2cfaafd8e1bb

Welcome to join us AI Blog station : http://panchuang.net/

sklearn Machine learning Chinese official documents : http://sklearn123.com/

Welcome to pay attention to pan Chuang blog resource summary station : http://docs.panchuang.net/

版权声明
本文为[Artificial intelligence meets pioneer]所创,转载请带上原文链接,感谢