public class Linear extends AbstractBlock
It has the following shapes:
The Linear block should be constructed using Linear.Builder.
| Modifier and Type | Class and Description |
|---|---|
static class |
Linear.Builder
|
children, inputNames, inputShapes, parameters, parameterShapeCallbacks, version| Modifier and Type | Method and Description |
|---|---|
void |
beforeInitialize(Shape[] inputShapes)
Performs any action necessary before initialization.
|
static Linear.Builder |
builder()
Creates a builder to build a
Linear. |
ai.djl.util.PairList<java.lang.String,Shape> |
describeInput()
Returns a
PairList of input names, and shapes. |
protected NDList |
forwardInternal(ParameterStore parameterStore,
NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params) |
Shape[] |
getOutputShapes(NDManager manager,
Shape[] inputs)
Returns the expected output shapes of the block for the specified input shapes.
|
static NDList |
linear(NDArray input,
NDArray weight)
Applies a linear transformation to the incoming data.
|
static NDList |
linear(NDArray input,
NDArray weight,
NDArray bias)
Applies a linear transformation to the incoming data.
|
void |
loadMetadata(byte version,
java.io.DataInputStream is)
Overwrite this to load additional metadata with the parameter values.
|
protected void |
saveMetadata(java.io.DataOutputStream os)
Override this method to save additional data apart from parameter values.
|
addChildBlock, addParameter, addParameter, addParameter, cast, clear, forward, getChildren, getDirectParameters, getParameters, getParameterShape, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveParameters, setInitializer, setInitializer, toStringclone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitforward, forward, validateLayoutprotected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
forwardInternal in class AbstractBlockpublic Shape[] getOutputShapes(NDManager manager, Shape[] inputs)
manager - an NDManagerinputs - the shapes of the inputspublic ai.djl.util.PairList<java.lang.String,Shape> describeInput()
PairList of input names, and shapes.describeInput in interface BlockdescribeInput in class AbstractBlockPairList of input names, and shapespublic void beforeInitialize(Shape[] inputShapes)
beforeInitialize in class AbstractBlockinputShapes - the expected shapes of the inputprotected void saveMetadata(java.io.DataOutputStream os)
throws java.io.IOException
This default implementation saves the currently set input shapes.
saveMetadata in class AbstractBlockos - the non-null output stream the parameter values and metadata are written tojava.io.IOException - saving failedpublic void loadMetadata(byte version,
java.io.DataInputStream is)
throws java.io.IOException,
MalformedModelException
If you overwrite AbstractBlock.saveMetadata(DataOutputStream) or need to provide
backward compatibility to older binary formats, you prabably need to overwrite this. This
default implementation checks if the version number fits, if not it throws an MalformedModelException. After that it restores the input shapes.
loadMetadata in class AbstractBlockversion - the version used for loading this metadata.is - the input stream we are loading fromjava.io.IOException - loading failedMalformedModelException - data can be loaded but has wrong formatpublic static NDList linear(NDArray input, NDArray weight)
input - input X: [x1, x2, …, xn, input_dim]weight - weight W: [units, input_dim]public static NDList linear(NDArray input, NDArray weight, NDArray bias)
input - input X: [x1, x2, …, xn, input_dim]weight - weight W: [units, input_dim]bias - bias b: [units]public static Linear.Builder builder()
Linear.