程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
 程式師世界 >> 編程語言 >> .NET網頁編程 >> C# >> C#入門知識 >> 數據挖掘之決策樹ID3算法(C#實現),

數據挖掘之決策樹ID3算法(C#實現),

編輯:C#入門知識

數據挖掘之決策樹ID3算法(C#實現),


決策樹是一種非常經典的分類器,它的作用原理有點類似於我們玩的猜謎游戲。比如猜一個動物:

問:這個動物是陸生動物嗎?

答:是的。

問:這個動物有鰓嗎?

答:沒有。

這樣的兩個問題順序就有些顛倒,因為一般來說陸生動物是沒有鰓的(記得應該是這樣的,如有錯誤歡迎指正)。所以玩這種游戲,提問的順序很重要,爭取每次都能夠獲得盡可能多的信息量。

AllElectronics顧客數據庫標記類的訓練元組 RID age income student credit_rating Class: buys_computer 1 youth high no fair no 2 youth high no excellent no 3 middle_aged high no fair yes 4 senior medium no fair yes 5 senior low yes fair yes 6 senior low yes excellent no 7 middle_aged low yes excellent yes 8 youth medium no fair no 9 youth low yes fair yes 10 senior medium yes fair yes 11 youth medium yes excellent yes 12 middle_aged medium no excellent yes 13 middle_aged high yes fair yes 14 senior medium no excellent no

以AllElectronics顧客數據庫標記類的訓練元組為例。我們想要以這些樣本為訓練集,訓練我們的決策樹模型,以此來挖掘出顧客是否會購買電腦的決策模式。

在決策樹ID3算法中,計算信息度的公式如下:

$$Info_A(D) = \sum_{j=1}^v\frac{|D_j|}{D} \times Info(D_j)$$

計算信息增益的公式如下:

$$Gain(A) = Info(D) - Info_A(D)$$

按照公式,在要進行分類的類別變量中,有5個“no”和9個“yes”,因此期望信息為:

$$Info(D)=-\frac{9}{14}log_2\frac{9}{14}-\frac{5}{14}log_2\frac{5}{14}=0.940$$

首先計算特征age的期望信息:

$$Info_{age}(D)=\frac{5}{14} \times (-\frac{2}{5}log_2\frac{2}{5} - \frac{3}{5}log_2\frac{3}{5})+\frac{4}{14} \times (-\frac{4}{4}log_2\frac{4}{4} - \frac{0}{4}log_2\frac{0}{4})+\frac{5}{14} \times (-\frac{3}{5}log_2\frac{3}{5} - \frac{2}{5}log_2\frac{2}{5})$$

因此,如果按照age進行劃分,則獲得的信息增益為:

$$Gain(age) = Info(D)-Info_{age}(D) = 0.940-0.694=0.246$$

依次計算以income、student和credit_rating來分裂的信息增益,由此選擇能夠帶來最大信息增益的變量,在當

前結點選擇以以該變量的取值進行分裂。遞歸地進行執行即可生成決策樹。更加詳細的內容可以參考:

https://en.wikipedia.org/wiki/Decision_tree

C#代碼的實現如下:

  1 using System;
  2 using System.Collections.Generic;
  3 using System.Linq;
  4 namespace MachineLearning.DecisionTree
  5 {
  6     public class DecisionTreeID3<T> where T : IEquatable<T>
  7     {
  8         T[,] Data;
  9         string[] Names;
 10         int Category;
 11         T[] CategoryLabels;
 12         DecisionTreeNode<T> Root;
 13         public DecisionTreeID3(T[,] data, string[] names, T[] categoryLabels)
 14         {
 15             Data = data;
 16             Names = names;
 17             Category = data.GetLength(1) - 1;//類別變量需要放在最後一列
 18             CategoryLabels = categoryLabels;
 19         }
 20         public void Learn()
 21         {
 22             int nRows = Data.GetLength(0);
 23             int nCols = Data.GetLength(1);
 24             int[] rows = new int[nRows];
 25             int[] cols = new int[nCols];
 26             for (int i = 0; i < nRows; i++) rows[i] = i;
 27             for (int i = 0; i < nCols; i++) cols[i] = i;
 28             Root = new DecisionTreeNode<T>(-1, default(T));
 29             Learn(rows, cols, Root);
 30             DisplayNode(Root);
 31         }
 32         public void DisplayNode(DecisionTreeNode<T> Node, int depth = 0)
 33         {
 34             if (Node.Label != -1)
 35                 Console.WriteLine("{0} {1}: {2}", new string('-', depth * 3), Names[Node.Label], Node.Value);
 36             foreach (var item in Node.Children)
 37                 DisplayNode(item, depth + 1);
 38         }
 39         private void Learn(int[] pnRows, int[] pnCols, DecisionTreeNode<T> Root)
 40         {
 41             var categoryValues = GetAttribute(Data, Category, pnRows);
 42             var categoryCount = categoryValues.Distinct().Count();
 43             if (categoryCount == 1)
 44             {
 45                 var node = new DecisionTreeNode<T>(Category, categoryValues.First());
 46                 Root.Children.Add(node);
 47             }
 48             else
 49             {
 50                 if (pnRows.Length == 0) return;
 51                 else if (pnCols.Length == 1)
 52                 {
 53                     //投票~
 54                     //多數票表決制
 55                     var Vote = categoryValues.GroupBy(i => i).OrderBy(i => i.Count()).First();
 56                     var node = new DecisionTreeNode<T>(Category, Vote.First());
 57                     Root.Children.Add(node);
 58                 }
 59                 else
 60                 {
 61                     var maxCol = MaxEntropy(pnRows, pnCols);
 62                     var attributes = GetAttribute(Data, maxCol, pnRows).Distinct();
 63                     string currentPrefix = Names[maxCol];
 64                     foreach (var attr in attributes)
 65                     {
 66                         int[] rows = pnRows.Where(irow => Data[irow, maxCol].Equals(attr)).ToArray();
 67                         int[] cols = pnCols.Where(i => i != maxCol).ToArray();
 68                         var node = new DecisionTreeNode<T>(maxCol, attr);
 69                         Root.Children.Add(node);
 70                         Learn(rows, cols, node);//遞歸生成決策樹
 71                     }
 72                 }
 73             }
 74         }
 75         public double AttributeInfo(int attrCol, int[] pnRows)
 76         {
 77             var tuples = AttributeCount(attrCol, pnRows);
 78             var sum = (double)pnRows.Length;
 79             double Entropy = 0.0;
 80             foreach (var tuple in tuples)
 81             {
 82                 int[] count = new int[CategoryLabels.Length];
 83                 foreach (var irow in pnRows)
 84                     if (Data[irow, attrCol].Equals(tuple.Item1))
 85                     {
 86                         int index = Array.IndexOf(CategoryLabels, Data[irow, Category]);
 87                         count[index]++;
 88                     }
 89                 double k = 0.0;
 90                 for (int i = 0; i < count.Length; i++)
 91                 {
 92                     double frequency = count[i] / (double)tuple.Item2;
 93                     double t = -frequency * Log2(frequency);
 94                     k += t;
 95                 }
 96                 double freq = tuple.Item2 / sum;
 97                 Entropy += freq * k;
 98             }
 99             return Entropy;
100         }
101         public double CategoryInfo(int[] pnRows)
102         {
103             var tuples = AttributeCount(Category, pnRows);
104             var sum = (double)pnRows.Length;
105             double Entropy = 0.0;
106             foreach (var tuple in tuples)
107             {
108                 double frequency = tuple.Item2 / sum;
109                 double t = -frequency * Log2(frequency);
110                 Entropy += t;
111             }
112             return Entropy;
113         }
114         private static IEnumerable<T> GetAttribute(T[,] data, int col, int[] pnRows)
115         {
116             foreach (var irow in pnRows)
117                 yield return data[irow, col];
118         }
119         private static double Log2(double x)
120         {
121             return x == 0.0 ? 0.0 : Math.Log(x, 2.0);
122         }
123         public int MaxEntropy(int[] pnRows, int[] pnCols)
124         {
125             double cateEntropy = CategoryInfo(pnRows);
126             int maxAttr = 0;
127             double max = double.MinValue;
128             foreach (var icol in pnCols)
129                 if (icol != Category)
130                 {
131                     double Gain = cateEntropy - AttributeInfo(icol, pnRows);
132                     if (max < Gain)
133                     {
134                         max = Gain;
135                         maxAttr = icol;
136                     }
137                 }
138             return maxAttr;
139         }
140         public IEnumerable<Tuple<T, int>> AttributeCount(int col, int[] pnRows)
141         {
142             var tuples = from n in GetAttribute(Data, col, pnRows)
143                          group n by n into i
144                          select Tuple.Create(i.First(), i.Count());
145             return tuples;
146         }
147     }
148 }

調用方法如下:

 1 using System;
 2 using System.Collections.Generic;
 3 using System.Linq;
 4 using System.Text;
 5 using System.Threading.Tasks;
 6 using MachineLearning.DecisionTree;
 7 namespace MachineLearning
 8 {
 9     class Program
10     {
11         static void Main(string[] args)
12         {
13             var data = new string[,]
14             {
15                 {"youth","high","no","fair","no"},
16                 {"youth","high","no","excellent","no"},
17                 {"middle_aged","high","no","fair","yes"},
18                 {"senior","medium","no","fair","yes"},
19                 {"senior","low","yes","fair","yes"},
20                 {"senior","low","yes","excellent","no"},
21                 {"middle_aged","low","yes","excellent","yes"},
22                 {"youth","medium","no","fair","no"},
23                 {"youth","low","yes","fair","yes"},
24                 {"senior","medium","yes","fair","yes"},
25                 {"youth","medium","yes","excellent","yes"},
26                 {"middle_aged","medium","no","excellent","yes"},
27                 {"middle_aged","high","yes","fair","yes"},
28                 {"senior","medium","no","excellent","no"}
29             };
30             var names = new string[] { "age", "income", "student", "credit_rating", "Class: buys_computer" };
31             var tree = new DecisionTreeID3<string>(data, names, new string[] { "yes", "no" });
32             tree.Learn();
33             Console.ReadKey();
34         }
35     }
36 }

 

運行結果:


 

注:作者本人也在學習中,能力有限,如有錯漏還請不吝指正。轉載請注明作者。

  1. 上一頁:
  2. 下一頁:
Copyright © 程式師世界 All Rights Reserved