public class RandomForest extends java.lang.Object implements Regression<smile.data.Tuple>, DataFrameRegression, TreeSHAP
Each tree is constructed using the following algorithm:
Constructor and Description |
---|
RandomForest(smile.data.formula.Formula formula,
RegressionTree[] trees,
double error,
double[] importance)
Constructor.
|
Modifier and Type | Method and Description |
---|---|
double |
error()
Returns the out-of-bag estimation of RMSE.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data)
Learns a random forest for regression.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
int ntrees,
int mtry,
int maxDepth,
int maxNodes,
int nodeSize,
double subsample)
Learns a random forest for regression.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
int ntrees,
int mtry,
int maxDepth,
int maxNodes,
int nodeSize,
double subsample,
java.util.stream.LongStream seeds)
Learns a random forest for regression.
|
static RandomForest |
fit(smile.data.formula.Formula formula,
smile.data.DataFrame data,
java.util.Properties prop)
Learns a random forest for regression.
|
smile.data.formula.Formula |
formula()
Returns the formula associated with the model.
|
double[] |
importance()
Returns the variable importance.
|
RandomForest |
merge(RandomForest other)
Merges together two random forests and returns a new forest consisting of trees from both input forests.
|
double |
predict(smile.data.Tuple x)
Predicts the dependent variable of an instance.
|
smile.data.type.StructType |
schema()
Returns the schema of predictors.
|
int |
size()
Returns the number of trees in the model.
|
double[][] |
test(smile.data.DataFrame data)
Test the model on a validation dataset.
|
RegressionTree[] |
trees()
Returns the classification/regression trees.
|
void |
trim(int ntrees)
Trims the tree model set to a smaller size in case of over-fitting.
|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
applyAsDouble, predict
predict
public RandomForest(smile.data.formula.Formula formula, RegressionTree[] trees, double error, double[] importance)
formula
- a symbolic description of the model to be fitted.trees
- forest of regression trees.error
- out-of-bag estimation of RMSEimportance
- variable importancepublic static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, java.util.Properties prop)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, int ntrees, int mtry, int maxDepth, int maxNodes, int nodeSize, double subsample)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.ntrees
- the number of trees.mtry
- the number of input variables to be used to determine the
decision at a node of the tree. p/3 generally give good
performance, where p is the number of variables.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the number of instances in a node below which the tree will
not split, nodeSize = 5 generally gives good results.subsample
- the sampling rate for training tree. 1.0 means sampling with
replacement. < 1.0 means sampling without replacement.public static RandomForest fit(smile.data.formula.Formula formula, smile.data.DataFrame data, int ntrees, int mtry, int maxDepth, int maxNodes, int nodeSize, double subsample, java.util.stream.LongStream seeds)
formula
- a symbolic description of the model to be fitted.data
- the data frame of the explanatory and response variables.ntrees
- the number of trees.mtry
- the number of input variables to be used to determine the
decision at a node of the tree. p/3 generally give good
performance, where p is the number of variables.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the number of instances in a node below which the tree will
not split, nodeSize = 5 generally gives good results.subsample
- the sampling rate for training tree. 1.0 means sampling with
replacement. < 1.0 means sampling without replacement.seeds
- optional RNG seeds for each regression tree.public RandomForest merge(RandomForest other)
public smile.data.formula.Formula formula()
DataFrameRegression
formula
in interface TreeSHAP
formula
in interface DataFrameRegression
public smile.data.type.StructType schema()
DataFrameRegression
schema
in interface DataFrameRegression
public double error()
public double[] importance()
public int size()
public RegressionTree[] trees()
TreeSHAP
public void trim(int ntrees)
ntrees
- the new (smaller) size of tree model set.public double predict(smile.data.Tuple x)
Regression
predict
in interface DataFrameRegression
predict
in interface Regression<smile.data.Tuple>
x
- an instance.public double[][] test(smile.data.DataFrame data)
data
- the test data set.