Class KerasEmbedding
- java.lang.Object
-
- org.deeplearning4j.nn.modelimport.keras.KerasLayer
-
- org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding
-
public class KerasEmbedding extends KerasLayer
Imports an Embedding layer from Keras.- Author:
- [email protected]
-
-
Nested Class Summary
-
Nested classes/interfaces inherited from class org.deeplearning4j.nn.modelimport.keras.KerasLayer
KerasLayer.DimOrder
-
-
Field Summary
-
Fields inherited from class org.deeplearning4j.nn.modelimport.keras.KerasLayer
className, conf, dimOrder, dropout, inboundLayerNames, inputShape, kerasMajorVersion, layer, layerName, originalLayerConfig, outboundLayerNames, vertex, weightL1Regularization, weightL2Regularization, weights
-
-
Constructor Summary
Constructors Constructor Description KerasEmbedding()
Pass through constructor for unit testsKerasEmbedding(Map<String,Object> layerConfig)
Constructor from parsed Keras layer configuration dictionary.KerasEmbedding(Map<String,Object> layerConfig, boolean enforceTrainingConfig)
Constructor from parsed Keras layer configuration dictionary.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description EmbeddingSequenceLayer
getEmbeddingLayer()
Get DL4J Embedding Sequence layer.int
getNumParams()
Returns number of trainable parameters in layer.InputType
getOutputType(InputType... inputType)
Get layer output type.void
setWeights(Map<String,INDArray> weights)
Set weights for layer.-
Methods inherited from class org.deeplearning4j.nn.modelimport.keras.KerasLayer
clearCustomLayers, clearLambdaLayers, copyWeightsToLayer, getClassName, getDimOrder, getInboundLayerNames, getInputPreprocessor, getInputShape, getKerasMajorVersion, getLayer, getLayerName, getNInFromConfig, getVertex, getWeights, isInputPreProcessor, isLayer, isValidInboundLayer, isVertex, registerCustomLayer, registerLambdaLayer, setInboundLayerNames, setLayer, usesRegularization
-
-
-
-
Constructor Detail
-
KerasEmbedding
public KerasEmbedding() throws UnsupportedKerasConfigurationException
Pass through constructor for unit tests- Throws:
UnsupportedKerasConfigurationException
- Unsupported Keras config
-
KerasEmbedding
public KerasEmbedding(Map<String,Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
Constructor from parsed Keras layer configuration dictionary.- Parameters:
layerConfig
- dictionary containing Keras layer configuration- Throws:
InvalidKerasConfigurationException
- Invalid Keras configUnsupportedKerasConfigurationException
- Unsupported Keras config
-
KerasEmbedding
public KerasEmbedding(Map<String,Object> layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException
Constructor from parsed Keras layer configuration dictionary.- Parameters:
layerConfig
- dictionary containing Keras layer configurationenforceTrainingConfig
- whether to enforce training-related configuration options- Throws:
InvalidKerasConfigurationException
- Invalid Keras configUnsupportedKerasConfigurationException
- Unsupported Keras config
-
-
Method Detail
-
getEmbeddingLayer
public EmbeddingSequenceLayer getEmbeddingLayer()
Get DL4J Embedding Sequence layer.- Returns:
- Embedding Sequence layer
-
getOutputType
public InputType getOutputType(InputType... inputType) throws InvalidKerasConfigurationException
Get layer output type.- Overrides:
getOutputType
in classKerasLayer
- Parameters:
inputType
- Array of InputTypes- Returns:
- output type as InputType
- Throws:
InvalidKerasConfigurationException
- Invalid Keras config
-
getNumParams
public int getNumParams()
Returns number of trainable parameters in layer.- Overrides:
getNumParams
in classKerasLayer
- Returns:
- number of trainable parameters (1)
-
setWeights
public void setWeights(Map<String,INDArray> weights) throws InvalidKerasConfigurationException
Set weights for layer.- Overrides:
setWeights
in classKerasLayer
- Parameters:
weights
- Embedding layer weights- Throws:
InvalidKerasConfigurationException
-
-