決策樹是一種非常經典的分類器,它的作用原理有點類似於我們玩的猜謎游戲。比如猜一個動物:
問:這個動物是陸生動物嗎?
答:是的。
問:這個動物有鰓嗎?
答:沒有。
這樣的兩個問題順序就有些顛倒,因為一般來說陸生動物是沒有鰓的(記得應該是這樣的,如有錯誤歡迎指正)。所以玩這種游戲,提問的順序很重要,爭取每次都能夠獲得盡可能多的信息量。
以AllElectronics顧客數據庫標記類的訓練元組為例。我們想要以這些樣本為訓練集,訓練我們的決策樹模型,以此來挖掘出顧客是否會購買電腦的決策模式。
在決策樹ID3算法中,計算信息度的公式如下:
$$Info_A(D) = \sum_{j=1}^v\frac{|D_j|}{D} \times Info(D_j)$$
計算信息增益的公式如下:
$$Gain(A) = Info(D) - Info_A(D)$$
按照公式,在要進行分類的類別變量中,有5個“no”和9個“yes”,因此期望信息為:
$$Info(D)=-\frac{9}{14}log_2\frac{9}{14}-\frac{5}{14}log_2\frac{5}{14}=0.940$$
首先計算特征age的期望信息:
$$Info_{age}(D)=\frac{5}{14} \times (-\frac{2}{5}log_2\frac{2}{5} - \frac{3}{5}log_2\frac{3}{5})+\frac{4}{14} \times (-\frac{4}{4}log_2\frac{4}{4} - \frac{0}{4}log_2\frac{0}{4})+\frac{5}{14} \times (-\frac{3}{5}log_2\frac{3}{5} - \frac{2}{5}log_2\frac{2}{5})$$
因此,如果按照age進行劃分,則獲得的信息增益為:
$$Gain(age) = Info(D)-Info_{age}(D) = 0.940-0.694=0.246$$
依次計算以income、student和credit_rating來分裂的信息增益,由此選擇能夠帶來最大信息增益的變量,在當
前結點選擇以以該變量的取值進行分裂。遞歸地進行執行即可生成決策樹。更加詳細的內容可以參考:
https://en.wikipedia.org/wiki/Decision_tree
C#代碼的實現如下:
1 using System;
2 using System.Collections.Generic;
3 using System.Linq;
4 namespace MachineLearning.DecisionTree
5 {
6 public class DecisionTreeID3<T> where T : IEquatable<T>
7 {
8 T[,] Data;
9 string[] Names;
10 int Category;
11 T[] CategoryLabels;
12 DecisionTreeNode<T> Root;
13 public DecisionTreeID3(T[,] data, string[] names, T[] categoryLabels)
14 {
15 Data = data;
16 Names = names;
17 Category = data.GetLength(1) - 1;//類別變量需要放在最後一列
18 CategoryLabels = categoryLabels;
19 }
20 public void Learn()
21 {
22 int nRows = Data.GetLength(0);
23 int nCols = Data.GetLength(1);
24 int[] rows = new int[nRows];
25 int[] cols = new int[nCols];
26 for (int i = 0; i < nRows; i++) rows[i] = i;
27 for (int i = 0; i < nCols; i++) cols[i] = i;
28 Root = new DecisionTreeNode<T>(-1, default(T));
29 Learn(rows, cols, Root);
30 DisplayNode(Root);
31 }
32 public void DisplayNode(DecisionTreeNode<T> Node, int depth = 0)
33 {
34 if (Node.Label != -1)
35 Console.WriteLine("{0} {1}: {2}", new string('-', depth * 3), Names[Node.Label], Node.Value);
36 foreach (var item in Node.Children)
37 DisplayNode(item, depth + 1);
38 }
39 private void Learn(int[] pnRows, int[] pnCols, DecisionTreeNode<T> Root)
40 {
41 var categoryValues = GetAttribute(Data, Category, pnRows);
42 var categoryCount = categoryValues.Distinct().Count();
43 if (categoryCount == 1)
44 {
45 var node = new DecisionTreeNode<T>(Category, categoryValues.First());
46 Root.Children.Add(node);
47 }
48 else
49 {
50 if (pnRows.Length == 0) return;
51 else if (pnCols.Length == 1)
52 {
53 //投票~
54 //多數票表決制
55 var Vote = categoryValues.GroupBy(i => i).OrderBy(i => i.Count()).First();
56 var node = new DecisionTreeNode<T>(Category, Vote.First());
57 Root.Children.Add(node);
58 }
59 else
60 {
61 var maxCol = MaxEntropy(pnRows, pnCols);
62 var attributes = GetAttribute(Data, maxCol, pnRows).Distinct();
63 string currentPrefix = Names[maxCol];
64 foreach (var attr in attributes)
65 {
66 int[] rows = pnRows.Where(irow => Data[irow, maxCol].Equals(attr)).ToArray();
67 int[] cols = pnCols.Where(i => i != maxCol).ToArray();
68 var node = new DecisionTreeNode<T>(maxCol, attr);
69 Root.Children.Add(node);
70 Learn(rows, cols, node);//遞歸生成決策樹
71 }
72 }
73 }
74 }
75 public double AttributeInfo(int attrCol, int[] pnRows)
76 {
77 var tuples = AttributeCount(attrCol, pnRows);
78 var sum = (double)pnRows.Length;
79 double Entropy = 0.0;
80 foreach (var tuple in tuples)
81 {
82 int[] count = new int[CategoryLabels.Length];
83 foreach (var irow in pnRows)
84 if (Data[irow, attrCol].Equals(tuple.Item1))
85 {
86 int index = Array.IndexOf(CategoryLabels, Data[irow, Category]);
87 count[index]++;
88 }
89 double k = 0.0;
90 for (int i = 0; i < count.Length; i++)
91 {
92 double frequency = count[i] / (double)tuple.Item2;
93 double t = -frequency * Log2(frequency);
94 k += t;
95 }
96 double freq = tuple.Item2 / sum;
97 Entropy += freq * k;
98 }
99 return Entropy;
100 }
101 public double CategoryInfo(int[] pnRows)
102 {
103 var tuples = AttributeCount(Category, pnRows);
104 var sum = (double)pnRows.Length;
105 double Entropy = 0.0;
106 foreach (var tuple in tuples)
107 {
108 double frequency = tuple.Item2 / sum;
109 double t = -frequency * Log2(frequency);
110 Entropy += t;
111 }
112 return Entropy;
113 }
114 private static IEnumerable<T> GetAttribute(T[,] data, int col, int[] pnRows)
115 {
116 foreach (var irow in pnRows)
117 yield return data[irow, col];
118 }
119 private static double Log2(double x)
120 {
121 return x == 0.0 ? 0.0 : Math.Log(x, 2.0);
122 }
123 public int MaxEntropy(int[] pnRows, int[] pnCols)
124 {
125 double cateEntropy = CategoryInfo(pnRows);
126 int maxAttr = 0;
127 double max = double.MinValue;
128 foreach (var icol in pnCols)
129 if (icol != Category)
130 {
131 double Gain = cateEntropy - AttributeInfo(icol, pnRows);
132 if (max < Gain)
133 {
134 max = Gain;
135 maxAttr = icol;
136 }
137 }
138 return maxAttr;
139 }
140 public IEnumerable<Tuple<T, int>> AttributeCount(int col, int[] pnRows)
141 {
142 var tuples = from n in GetAttribute(Data, col, pnRows)
143 group n by n into i
144 select Tuple.Create(i.First(), i.Count());
145 return tuples;
146 }
147 }
148 }
調用方法如下:
1 using System;
2 using System.Collections.Generic;
3 using System.Linq;
4 using System.Text;
5 using System.Threading.Tasks;
6 using MachineLearning.DecisionTree;
7 namespace MachineLearning
8 {
9 class Program
10 {
11 static void Main(string[] args)
12 {
13 var data = new string[,]
14 {
15 {"youth","high","no","fair","no"},
16 {"youth","high","no","excellent","no"},
17 {"middle_aged","high","no","fair","yes"},
18 {"senior","medium","no","fair","yes"},
19 {"senior","low","yes","fair","yes"},
20 {"senior","low","yes","excellent","no"},
21 {"middle_aged","low","yes","excellent","yes"},
22 {"youth","medium","no","fair","no"},
23 {"youth","low","yes","fair","yes"},
24 {"senior","medium","yes","fair","yes"},
25 {"youth","medium","yes","excellent","yes"},
26 {"middle_aged","medium","no","excellent","yes"},
27 {"middle_aged","high","yes","fair","yes"},
28 {"senior","medium","no","excellent","no"}
29 };
30 var names = new string[] { "age", "income", "student", "credit_rating", "Class: buys_computer" };
31 var tree = new DecisionTreeID3<string>(data, names, new string[] { "yes", "no" });
32 tree.Learn();
33 Console.ReadKey();
34 }
35 }
36 }
運行結果:

注:作者本人也在學習中,能力有限,如有錯漏還請不吝指正。轉載請注明作者。