Class RecurrentBlock

  • All Implemented Interfaces:
    Block
    Direct Known Subclasses:
    GRU, LSTM, RNN

    public abstract class RecurrentBlock
    extends AbstractBlock
    RecurrentBlock is an abstract implementation of recurrent neural networks.

    Recurrent neural networks are neural networks with hidden states. They are very popular for natural language processing tasks, and other tasks which involve sequential data.

    This [article](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) written by Andrej Karpathy provides a detailed explanation of recurrent neural networks.

    Currently, vanilla RNN, LSTM and GRU are implemented, with both multi-layer and bidirectional support.

    • Field Detail

      • stateSize

        protected long stateSize
      • dropRate

        protected float dropRate
      • numLayers

        protected int numLayers
      • gates

        protected int gates
      • batchFirst

        protected boolean batchFirst
      • hasBiases

        protected boolean hasBiases
      • bidirectional

        protected boolean bidirectional
      • returnState

        protected boolean returnState
    • Constructor Detail

      • RecurrentBlock

        public RecurrentBlock​(RecurrentBlock.BaseBuilder<?> builder)
        Creates a RecurrentBlock object.
        Parameters:
        builder - the Builder that has the necessary configurations
    • Method Detail

      • getOutputShapes

        public Shape[] getOutputShapes​(Shape[] inputs)
        Returns the expected output shapes of the block for the specified input shapes.
        Parameters:
        inputs - the shapes of the inputs
        Returns:
        the expected output shapes of the block
      • beforeInitialize

        protected void beforeInitialize​(Shape... inputShapes)
        Performs any action necessary before initialization. For example, keep the input information or verify the layout.
        Overrides:
        beforeInitialize in class AbstractBaseBlock
        Parameters:
        inputShapes - the expected shapes of the input
      • loadMetadata

        public void loadMetadata​(byte loadVersion,
                                 java.io.DataInputStream is)
                          throws java.io.IOException,
                                 MalformedModelException
        Overwrite this to load additional metadata with the parameter values.

        If you overwrite AbstractBaseBlock.saveMetadata(DataOutputStream) or need to provide backward compatibility to older binary formats, you prabably need to overwrite this. This default implementation checks if the version number fits, if not it throws an MalformedModelException. After that it restores the input shapes.

        Overrides:
        loadMetadata in class AbstractBaseBlock
        Parameters:
        loadVersion - the version used for loading this metadata.
        is - the input stream we are loading from
        Throws:
        java.io.IOException - loading failed
        MalformedModelException - data can be loaded but has wrong format
      • getNumDirections

        protected int getNumDirections()