public class Linear extends AbstractBlock
It has the following shapes:
The Linear block should be constructed using Linear.Builder
.
Modifier and Type | Class and Description |
---|---|
static class |
Linear.Builder
|
children, inputNames, inputShapes, parameters, parameterShapeCallbacks, version
Modifier and Type | Method and Description |
---|---|
void |
beforeInitialize(Shape[] inputShapes)
Performs any action necessary before initialization.
|
static Linear.Builder |
builder()
Creates a builder to build a
Linear . |
ai.djl.util.PairList<java.lang.String,Shape> |
describeInput()
Returns a
PairList of input names, and shapes. |
NDList |
forward(ParameterStore parameterStore,
NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Applies the operating function of the block once.
|
Shape[] |
getOutputShapes(NDManager manager,
Shape[] inputs)
Returns the expected output shapes of the block for the specified input shapes.
|
static NDList |
linear(NDArray input,
NDArray weight)
Applies a linear transformation to the incoming data.
|
static NDList |
linear(NDArray input,
NDArray weight,
NDArray bias)
Applies a linear transformation to the incoming data.
|
void |
loadMetadata(byte version,
java.io.DataInputStream is)
Overwrite this to load additional metadata with the parameter values.
|
protected void |
saveMetadata(java.io.DataOutputStream os)
Override this method to save additional data apart from parameter values.
|
addChildBlock, addParameter, addParameter, addParameter, cast, clear, getChildren, getDirectParameters, getParameters, getParameterShape, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveParameters, setInitializer, setInitializer, toString
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
forward, forward, validateLayout
public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameterspublic Shape[] getOutputShapes(NDManager manager, Shape[] inputs)
manager
- an NDManagerinputs
- the shapes of the inputspublic ai.djl.util.PairList<java.lang.String,Shape> describeInput()
PairList
of input names, and shapes.describeInput
in interface Block
describeInput
in class AbstractBlock
PairList
of input names, and shapespublic void beforeInitialize(Shape[] inputShapes)
beforeInitialize
in class AbstractBlock
inputShapes
- the expected shapes of the inputprotected void saveMetadata(java.io.DataOutputStream os) throws java.io.IOException
This default implementation saves the currently set input shapes.
saveMetadata
in class AbstractBlock
os
- the non-null output stream the parameter values and metadata are written tojava.io.IOException
- saving failedpublic void loadMetadata(byte version, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
If you overwrite AbstractBlock.saveMetadata(DataOutputStream)
or need to provide
backward compatibility to older binary formats, you prabably need to overwrite this. This
default implementation checks if the version number fits, if not it throws an MalformedModelException
. After that it restores the input shapes.
loadMetadata
in class AbstractBlock
version
- the version used for loading this metadata.is
- the input stream we are loading fromjava.io.IOException
- loading failedMalformedModelException
- data can be loaded but has wrong formatpublic static NDList linear(NDArray input, NDArray weight)
input
- input X: [x1, x2, …, xn, input_dim]weight
- weight W: [units, input_dim]public static NDList linear(NDArray input, NDArray weight, NDArray bias)
input
- input X: [x1, x2, …, xn, input_dim]weight
- weight W: [units, input_dim]bias
- bias b: [units]public static Linear.Builder builder()
Linear
.