ML之DT:利用DT(DTC)实现对iris(鸢尾花)数据集进行分类并可视化DT结构

ML之DT:利用DT(DTC)实现对iris(鸢尾花)数据集进行分类并可视化DT结构


输出结果

实现代码

#1、
iris = load_iris()
dir(iris)

iris_feature_name = iris.feature_names
iris_features = iris.data
iris_target_name = iris.target_names
iris_target = iris.target

print('iris_feature_name','\n',iris_feature_name)
print('iris_features前5','\n',iris_features[:5,:],iris_features.shape)
print('iris_target_name','\n',iris_target_name)
print('iris_target','\n',iris_target)

#2、
clf = tree.DecisionTreeClassifier(max_depth=4)
clf = clf.fit(iris_features, iris_target)

#3、
import pydotplus
from IPython.display import Image, display

dot_data = tree.export_graphviz(clf,
                                out_file = None,
                                feature_names = iris_feature_name,
                                class_names = iris_target_name,
                                filled=True,
                                rounded=True
                               )

from IPython.display import display, Image
graph = pydotplus.graph_from_dot_data(dot_data)
# graph.write_png(r"DT.png")
display(Image(graph.create_png()))
Image(graph.create_png())

import matplotlib.pyplot as plt

img_path='DT.png'
plt.imshow(img_path)
plt.show()
(0)

相关推荐