Python | Regresión del árbol de decisión usando sklearn

Decision Tree es una herramienta de toma de decisiones que utiliza una estructura de árbol similar a un diagrama de flujo o es un modelo de decisiones y todos sus posibles resultados, incluidos los resultados, los costos de entrada y la utilidad.
El algoritmo del árbol de decisiones cae dentro de la categoría de algoritmos de aprendizaje supervisado. Funciona tanto para variables de salida continuas como categóricas.

Machine-Learning-Course

Las ramas/aristas representan el resultado del Node y los Nodes tienen: 

  1. Condiciones [Nodes de decisión]
  2. Resultado [Nodes finales]

Las ramas/bordes representan la verdad/falsedad de la declaración y toma una decisión basada en el ejemplo a continuación que muestra un árbol de decisión que evalúa el menor de tres números:  


Regresión del 
árbol de decisión: la regresión del árbol de decisión observa las características de un objeto y entrena un modelo en la estructura de un árbol para predecir datos en el futuro para producir una salida continua significativa. Salida continua significa que la salida/resultado no es discreto, es decir, no está representado simplemente por un conjunto discreto y conocido de números o valores.

Ejemplo de salida discreta: un modelo de predicción meteorológica que predice si lloverá o no en un día determinado. 
Ejemplo de salida continua: un modelo de predicción de ganancias que establece la ganancia probable que se puede generar a partir de la venta de un producto.
Aquí, los valores continuos se predicen con la ayuda de un modelo de regresión de árbol de decisión.

Veamos la implementación paso a paso – 

  • Paso 1: Importe las bibliotecas requeridas. 

Python3

# import numpy package for arrays and stuff
import numpy as np 
  
# import matplotlib.pyplot for plotting our result
import matplotlib.pyplot as plt
  
# import pandas for importing csv files 
import pandas as pd 
  • Paso 2: inicialice e imprima el conjunto de datos.

Python3

# import dataset
# dataset = pd.read_csv('Data.csv') 
# alternatively open up .csv file to read data
  
dataset = np.array(
[['Asset Flip', 100, 1000],
['Text Based', 500, 3000],
['Visual Novel', 1500, 5000],
['2D Pixel Art', 3500, 8000],
['2D Vector Art', 5000, 6500],
['Strategy', 6000, 7000],
['First Person Shooter', 8000, 15000],
['Simulator', 9500, 20000],
['Racing', 12000, 21000],
['RPG', 14000, 25000],
['Sandbox', 15500, 27000],
['Open-World', 16500, 30000],
['MMOFPS', 25000, 52000],
['MMORPG', 30000, 80000]
])
  
# print the dataset
print(dataset) 

Producción:

[['Asset Flip' '100' '1000']
 ['Text Based' '500' '3000']
 ['Visual Novel' '1500' '5000']
 ['2D Pixel Art' '3500' '8000']
 ['2D Vector Art' '5000' '6500']
 ['Strategy' '6000' '7000']
 ['First Person Shooter' '8000' '15000']
 ['Simulator' '9500' '20000']
 ['Racing' '12000' '21000']
 ['RPG' '14000' '25000']
 ['Sandbox' '15500' '27000']
 ['Open-World' '16500' '30000']
 ['MMOFPS' '25000' '52000']
 ['MMORPG' '30000' '80000']]
  • Paso 3: seleccione todas las filas y la columna 1 del conjunto de datos a «X».

Python3

# select all rows by : and column 1
# by 1:2 representing features
X = dataset[:, 1:2].astype(int) 
  
# print X
print(X)

Producción:

[[  100]
 [  500]
 [ 1500]
 [ 3500]
 [ 5000]
 [ 6000]
 [ 8000]
 [ 9500]
 [12000]
 [14000]
 [15500]
 [16500]
 [25000]
 [30000]]
  • Paso 4: seleccione todas las filas y la columna 2 del conjunto de datos a «y».

Python3

# select all rows by : and column 2
# by 2 to Y representing labels
y = dataset[:, 2].astype(int) 
  
# print y
print(y)

Producción:

[ 1000  3000  5000  8000  6500  7000 15000 20000 21000 25000 27000 30000 52000 80000]
  • Paso 5: ajuste el regresor del árbol de decisión al conjunto de datos

Python3

# import the regressor
from sklearn.tree import DecisionTreeRegressor 
  
# create a regressor object
regressor = DecisionTreeRegressor(random_state = 0) 
  
# fit the regressor with X and Y data
regressor.fit(X, y)

Producción:

DecisionTreeRegressor(ccp_alpha=0.0, criterion='mse', max_depth=None,
                      max_features=None, max_leaf_nodes=None,
                      min_impurity_decrease=0.0, min_impurity_split=None,
                      min_samples_leaf=1, min_samples_split=2,
                      min_weight_fraction_leaf=0.0, presort='deprecated',
                      random_state=0, splitter='best')
  • Paso 6: Predecir un nuevo valor

Python3

# predicting a new value
  
# test the output by changing values, like 3750
y_pred = regressor.predict([[3750]])
  
# print the predicted price
print("Predicted price: % d\n"% y_pred) 

Producción:

Predicted price:  8000
  • Paso 7: Visualización del resultado

Python3

# arange for creating a range of values 
# from min value of X to max value of X 
# with a difference of 0.01 between two
# consecutive values
X_grid = np.arange(min(X), max(X), 0.01)
  
# reshape for reshaping the data into 
# a len(X_grid)*1 array, i.e. to make
# a column out of the X_grid values
X_grid = X_grid.reshape((len(X_grid), 1)) 
  
# scatter plot for original data
plt.scatter(X, y, color = 'red')
  
# plot predicted data
plt.plot(X_grid, regressor.predict(X_grid), color = 'blue') 
  
# specify title
plt.title('Profit to Production Cost (Decision Tree Regression)') 
  
# specify X axis label
plt.xlabel('Production Cost')
  
# specify Y axis label
plt.ylabel('Profit')
  
# show the plot
plt.show()

  • Paso 8: el árbol finalmente se exporta y se muestra en la ESTRUCTURA DEL ÁRBOL a continuación, visualizada usando http://www.webgraphviz.com/ copiando los datos del archivo ‘tree.dot’.

Python3

# import export_graphviz
from sklearn.tree import export_graphviz 
  
# export the decision tree to a tree.dot file
# for visualizing the plot easily anywhere
export_graphviz(regressor, out_file ='tree.dot',
               feature_names =['Production Cost']) 

Salida (árbol de decisión): 


Publicación traducida automáticamente

Artículo escrito por AnkanDas22 y traducido por Barcelona Geeks. The original can be accessed here. Licence: CCBY-SA

Deja una respuesta

Tu dirección de correo electrónico no será publicada. Los campos obligatorios están marcados con *