This article describes the decision tree algorithm implemented by Java. Share it for your reference, as follows:
The decision tree algorithm is a method to approximate the value of a discrete function. It is a typical classification method, first processing the data, using the induction algorithm to generate readable rules and decision trees, and then using decisions to analyze new data. In essence, decision trees are the process of classifying data through a series of rules.
Decision tree construction can be carried out in two steps. The first step is the generation of the decision tree: the process of generating the decision tree from the training sample set. Generally speaking, the training sample data set is a data set that has a history and a certain degree of comprehensiveness according to actual needs and is used for data analysis and processing. The second step is pruning of the decision tree: pruning of the decision tree is a process of verifying, correcting and revising the decision tree generated in the previous stage. It mainly uses the data from the new sample data set (called the test data set) to verify the preliminary rules generated during the generation process of the decision tree, and pruning out those branches that affect the accuracy of the prebalance.
The java implementation code is as follows:
package demo;import java.util.HashMap;import java.util.LinkedList;import java.util.List;import java.util.Map;import java.util.Map.Entry;import java.util.Set;public class DicisionTree { public static void main(String[] args) throws Exception { System.out.print("Wulin.com test result:"); String[] attrNames = new String[] { "AGE", "INCOME", "STUDENT", "CREDIT_RATING" }; // Read the sample set Map<Object, List<Sample>> samples = readSamples(attrNames); // Generate the decision tree Object decisionTree = generateDecisionTree(samples, attrNames); // Output the decision tree outputDecisionTree(decisionTree, 0, null); } /** * Read the sample set that has been classified and return the Map: Classification-> List of samples belonging to the category*/ static Map<Object, List<Sample>> readSamples(String[] attrNames) { // Sample attributes and their classification (the last element in the array is the classification to which the sample belongs) Object[][] rawData = new Object[][] { { "<30", "High", "No", "Fair", "0" }, { "<30", "High", "No", "Excellent", "0" }, { "30-40", "High", "No", "Fair", "1" }, { ">40", "Medium", "No", "Fair", "1" }, { ">40", "Low", "Yes", "Fair ", "1" }, { ">40 ", "Low ", "Yes", "Excellent", "0" }, { "30-40", "Low ", "Yes", "Excellent", "1" }, { "<30 ", "Medium", "No", "Fair ", "0" }, { "<30 ", "Low ", "Yes", "Fair ", "1" }, { ">40 ", "Medium", "Yes", "Fair ", "1" }, { "<30 ", "Medium", "Yes", "Excellent", "1" }, { "30-40", "Medium", "No", "Excellent", "1" }, { "30-40", "High", "Yes", "Fair", "1" }, { ">40", "Medium", "No", "Excellent", "0" } }; // Read the sample attributes and their classifications, construct the Sample object representing the sample, and divide the sample set by classification Map<Object, List<Sample>> ret = new HashMap<Object, List<Sample>>(); for (Object[] row : rawData) { Sample sample = new Sample(); int i = 0; for (int n = row.length - 1; i < n; i++) sample.setAttribute(attrNames[i], row[i]); sample.setCategory(row[i]); List<Sample> samples = ret.get(row[i]); if (samples == null) { samples = new LinkedList<Sample>(); ret.put(row[i], samples); } samples.add(sample); } return ret; } /** * Construct the decision tree*/ static Object generateDecisionTree( Map<Object, List<Sample>> categoryToSamples, String[] attrNames) { // If there is only one sample, use the classification to which the sample belongs as the classification of the new sample if (categoryToSamples.size() == 1) return categoryToSamples.keySet().iterator().next(); // If there is no attribute for decision-making, the classification with the most samples in the sample set is used as the classification of the new sample, that is, vote for the classification if (attrNames.length == 0) { int max = 0; Object maxCategory = null; for (Entry<Object, List<Sample>> entry : categoryToSamples .entrySet()) { int cur = entry.getValue().size(); if (cur > max) { max = cur; maxCategory = entry.getKey(); } } return maxCategory; } // Select the test attribute Object[] rst = chooseBestTestAttribute(categoryToSamples, attrNames); // The root node of the decision tree, the branch attribute is the selected test attribute Tree tree = new Tree(attrNames[(Integer) rst[0]]); // Used test attributes should not be selected as test attributes again String[] subA = new String[attrNames.length - 1]; for (int i = 0, j = 0; i < attrNames.length; i++) if (i != (Integer) rst[0]) subA[j++] = attrNames[i]; // Generate branch based on branch attributes @SuppressWarnings("unchecked") Map<Object, Map<Object, List<Sample>>> splits = /* NEW LINE */(Map<Object, Map<Object, List<Sample>>>) rst[2]; for (Entry<Object, Map<Object, List<Sample>>> entry : splits.entrySet()) { Object attrValue = entry.getKey(); Map<Object, List<Sample>> split = entry.getValue(); Object child = generateDecisionTree(split, subA); tree.setChild(attrValue, child); } return tree; } /** * Select the optimal test attribute. Optimal means that if the selected test attribute branch is based on the selected test attribute branch, the sum of the information required for the classification of the new sample* is determined from each branch, which is equivalent to the maximum information gain obtained by determining the test attribute of the new sample* Return to the array: selected attribute subscript, sum of information, Map(attribute value->(category->sample list)) */ static Object[] chooseBestTestAttribute( Map<Object, List<Sample>> categoryToSamples, String[] attrNames) { int minIndex = -1; // Optimal attribute subscript double minValue = Double.MAX_VALUE; // Minimum information Map<Object, Map<Object, List<Sample>>> minSplits = null; // Optimal branch scheme// For each attribute, calculate the sum of the information required to determine the classification of the new sample in each branch when it is used as a test attribute, and select the minimum optimal for (int attrIndex = 0; attrIndex < attrNames.length; attrIndex++) { int allCount = 0; // Counter for counting the total number of samples// Build a Map according to the current attribute: Attribute value->(Classification->Sample List) Map<Object, Map<Object, List<Sample>>> curSplits = /* NEW LINE */new HashMap<Object, Map<Object, List<Sample>>>(); for (Entry<Object, List<Sample>> entry : categoryToSamples .entrySet()) { Object category = entry.getKey(); List<Sample> samples = entry.getValue(); for (Sample sample : samples) { Object attrValue = sample .getAttribute(attrNames[attrIndex]); Map<Object, List<Sample>> split = curSplits.get(attrValue); if (split == null) { split = new HashMap<Object, List<Sample>>(); curSplits.put(attrValue, split); } List<Sample> splitSamples = split.get(category); if (splitSamples == null) { splitSamples = new LinkedList<Sample>(); split.put(category, splitSamples); } splitSamples.add(sample); } allCount += samples.size(); } // Calculate the sum of the information required to determine the classification of the new sample in each branch when using the current attribute as a test attribute double curValue = 0.0; // Counter: accumulate each branch for (Map<Object, List<Sample>> splits : curSplits.values()) { double perSplitCount = 0; for (List<Sample> list : splits.values()) perSplitCount += list.size(); // Cumulative current branch samples double perSplitValue = 0.0; // Counter: current branch for (List<Sample> list : splits.values()) { double p = list.size() / perSplitCount; perSplitValue -= p * (Math.log(p) / Math.log(2)); } curValue += (perSplitCount / allCount) * perSplitValue; } // Select the minimum to optimal if (minValue > curValue) { minIndex = attrIndex; minValue = curValue; minSplits = curSplits; } } return new Object[] { minIndex, minValue, minSplits }; } /** * Output the decision tree to standard output*/ static void outputDecisionTree(Object obj, int level, Object from) { for (int i = 0; i < level; i++) System.out.print("|------"); if (from != null) System.out.printf("(%s):", from); if (obj instanceof Tree) { Tree tree = (Tree) obj; String attrName = tree.getAttribute(); System.out.printf("[%s = ?]/n", attrName); for (Object attrValue : tree.getAttributeValues()) { Object child = tree.getChild(attrValue); outputDecisionTree(child, level + 1, attrName + " = " + attrValue); } } else { System.out.printf("[CATEGORY = %s]/n", obj); } } /** * Sample, containing multiple attributes and a classification value that specifies the classification to which the sample belongs*/ static class Sample { private Map<String, Object> attributes = new HashMap<String, Object>(); private Object category; public Object getAttribute(String name) { return attributes.get(name); } public void setAttribute(String name, Object value) { attributes.put(name, value); } public Object getCategory() { return category; } public void setCategory(Object category) { this.category = category; } public String toString() { return attributes.toString(); } } /** * Decision tree (non-leaf node), each non-leaf node in the decision tree leads a decision tree* Each non-leaf node contains a branch attribute and multiple branches. Each value of the branch attribute corresponds to a branch. The branch guides a sub-decision tree*/ static class Tree { private String attribute; private Map<Object, Object> children = new HashMap<Object, Object>(); public Tree(String attribute) { this.attribute = attribute; } public String getAttribute() { return attribute; } public Object getChild(Object attrValue) { return children.get(attrValue); } public void setChild(Object attrValue, Object child) { children.put(attrValue, child); } public Set<Object> getAttributeValues() { return children.keySet(); } }}Running results:
For more information about Java algorithms, readers who are interested in this site can view the topics: "Java Data Structure and Algorithm Tutorial", "Summary of Java Operation DOM Node Tips", "Summary of Java File and Directory Operation Tips" and "Summary of Java Cache Operation Tips"
I hope this article will be helpful to everyone's Java programming.