Package ai.djl.ndarray.types
Class Shape
- java.lang.Object
-
- ai.djl.ndarray.types.Shape
-
public class Shape extends java.lang.Object
A class that presents theNDArray
's shape information.
-
-
Constructor Summary
Constructors Constructor Description Shape(long... shape)
Constructs and initializes aShape
with specified dimension as(long... shape)
.Shape(long[] shape, LayoutType[] layout)
Constructs and initializes aShape
with specified dimension and layout.Shape(long[] shape, java.lang.String layout)
Constructs and initializes aShape
with specified dimension and layout.Shape(ai.djl.util.PairList<java.lang.Long,LayoutType> shape)
Constructs and initializes aShape
with specified shape and layout pairList.Shape(java.util.List<java.lang.Long> shape)
Constructs and initializes aShape
with specified dimension.
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description Shape
add(long... axes)
Joins this shape with axes.Shape
addAll(Shape other)
Joins this shape with specifiedother
shape.static Shape
decode(java.io.DataInputStream dis)
Decodes the data in the givenDataInputStream
and converts it into the correspondingShape
object.int
dimension()
Returns the number of dimensions of thisShape
.boolean
equals(java.lang.Object o)
Shape
filterByLayoutType(java.util.function.Predicate<LayoutType> predicate)
Returns only the axes of the Shape whose layout types match the predicate.long
get(int dimension)
Returns the shape in the given dimension.byte[]
getEncoded()
Gets the byte array representation of thisShape
for serialization.LayoutType[]
getLayout()
Returns the layout type for each axis in this shape.LayoutType
getLayoutType(int dimension)
Returns the layout type in the given dimension.int
getLeadingOnes()
Returns the number of leading ones in the array shape.long[]
getShape()
Returns the dimensions of theShape
.int
getTrailingOnes()
Returns the number of trailing ones in the array shape.long
getUnknownValueCount()
Return the count of unknown value in thisShape
.int
hashCode()
boolean
hasZeroDimension()
Returnstrue
if the NDArray contains zero dimensions.long
head()
Returns the head index of the shape.boolean
isLayoutKnown()
Returnstrue
if a layout is set.boolean
isScalar()
Returnstrue
if the NDArray is a scalar.Shape
map(java.util.function.Function<ai.djl.util.Pair<java.lang.Long,LayoutType>,ai.djl.util.Pair<java.lang.Long,LayoutType>> mapper)
Returns a mapped shape.long
size()
Returns the total size.long
size(int... dimensions)
Returns the size of a specific dimension or several specific dimensions.Shape
slice(int beginIndex)
Creates a newShape
whose content is a slice of this shape.Shape
slice(int beginIndex, int endIndex)
Creates a newShape
whose content is a slice of this shape.java.util.stream.Stream<ai.djl.util.Pair<java.lang.Long,LayoutType>>
stream()
Returns a stream of the Shape.long
tail()
Returns the tail index of the shape.java.lang.String
toLayoutString()
Returns the string layout type for each axis in this shape.java.lang.String
toString()
static Shape
update(Shape shape, int dimension, long value)
Returns a new shape altering the given dimension.
-
-
-
Constructor Detail
-
Shape
public Shape(long... shape)
Constructs and initializes aShape
with specified dimension as(long... shape)
.- Parameters:
shape
- the dimensions of the shape- Throws:
java.lang.IllegalArgumentException
- Thrown if any element in Shape is invalid. It should not be less than -1. Also thrown if the shape and layout do not have equal sizes.
-
Shape
public Shape(java.util.List<java.lang.Long> shape)
Constructs and initializes aShape
with specified dimension.- Parameters:
shape
- the dimensions of the shape- Throws:
java.lang.IllegalArgumentException
- Thrown if any element in Shape is invalid. It should not be less than -1. Also thrown if the shape and layout do not have equal sizes.
-
Shape
public Shape(ai.djl.util.PairList<java.lang.Long,LayoutType> shape)
Constructs and initializes aShape
with specified shape and layout pairList.- Parameters:
shape
- the dimensions and layout of the shape- Throws:
java.lang.IllegalArgumentException
- Thrown if any element in Shape is invalid. It should not be less than -1 .Also thrown if the shape and layout do not have equal sizes.
-
Shape
public Shape(long[] shape, java.lang.String layout)
Constructs and initializes aShape
with specified dimension and layout.- Parameters:
shape
- the size of each axis of the shapelayout
- theLayoutType
of each axis in the shape- Throws:
java.lang.IllegalArgumentException
- Thrown if any element in Shape is invalid. It should not be less than -1. Also thrown for an invalid layout. Also thrown if the shape and layout do not have equal sizes.
-
Shape
public Shape(long[] shape, LayoutType[] layout)
Constructs and initializes aShape
with specified dimension and layout.- Parameters:
shape
- the size of each axis of the shapelayout
- theLayoutType
of each axis in the shape- Throws:
java.lang.IllegalArgumentException
- Thrown if any element in Shape is invalid. It should not be less than -1. Also thrown if the shape and layout do not have equal sizes.
-
-
Method Detail
-
update
public static Shape update(Shape shape, int dimension, long value)
Returns a new shape altering the given dimension.- Parameters:
shape
- the shape to updatedimension
- the dimension to get the shape invalue
- the value to set the dimension to- Returns:
- a new shape with the update applied
-
getShape
public long[] getShape()
Returns the dimensions of theShape
.- Returns:
- the dimensions of the
Shape
-
get
public long get(int dimension)
Returns the shape in the given dimension.- Parameters:
dimension
- the dimension to get the shape in- Returns:
- the shape in the given dimension
-
getLayoutType
public LayoutType getLayoutType(int dimension)
Returns the layout type in the given dimension.- Parameters:
dimension
- the dimension to get the layout type in- Returns:
- the layout type in the given dimension
-
size
public long size(int... dimensions)
Returns the size of a specific dimension or several specific dimensions.- Parameters:
dimensions
- the dimension or dimensions to find the size of- Returns:
- the size of specific dimension(s) or -1 for indeterminate size
- Throws:
java.lang.IllegalArgumentException
- thrown if passed an invalid dimension
-
size
public long size()
Returns the total size.- Returns:
- the total size or -1 for indeterminate size
-
dimension
public int dimension()
Returns the number of dimensions of thisShape
.- Returns:
- the number of dimensions of this
Shape
-
getUnknownValueCount
public long getUnknownValueCount()
Return the count of unknown value in thisShape
.- Returns:
- the number of unknown value in this
Shape
-
slice
public Shape slice(int beginIndex)
Creates a newShape
whose content is a slice of this shape.The sub shape begins at the specified
beginIndex
and extends toendIndex - 1
.- Parameters:
beginIndex
- the beginning index, inclusive- Returns:
- a new
Shape
whose content is a slice of this shape
-
slice
public Shape slice(int beginIndex, int endIndex)
Creates a newShape
whose content is a slice of this shape.The sub shape begins at the specified
beginIndex
and extends toendIndex - 1
.- Parameters:
beginIndex
- the beginning index, inclusiveendIndex
- the ending index, exclusive- Returns:
- a new
Shape
whose content is a slice of this shape
-
filterByLayoutType
public Shape filterByLayoutType(java.util.function.Predicate<LayoutType> predicate)
Returns only the axes of the Shape whose layout types match the predicate.- Parameters:
predicate
- the predicate to compare the axes of the Shape with- Returns:
- a new filtered Shape
-
map
public Shape map(java.util.function.Function<ai.djl.util.Pair<java.lang.Long,LayoutType>,ai.djl.util.Pair<java.lang.Long,LayoutType>> mapper)
Returns a mapped shape.- Parameters:
mapper
- the function to map each element of the Shape by- Returns:
- a new mapped Shape
-
stream
public java.util.stream.Stream<ai.djl.util.Pair<java.lang.Long,LayoutType>> stream()
Returns a stream of the Shape.- Returns:
- the stream of the Shape
-
add
public Shape add(long... axes)
Joins this shape with axes.- Parameters:
axes
- the axes to join- Returns:
- the joined
Shape
-
addAll
public Shape addAll(Shape other)
Joins this shape with specifiedother
shape.- Parameters:
other
- the shape to join- Returns:
- the joined
Shape
-
head
public long head()
Returns the head index of the shape.- Returns:
- the head index of the shape
- Throws:
java.lang.IndexOutOfBoundsException
- Thrown if the shape is empty
-
tail
public long tail()
Returns the tail index of the shape.- Returns:
- the tail index of the shape
- Throws:
java.lang.IndexOutOfBoundsException
- Thrown if the shape is empty
-
getTrailingOnes
public int getTrailingOnes()
Returns the number of trailing ones in the array shape.For example, a rank 3 array with shape [10, 1, 1] would return 2 for this method
- Returns:
- the number of trailing ones in the shape
-
getLeadingOnes
public int getLeadingOnes()
Returns the number of leading ones in the array shape.For example, a rank 3 array with shape [1, 10, 1] would return value 1 for this method
- Returns:
- the number of leading ones in the shape
-
isScalar
public boolean isScalar()
Returnstrue
if the NDArray is a scalar.- Returns:
- whether the NDArray is a scalar
-
hasZeroDimension
public boolean hasZeroDimension()
Returnstrue
if the NDArray contains zero dimensions.- Returns:
- whether the NDArray contain zero dimensions
-
isLayoutKnown
public boolean isLayoutKnown()
Returnstrue
if a layout is set.- Returns:
- whether a layout has been set
-
getLayout
public LayoutType[] getLayout()
Returns the layout type for each axis in this shape.- Returns:
- the layout type for each axis in this shape
-
toLayoutString
public java.lang.String toLayoutString()
Returns the string layout type for each axis in this shape.- Returns:
- the string layout type for each axis in this shape
-
getEncoded
public byte[] getEncoded()
Gets the byte array representation of thisShape
for serialization.- Returns:
- a byte array representation of this
Shape
-
equals
public boolean equals(java.lang.Object o)
- Overrides:
equals
in classjava.lang.Object
-
hashCode
public int hashCode()
- Overrides:
hashCode
in classjava.lang.Object
-
toString
public java.lang.String toString()
- Overrides:
toString
in classjava.lang.Object
-
decode
public static Shape decode(java.io.DataInputStream dis) throws java.io.IOException
Decodes the data in the givenDataInputStream
and converts it into the correspondingShape
object.- Parameters:
dis
- the inputstream to read from- Returns:
- the corresponding
Shape
object - Throws:
java.io.IOException
- when an I/O error occurs
-
-