本文實例講述了Java實現的決策樹算法。分享給大家供大家參考,具體如下:
決策樹算法是一種逼近離散函數值的方法。它是一種典型的分類方法,首先對數據進行處理,利用歸納算法生成可讀的規則和決策樹,然後使用決策對新數據進行分析。本質上決策樹是通過一系列規則對數據進行分類的過程。
決策樹構造可以分兩步進行。第一步,決策樹的生成:由訓練樣本集生成決策樹的過程。一般情況下,訓練樣本數據集是根據實際需要有歷史的、有一定綜合程度的,用於數據分析處理的數據集。第二步,決策樹的剪枝:決策樹的剪枝是對上一階段生成的決策樹進行檢驗、校正和修下的過程,主要是用新的樣本數據集(稱為測試數據集)中的數據校驗決策樹生成過程中產生的初步規則,將那些影響預衡準確性的分枝剪除。
java實現代碼如下:
package demo;import java.util.HashMap;import java.util.LinkedList;import java.util.List;import java.util.Map;import java.util.Map.Entry;import java.util.Set;public class DicisionTree { public static void main(String[] args) throws Exception { System.out.print("武林網測試結果:"); String[] attrNames = new String[] { "AGE", "INCOME", "STUDENT", "CREDIT_RATING" }; // 讀取樣本集Map<Object, List<Sample>> samples = readSamples(attrNames); // 生成決策樹Object decisionTree = generateDecisionTree(samples, attrNames); // 輸出決策樹outputDecisionTree(decisionTree, 0, null); } /** * 讀取已分類的樣本集,返回Map:分類-> 屬於該分類的樣本的列表*/ static Map<Object, List<Sample>> readSamples(String[] attrNames) { // 樣本屬性及其所屬分類(數組中的最後一個元素為樣本所屬分類) Object[][] rawData = new Object[][] { { "<30 ", "High ", "No ", "Fair ", "0" }, { "<30 ", "High ", "No ", "Excellent", "0" }, { "30-40", "High ", "No ", "Fair ", "1" }, { ">40 ", "Medium", "No ", "Fair ", "1" }, { ">40 ", "Low ", "Yes", "Fair ", "1" }, { ">40 ", "Low ", "Yes", "Excellent", "0" }, { "30-40", "Low ", "Yes", "Excellent", "1" }, { "<30 ", "Medium", "No ", "Fair ", "0" }, { "<30 ", "Low ", "Yes", "Fair ", "1" }, { ">40 ", "Medium", "Yes", "Fair ", "1" }, { "<30 ", "Medium", "Yes", "Excellent", "1" }, { "30-40", "Medium", "No ", "Excellent", "1" }, { "30-40", "High ", "Yes", "Fair ", "1" }, { ">40 ", "Medium", "No ", "Excellent", "0" } }; // 讀取樣本屬性及其所屬分類,構造表示樣本的Sample對象,並按分類劃分樣本集Map<Object, List<Sample>> ret = new HashMap<Object, List<Sample>>(); for (Object[] row : rawData) { Sample sample = new Sample(); int i = 0; for (int n = row.length - 1; i < n; i++) sample.setAttribute(attrNames[i], row[i]); sample.setCategory(row[i]); List<Sample> samples = ret.get(row[i]); if (samples == null) { samples = new LinkedList<Sample>(); ret.put(row[i], samples); } samples.add(sample); } return ret; } /** * 構造決策樹*/ static Object generateDecisionTree( Map<Object, List<Sample>> categoryToSamples, String[] attrNames) { // 如果只有一個樣本,將該樣本所屬分類作為新樣本的分類if (categoryToSamples.size() == 1) return categoryToSamples.keySet().iterator().next(); // 如果沒有供決策的屬性,則將樣本集中具有最多樣本的分類作為新樣本的分類,即投票選舉出分類if (attrNames.length == 0) { int max = 0; Object maxCategory = null; for (Entry<Object, List<Sample>> entry : categoryToSamples .entrySet()) { int cur = entry.getValue().size(); if (cur > max) { max = cur; maxCategory = entry.getKey(); } } return maxCategory; } // 選取測試屬性Object[] rst = chooseBestTestAttribute(categoryToSamples, attrNames); // 決策樹根結點,分支屬性為選取的測試屬性Tree tree = new Tree(attrNames[(Integer) rst[0]]); // 已用過的測試屬性不應再次被選為測試屬性String[] subA = new String[attrNames.length - 1]; for (int i = 0, j = 0; i < attrNames.length; i++) if (i != (Integer) rst[0]) subA[j++] = attrNames[i]; // 根據分支屬性生成分支@SuppressWarnings("unchecked") Map<Object, Map<Object, List<Sample>>> splits = /* NEW LINE */(Map<Object, Map<Object, List<Sample>>>) rst[2]; for (Entry<Object, Map<Object, List<Sample>>> entry : splits.entrySet()) { Object attrValue = entry.getKey(); Map<Object, List<Sample>> split = entry.getValue(); Object child = generateDecisionTree(split, subA); tree.setChild(attrValue, child); } return tree; } /** * 選取最優測試屬性。最優是指如果根據選取的測試屬性分支,則從各分支確定新樣本* 的分類需要的信息量之和最小,這等價於確定新樣本的測試屬性獲得的信息增益最大* 返回數組:選取的屬性下標、信息量之和、Map(屬性值->(分類->樣本列表)) */ static Object[] chooseBestTestAttribute( Map<Object, List<Sample>> categoryToSamples, String[] attrNames) { int minIndex = -1; // 最優屬性下標double minValue = Double.MAX_VALUE; // 最小信息量Map<Object, Map<Object, List<Sample>>> minSplits = null; // 最優分支方案// 對每一個屬性,計算將其作為測試屬性的情況下在各分支確定新樣本的分類需要的信息量之和,選取最小為最優for (int attrIndex = 0; attrIndex < attrNames.length; attrIndex++) { int allCount = 0; // 統計樣本總數的計數器// 按當前屬性構建Map:屬性值->(分類->樣本列表) Map<Object, Map<Object, List<Sample>>> curSplits = /* NEW LINE */new HashMap<Object, Map<Object, List<Sample>>>(); for (Entry<Object, List<Sample>> entry : categoryToSamples .entrySet()) { Object category = entry.getKey(); List<Sample> samples = entry.getValue(); for (Sample sample : samples) { Object attrValue = sample .getAttribute(attrNames[attrIndex]); Map<Object, List<Sample>> split = curSplits.get(attrValue); if (split == null) { split = new HashMap<Object, List<Sample>>(); curSplits.put(attrValue, split); } List<Sample> splitSamples = split.get(category); if (splitSamples == null) { splitSamples = new LinkedList<Sample>(); split.put(category, splitSamples); } splitSamples.add(sample); } allCount += samples.size(); } // 計算將當前屬性作為測試屬性的情況下在各分支確定新樣本的分類需要的信息量之和double curValue = 0.0; // 計數器:累加各分支for (Map<Object, List<Sample>> splits : curSplits.values()) { double perSplitCount = 0; for (List<Sample> list : splits.values()) perSplitCount += list.size(); // 累計當前分支樣本數double perSplitValue = 0.0; // 計數器:當前分支for (List<Sample> list : splits.values()) { double p = list.size() / perSplitCount; perSplitValue -= p * (Math.log(p) / Math.log(2)); } curValue += (perSplitCount / allCount) * perSplitValue; } // 選取最小為最優if (minValue > curValue) { minIndex = attrIndex; minValue = curValue; minSplits = curSplits; } } return new Object[] { minIndex, minValue, minSplits }; } /** * 將決策樹輸出到標準輸出*/ static void outputDecisionTree(Object obj, int level, Object from) { for (int i = 0; i < level; i++) System.out.print("|-----"); if (from != null) System.out.printf("(%s):", from); if (obj instanceof Tree) { Tree tree = (Tree) obj; String attrName = tree.getAttribute(); System.out.printf("[%s = ?]/n", attrName); for (Object attrValue : tree.getAttributeValues()) { Object child = tree.getChild(attrValue); outputDecisionTree(child, level + 1, attrName + " = " + attrValue); } } else { System.out.printf("[CATEGORY = %s]/n", obj); } } /** * 樣本,包含多個屬性和一個指明樣本所屬分類的分類值*/ static class Sample { private Map<String, Object> attributes = new HashMap<String, Object>(); private Object category; public Object getAttribute(String name) { return attributes.get(name); } public void setAttribute(String name, Object value) { attributes.put(name, value); } public Object getCategory() { return category; } public void setCategory(Object category) { this.category = category; } public String toString() { return attributes.toString(); } } /** * 決策樹(非葉結點),決策樹中的每個非葉結點都引導了一棵決策樹* 每個非葉結點包含一個分支屬性和多個分支,分支屬性的每個值對應一個分支,該分支引導了一棵子決策樹*/ static class Tree { private String attribute; private Map<Object, Object> children = new HashMap<Object, Object>(); public Tree(String attribute) { this.attribute = attribute; } public String getAttribute() { return attribute; } public Object getChild(Object attrValue) { return children.get(attrValue); } public void setChild(Object attrValue, Object child) { children.put(attrValue, child); } public Set<Object> getAttributeValues() { return children.keySet(); } }}運行結果:
更多關於java算法相關內容感興趣的讀者可查看本站專題:《Java數據結構與算法教程》、《Java操作DOM節點技巧總結》、《Java文件與目錄操作技巧匯總》和《Java緩存操作技巧匯總》
希望本文所述對大家java程序設計有所幫助。