public class DecisionTree extends CART implements SoftClassifier<smile.data.Tuple>, DataFrameClassifier
The algorithms that are used for constructing decision trees usually work top-down by choosing a variable at each step that is the next best variable to use in splitting the set of items. "Best" is defined by how well the variable splits the set into homogeneous subsets that have the same value of the target variable. Different algorithms use different formulae for measuring "best". Used by the CART algorithm, Gini impurity is a measure of how often a randomly chosen element from the set would be incorrectly labeled if it were randomly labeled according to the distribution of labels in the subset. Gini impurity can be computed by summing the probability of each item being chosen times the probability of a mistake in categorizing that item. It reaches its minimum (zero) when all cases in the node fall into a single target category. Information gain is another popular measure, used by the ID3, C4.5 and C5.0 algorithms. Information gain is based on the concept of entropy used in information theory. For categorical variables with different number of levels, however, information gain are biased in favor of those attributes with more levels. Instead, one may employ the information gain ratio, which solves the drawback of information gain.
Classification and Regression Tree techniques have a number of advantages over many of those alternative techniques.
Some techniques such as bagging, boosting, and random forest use more than one decision tree for their analysis.
AdaBoost
,
GradientTreeBoost
,
RandomForest
,
Serialized FormConstructor and Description |
---|
DecisionTree(smile.data.DataFrame x,
int[] y,
smile.data.type.StructField response,
int k,
SplitRule rule,
int maxDepth,
int maxNodes,
int nodeSize,
int mtry,
int[] samples,
int[][] order)
Constructor.
|
Modifier and Type | Method and Description |
---|---|
protected java.util.Optional<Split> |
findBestSplit(LeafNode leaf,
int j,
double impurity,
int lo,
int hi)
Finds the best split for given column.
|
static DecisionTree |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data)
Learns a classification tree.
|
static DecisionTree |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
java.util.Properties prop)
Learns a classification tree.
|
static DecisionTree |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
SplitRule rule,
int maxDepth,
int maxNodes,
int nodeSize)
Learns a classification tree.
|
smile.data.formula.Formula |
formula()
Returns null if the tree is part of ensemble algorithm.
|
protected double |
impurity(LeafNode node)
Returns the impurity of node.
|
protected LeafNode |
newNode(int[] nodeSamples)
Creates a new leaf node.
|
int |
predict(smile.data.Tuple x)
Predicts the class label of an instance.
|
int |
predict(smile.data.Tuple x,
double[] posteriori)
Predicts the class label of an instance and also calculate a posteriori
probabilities.
|
DecisionTree |
prune(smile.data.DataFrame test)
Returns a new decision tree by reduced error pruning.
|
smile.data.type.StructType |
schema()
Returns the design matrix schema.
|
clear, dot, findBestSplit, importance, order, root, size, split, toString
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
applyAsDouble, applyAsInt, f, predict
predict
public DecisionTree(smile.data.DataFrame x, int[] y, smile.data.type.StructField response, int k, SplitRule rule, int maxDepth, int maxNodes, int nodeSize, int mtry, int[] samples, int[][] order)
x
- the data frame of the explanatory variable.y
- the response variables.response
- the metadata of response variable.k
- the number of classes.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the minimum size of leaf nodes.mtry
- the number of input variables to pick to split on at each
node. It seems that sqrt(p) give generally good performance,
where p is the number of variables.rule
- the splitting rule.samples
- the sample set of instances for stochastic learning.
samples[i] is the number of sampling for instance i.order
- the index of training values in ascending order. Note
that only numeric attributes need be sorted.protected double impurity(LeafNode node)
CART
protected LeafNode newNode(int[] nodeSamples)
CART
protected java.util.Optional<Split> findBestSplit(LeafNode leaf, int j, double impurity, int lo, int hi)
CART
findBestSplit
in class CART
public static DecisionTree fit(smile.data.formula.Formula formula, smile.data.DataFrame data)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.public static DecisionTree fit(smile.data.formula.Formula formula, smile.data.DataFrame data, java.util.Properties prop)
prop
include
smile.cart.split.rule
smile.cart.node.size
smile.cart.max.nodes
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.prop
- Training algorithm hyper-parameters and properties.public static DecisionTree fit(smile.data.formula.Formula formula, smile.data.DataFrame data, SplitRule rule, int maxDepth, int maxNodes, int nodeSize)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.rule
- the splitting rule.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the minimum size of leaf nodes.public int predict(smile.data.Tuple x)
Classifier
predict
in interface Classifier<smile.data.Tuple>
predict
in interface DataFrameClassifier
x
- the instance to be classified.public int predict(smile.data.Tuple x, double[] posteriori)
predict
in interface SoftClassifier<smile.data.Tuple>
x
- an instance to be classified.posteriori
- the array to store a posteriori probabilities on output.public smile.data.formula.Formula formula()
formula
in interface DataFrameClassifier
public smile.data.type.StructType schema()
DataFrameClassifier
schema
in interface DataFrameClassifier
public DecisionTree prune(smile.data.DataFrame test)
test
- the test data set to evaluate the errors of nodes.