Package ai.djl.nn

Class AbstractBlock

    • Field Detail

      • inputShapes

        protected Shape[] inputShapes
        The shape of the input for this block, set by the initialization process.
      • inputNames

        protected java.util.List<java.lang.String> inputNames
        List of names for the input, named inputs should be manually set in sub class.
      • version

        protected byte version
        The model version of this block, used for checking if parameters are still valid during parameter loading.
      • children

        protected BlockList children
        All direct children of this Block. Keys are names of the blocks.

        Use the addChildBlock(String, Block) method to add children. All children in this map are automagically loaded / saved.

      • parameters

        protected java.util.LinkedHashMap<java.lang.String,​Parameter> parameters
        All direct parameters of this Block. Keys are name of the parameters.

        Use the addParameter(Parameter) method to add children. All parameters in this map are automatically loaded / saved.

    • Constructor Detail

      • AbstractBlock

        public AbstractBlock()
        Constructs a new AbstractBlock instance.
      • AbstractBlock

        public AbstractBlock​(byte version)
        Builds an empty block with the given version for parameter serialization.
        Parameters:
        version - the version to use for parameter serialization.
    • Method Detail

      • forward

        public final NDList forward​(ParameterStore parameterStore,
                                    NDList inputs,
                                    boolean training,
                                    ai.djl.util.PairList<java.lang.String,​java.lang.Object> params)
        Applies the operating function of the block once. This method should be called only on blocks that are initialized.
        Specified by:
        forward in interface Block
        Parameters:
        parameterStore - the parameter store
        inputs - the input NDList
        training - true for a training forward pass
        params - optional parameters
        Returns:
        the output of the forward pass
      • forward

        public NDList forward​(ParameterStore parameterStore,
                              NDList data,
                              NDList labels,
                              ai.djl.util.PairList<java.lang.String,​java.lang.Object> params)
        A forward call using both training data and labels.

        Within this forward call, it can be assumed that training is true.

        Specified by:
        forward in interface Block
        Parameters:
        parameterStore - the parameter store
        data - the input data NDList
        labels - the input labels NDList
        params - optional parameters
        Returns:
        the output of the forward pass
        See Also:
        Block.forward(ParameterStore, NDList, boolean, PairList)
      • forwardInternal

        protected abstract NDList forwardInternal​(ParameterStore parameterStore,
                                                  NDList inputs,
                                                  boolean training,
                                                  ai.djl.util.PairList<java.lang.String,​java.lang.Object> params)
        Parameters:
        parameterStore - the parameter store
        inputs - the input NDList
        training - true for a training forward pass
        params - optional parameters
        Returns:
        the output of the forward pass
      • addChildBlock

        protected final <B extends Block> B addChildBlock​(java.lang.String name,
                                                          B block)
        Use this to add a child block to this block.
        Type Parameters:
        B - The type of block
        Parameters:
        name - Name of the block, must be unique or otherwise existing children with this name are removed, must not be null.
        block - The block, must not be null.
        Returns:
        the block given as a parameter - that way the block can be created and reassigned to a member variable more easily.
      • addParameter

        protected final <P extends Parameter> P addParameter​(P parameter)
        Adds a parameter to this block. If parameters are added with this method, intialization of the parameter works out of the box
        Type Parameters:
        P - the specific parameter subclass
        Parameters:
        parameter - the parameter to add, not null
        Returns:
        the parameter passed as arguments to make it easier to create and assign parameters in one line
      • getChildren

        public BlockList getChildren()
        Returns a list of all the children of the block.
        Specified by:
        getChildren in interface Block
        Returns:
        the list of child blocks
      • describeInput

        public ai.djl.util.PairList<java.lang.String,​Shape> describeInput()
        Returns a PairList of input names, and shapes.
        Specified by:
        describeInput in interface Block
        Returns:
        the PairList of input names, and shapes
      • setInitializer

        public void setInitializer​(Initializer initializer,
                                   Parameter.Type params)
        Sets an Initializer to all the parameters that match parameter type in the block.
        Specified by:
        setInitializer in interface Block
        Parameters:
        initializer - the initializer to set
        params - the Parameter Type we want to setInitializer
      • setInitializer

        public void setInitializer​(Initializer initializer,
                                   java.lang.String paramName)
        Sets an Initializer to the specified direct parameter of the block, overriding the initializer of the parameter, if already set.
        Specified by:
        setInitializer in interface Block
        Parameters:
        initializer - the initializer to be set
        paramName - the name of the parameter
      • setInitializer

        public void setInitializer​(Initializer initializer,
                                   java.util.function.Predicate<Parameter> predicate)
        Sets an Initializer to all the parameters that match Predicate in the block.
        Specified by:
        setInitializer in interface Block
        Parameters:
        initializer - the initializer to be set
        predicate - predicate function to indicate parameters you want to set
      • initialize

        public void initialize​(NDManager manager,
                               DataType dataType,
                               Shape... inputShapes)
        Initializes the parameters of the block. This method must be called before calling `forward`.
        Specified by:
        initialize in interface Block
        Parameters:
        manager - the NDManager to initialize the parameters
        dataType - the datatype of the parameters
        inputShapes - the shapes of the inputs to the block
      • beforeInitialize

        protected void beforeInitialize​(Shape... inputShapes)
        Performs any action necessary before initialization. For example, keep the input information or verify the layout.
        Parameters:
        inputShapes - the expected shapes of the input
      • initializeChildBlocks

        protected void initializeChildBlocks​(NDManager manager,
                                             DataType dataType,
                                             Shape... inputShapes)
        Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.
        Parameters:
        manager - the manager to use for initialization
        dataType - the requested data type
        inputShapes - the expected input shapes for this block
      • getParameters

        public ParameterList getParameters()
        Returns a list of all the parameters of the block, including the parameters of its children fetched recursively.
        Specified by:
        getParameters in interface Block
        Returns:
        the list of all parameters of the block
      • prepare

        protected void prepare​(Shape[] inputShapes)
        Sets the shape of Parameters.
        Parameters:
        inputShapes - the shapes of inputs
      • isInitialized

        public boolean isInitialized()
        Returns a boolean whether the block is initialized.
        Specified by:
        isInitialized in interface Block
        Returns:
        whether the block is initialized
      • clear

        public void clear()
        Closes all the parameters of the block. All the updates made during training will be lost.
        Specified by:
        clear in interface Block
      • cast

        public void cast​(DataType dataType)
        Guaranteed to throw an exception. Not yet implemented
        Specified by:
        cast in interface Block
        Parameters:
        dataType - the data type to cast to
      • saveParameters

        public void saveParameters​(java.io.DataOutputStream os)
                            throws java.io.IOException
        Writes the parameters of the block to the given outputStream.
        Specified by:
        saveParameters in interface Block
        Parameters:
        os - the outputstream to save the parameters to
        Throws:
        java.io.IOException - if an I/O error occurs
      • loadParameters

        public void loadParameters​(NDManager manager,
                                   java.io.DataInputStream is)
                            throws java.io.IOException,
                                   MalformedModelException
        Loads the parameters from the given input stream.
        Specified by:
        loadParameters in interface Block
        Parameters:
        manager - an NDManager to create the parameter arrays
        is - the inputstream that stream the parameter values
        Throws:
        java.io.IOException - if an I/O error occurs
        MalformedModelException - if the model file is corrupted or unsupported
      • saveMetadata

        protected void saveMetadata​(java.io.DataOutputStream os)
                             throws java.io.IOException
        Override this method to save additional data apart from parameter values.

        This default implementation saves the currently set input shapes.

        Parameters:
        os - the non-null output stream the parameter values and metadata are written to
        Throws:
        java.io.IOException - saving failed
      • loadMetadata

        protected void loadMetadata​(byte loadVersion,
                                    java.io.DataInputStream is)
                             throws java.io.IOException,
                                    MalformedModelException
        Overwrite this to load additional metadata with the parameter values.

        If you overwrite saveMetadata(DataOutputStream) or need to provide backward compatibility to older binary formats, you prabably need to overwrite this. This default implementation checks if the version number fits, if not it throws an MalformedModelException. After that it restores the input shapes.

        Parameters:
        loadVersion - the version used for loading this metadata.
        is - the input stream we are loading from
        Throws:
        java.io.IOException - loading failed
        MalformedModelException - data can be loaded but has wrong format
      • saveInputShapes

        protected void saveInputShapes​(java.io.DataOutputStream os)
                                throws java.io.IOException
        Throws:
        java.io.IOException
      • readInputShapes

        protected void readInputShapes​(java.io.DataInputStream is)
                                throws java.io.IOException
        Throws:
        java.io.IOException
      • toString

        public java.lang.String toString()
        Overrides:
        toString in class java.lang.Object