Class ResNetV1
- java.lang.Object
-
- ai.djl.basicmodelzoo.cv.classification.ResNetV1
-
public final class ResNetV1 extends java.lang.Object
ResNetV1
contains a generic implementation of ResNet adapted from https://github.com/tornadomeet/ResNet/blob/master/symbol_resnet.py (Original author Wei Wu) by Antti-Pekka Hynninen.Implementing the original resnet ILSVRC 2015 winning network from Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Deep Residual Learning for Image Recognition"
- See Also:
- The D2L chapter on ResNet
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
ResNetV1.Builder
The Builder to construct aResNetV1
object.
-
Method Summary
All Methods Static Methods Concrete Methods Modifier and Type Method Description static ResNetV1.Builder
builder()
Creates a builder to build aResNetV1
.static ai.djl.nn.Block
residualUnit(int numFilters, ai.djl.ndarray.types.Shape stride, boolean dimMatch, boolean bottleneck, float batchNormMomentum)
Builds aBlock
that represents a residual unit used in the implementation of the Resnet model.static ai.djl.nn.SequentialBlock
resnet(ResNetV1.Builder builder)
-
-
-
Method Detail
-
residualUnit
public static ai.djl.nn.Block residualUnit(int numFilters, ai.djl.ndarray.types.Shape stride, boolean dimMatch, boolean bottleneck, float batchNormMomentum)
Builds aBlock
that represents a residual unit used in the implementation of the Resnet model.- Parameters:
numFilters
- the number of output channelsstride
- the stride of the convolution in each dimensiondimMatch
- whether the number of channels between input and output has to remain the samebottleneck
- whether to use bottleneck architecturebatchNormMomentum
- the momentum to be used forBatchNorm
- Returns:
- a
Block
that represents a residual unit
-
resnet
public static ai.djl.nn.SequentialBlock resnet(ResNetV1.Builder builder)
- Parameters:
builder
- theResNetV1.Builder
with the necessary arguments- Returns:
- a
Block
that represents the required ResNet model
-
builder
public static ResNetV1.Builder builder()
Creates a builder to build aResNetV1
.- Returns:
- a new builder
-
-