kr.ac.kaist.ir.deep.train

Trainer

class Trainer[IN, OUT] extends Serializable

General Trainer Implementation.

This class trains with help of Training Style and Input Operation.

IN

the type of input. Currently, kr.ac.kaist.ir.deep.fn.ScalarMatrix and DAG are supported

OUT

the type of output Currently, kr.ac.kaist.ir.deep.fn.ScalarMatrix and Null are supported

Example:
  1. val net:Network = ...
    
             // Define Manipulation Type. VectorType, AEType, RAEType and URAEType.
             val operation = new VectorType(
                corrupt = GaussianCorruption(variance = 0.1)
             )
    
            // Define Training Style. SingleThreadTrainStyle vs DistBeliefTrainStyle
             val style = new SingleThreadTrainStyle(
               net = net,
               algorithm = new StochasticGradientDescent(l2decay = 0.0001),
                make = operation,
               param = SimpleTrainingCriteria(miniBatch = 8))
    
            // Define Trainer
            val train = new Trainer(
               style = style,
               stops = StoppingCriteria(maxIter = 100000))
    
            // Do Train
            train.train(set, valid)
Note

To train an autoencoder, you can provide same training set as validation set.

,

This trainer is generalized class. Further implementation, you should see several styles.

Linear Supertypes
Serializable, Serializable, AnyRef, Any
Ordering
  1. Alphabetic
  2. By inheritance
Inherited
  1. Trainer
  2. Serializable
  3. Serializable
  4. AnyRef
  5. Any
  1. Hide All
  2. Show all
Learn more about member selection
Visibility
  1. Public
  2. All

Instance Constructors

  1. new Trainer(style: TrainStyle[IN, OUT], stops: StoppingCriteria = StoppingCriteria(), name: String = "Trainer")

    style

    Training style that supervises how to train. There are two styles, one is SingleThreadTrainStyle and the other is DistBeliefTrainStyle.

    stops

    Stopping Criteria that controls the threshold for stopping. (Default : StoppingCriteria)

    name

    Name used for logging.

Value Members

  1. final def !=(arg0: Any): Boolean

    Definition Classes
    AnyRef → Any
  2. final def ##(): Int

    Definition Classes
    AnyRef → Any
  3. final def ==(arg0: Any): Boolean

    Definition Classes
    AnyRef → Any
  4. final def asInstanceOf[T0]: T0

    Definition Classes
    Any
  5. var bestIter: Int

    Best Loss Iteration Number

    Best Loss Iteration Number

    Attributes
    protected
  6. var bestParam: IndexedSeq[ScalarMatrix]

    Best Parameter History

    Best Parameter History

    Attributes
    protected
  7. def clone(): AnyRef

    Attributes
    protected[java.lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  8. final def eq(arg0: AnyRef): Boolean

    Definition Classes
    AnyRef
  9. def equals(arg0: Any): Boolean

    Definition Classes
    AnyRef → Any
  10. def finalize(): Unit

    Attributes
    protected[java.lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( classOf[java.lang.Throwable] )
  11. final def getClass(): Class[_]

    Definition Classes
    AnyRef → Any
  12. def hashCode(): Int

    Definition Classes
    AnyRef → Any
  13. final def isInstanceOf[T0]: Boolean

    Definition Classes
    Any
  14. val logger: Logger

    Logger

    Logger

    Attributes
    protected
  15. val name: String

    Name used for logging.

  16. final def ne(arg0: AnyRef): Boolean

    Definition Classes
    AnyRef
  17. final def notify(): Unit

    Definition Classes
    AnyRef
  18. final def notifyAll(): Unit

    Definition Classes
    AnyRef
  19. def printValidation(): Unit

    Print validation result into logger

    Print validation result into logger

    Attributes
    protected
  20. final def restoreParams(): Unit

    Restore best parameters

    Restore best parameters

    Attributes
    protected
  21. final def saveParams(epoch: Int = 0, loss: Scalar = Float.MaxValue, patience: Int = validationPeriod * 5): Unit

    Store best parameters

    Store best parameters

    epoch

    current iteration epoch. (1 iteration = 1 validation freq)

    loss

    previous loss

    patience

    current patience, i.e. loop until at least this epoch.

    Attributes
    protected
  22. val stops: StoppingCriteria

    Stopping Criteria that controls the threshold for stopping.

    Stopping Criteria that controls the threshold for stopping. (Default : StoppingCriteria)

  23. val style: TrainStyle[IN, OUT]

    Training style that supervises how to train.

    Training style that supervises how to train. There are two styles, one is SingleThreadTrainStyle and the other is DistBeliefTrainStyle.

  24. final def synchronized[T0](arg0: ⇒ T0): T0

    Definition Classes
    AnyRef
  25. def toString(): String

    Definition Classes
    AnyRef → Any
  26. def train(set: RDD[Pair], validation: RDD[Pair]): Scalar

    Train using given RDD sequence.

    Train using given RDD sequence.

    set

    RDD of training set

    validation

    RDD of validation set

  27. def train(set: RDD[Pair]): Scalar

    Train using given RDD sequence.

    Train using given RDD sequence.

    set

    RDD of training set

  28. def train(set: Seq[Pair], validation: Seq[Pair]): Scalar

    Train given sequence, and validate with another sequence.

    Train given sequence, and validate with another sequence.

    set

    Full Sequence of training set

    validation

    Full Sequence of validation set

    returns

    Training error (loss)

  29. def train(set: Seq[Pair]): Scalar

    Train given sequence, and validate with given sequence.

    Train given sequence, and validate with given sequence.

    set

    Full Sequence of training set

    returns

    Training error (loss)

  30. final def trainBatch(epoch: Int = 0, prevloss: Scalar = Float.MaxValue, patience: Int = validationPeriod * 5): Scalar

    Tail Recursive : Train each batch

    Tail Recursive : Train each batch

    epoch

    current iteration epoch. (1 iteration = 1 validation freq)

    prevloss

    previous loss

    patience

    current patience, i.e. loop until at least this epoch.

    returns

    Total Loss when train is finished

    Attributes
    protected
    Annotations
    @tailrec()
  31. var validationPeriod: Int

    Period of validation

    Period of validation

    Attributes
    protected
  32. final def wait(): Unit

    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  33. final def wait(arg0: Long, arg1: Int): Unit

    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  34. final def wait(arg0: Long): Unit

    Definition Classes
    AnyRef
    Annotations
    @throws( ... )

Inherited from Serializable

Inherited from Serializable

Inherited from AnyRef

Inherited from Any

Ungrouped