Class Shape

java.lang.Object
ai.djl.ndarray.types.Shape

public class Shape extends Object
A class that presents the NDArray's shape information.
  • Constructor Summary

    Constructors
    Constructor
    Description
    Shape(long... shape)
    Constructs and initializes a Shape with specified dimension as (long... shape).
    Shape(long[] shape, LayoutType[] layout)
    Constructs and initializes a Shape with specified dimension and layout.
    Shape(long[] shape, String layout)
    Constructs and initializes a Shape with specified dimension and layout.
    Shape(ai.djl.util.PairList<Long,LayoutType> shape)
    Constructs and initializes a Shape with specified shape and layout pairList.
    Shape(List<Long> shape)
    Constructs and initializes a Shape with specified dimension.
  • Method Summary

    Modifier and Type
    Method
    Description
    add(long... axes)
    Joins this shape with axes.
    addAll(Shape other)
    Joins this shape with specified other shape.
    static Shape
    Decodes the data in the given DataInputStream and converts it into the corresponding Shape object.
    static Shape
    Decodes the data in the given ByteBuffer and converts it into the corresponding Shape object.
    int
    Returns the number of dimensions of this Shape.
    boolean
    Returns only the axes of the Shape whose layout types match the predicate.
    long
    get(int dimension)
    Returns the shape in the given dimension.
    byte[]
    Gets the byte array representation of this Shape for serialization.
    long
    Returns the last index.
    Returns the layout type for each axis in this shape.
    getLayoutType(int dimension)
    Returns the layout type in the given dimension.
    int
    Returns the number of leading ones in the array shape.
    long[]
    Returns the dimensions of the Shape.
    int
    Returns the number of trailing ones in the array shape.
    long
    Return the count of unknown value in this Shape.
    int
    boolean
    Returns true if the NDArray contains zero dimensions.
    long
    Returns the head index of the shape.
    boolean
    Returns true if a layout is set.
    boolean
    Returns if the array is rank-1 which is inferred from the shape.
    boolean
    Returns true if the NDArray is a scalar.
    map(Function<ai.djl.util.Pair<Long,LayoutType>,ai.djl.util.Pair<Long,LayoutType>> mapper)
    Returns a mapped shape.
    long
    Returns the total size.
    long
    size(int... dimensions)
    Returns the size of a specific dimension or several specific dimensions.
    slice(int beginIndex)
    Creates a new Shape whose content is a slice of this shape.
    slice(int beginIndex, int endIndex)
    Creates a new Shape whose content is a slice of this shape.
    Stream<ai.djl.util.Pair<Long,LayoutType>>
    Returns a stream of the Shape.
    long
    Returns the tail index of the shape.
    Returns the string layout type for each axis in this shape.
    static Shape
    update(Shape shape, int dimension, long value)
    Returns a new shape altering the given dimension.

    Methods inherited from class java.lang.Object

    clone, finalize, getClass, notify, notifyAll, wait, wait, wait
  • Constructor Details

    • Shape

      public Shape(long... shape)
      Constructs and initializes a Shape with specified dimension as (long... shape).
      Parameters:
      shape - the dimensions of the shape
      Throws:
      IllegalArgumentException - Thrown if any element in Shape is invalid. It should not be less than -1. Also thrown if the shape and layout do not have equal sizes.
    • Shape

      public Shape(List<Long> shape)
      Constructs and initializes a Shape with specified dimension.
      Parameters:
      shape - the dimensions of the shape
      Throws:
      IllegalArgumentException - Thrown if any element in Shape is invalid. It should not be less than -1. Also thrown if the shape and layout do not have equal sizes.
    • Shape

      public Shape(ai.djl.util.PairList<Long,LayoutType> shape)
      Constructs and initializes a Shape with specified shape and layout pairList.
      Parameters:
      shape - the dimensions and layout of the shape
      Throws:
      IllegalArgumentException - Thrown if any element in Shape is invalid. It should not be less than -1 .Also thrown if the shape and layout do not have equal sizes.
    • Shape

      public Shape(long[] shape, String layout)
      Constructs and initializes a Shape with specified dimension and layout.
      Parameters:
      shape - the size of each axis of the shape
      layout - the LayoutType of each axis in the shape
      Throws:
      IllegalArgumentException - Thrown if any element in Shape is invalid. It should not be less than -1. Also thrown for an invalid layout. Also thrown if the shape and layout do not have equal sizes.
    • Shape

      public Shape(long[] shape, LayoutType[] layout)
      Constructs and initializes a Shape with specified dimension and layout.
      Parameters:
      shape - the size of each axis of the shape
      layout - the LayoutType of each axis in the shape
      Throws:
      IllegalArgumentException - Thrown if any element in Shape is invalid. It should not be less than -1. Also thrown if the shape and layout do not have equal sizes.
  • Method Details

    • update

      public static Shape update(Shape shape, int dimension, long value)
      Returns a new shape altering the given dimension.
      Parameters:
      shape - the shape to update
      dimension - the dimension to get the shape in
      value - the value to set the dimension to
      Returns:
      a new shape with the update applied
    • getShape

      public long[] getShape()
      Returns the dimensions of the Shape.
      Returns:
      the dimensions of the Shape
    • get

      public long get(int dimension)
      Returns the shape in the given dimension.
      Parameters:
      dimension - the dimension to get the shape in
      Returns:
      the shape in the given dimension
    • getLastDimension

      public long getLastDimension()
      Returns the last index.
      Returns:
      the last index
    • getLayoutType

      public LayoutType getLayoutType(int dimension)
      Returns the layout type in the given dimension.
      Parameters:
      dimension - the dimension to get the layout type in
      Returns:
      the layout type in the given dimension
    • size

      public long size(int... dimensions)
      Returns the size of a specific dimension or several specific dimensions.
      Parameters:
      dimensions - the dimension or dimensions to find the size of
      Returns:
      the size of specific dimension(s) or -1 for indeterminate size
      Throws:
      IllegalArgumentException - thrown if passed an invalid dimension
    • size

      public long size()
      Returns the total size.
      Returns:
      the total size or -1 for indeterminate size
    • dimension

      public int dimension()
      Returns the number of dimensions of this Shape.
      Returns:
      the number of dimensions of this Shape
    • getUnknownValueCount

      public long getUnknownValueCount()
      Return the count of unknown value in this Shape.
      Returns:
      the number of unknown value in this Shape
    • slice

      public Shape slice(int beginIndex)
      Creates a new Shape whose content is a slice of this shape.

      The sub shape begins at the specified beginIndex and extends to endIndex - 1.

      Parameters:
      beginIndex - the beginning index, inclusive
      Returns:
      a new Shape whose content is a slice of this shape
    • slice

      public Shape slice(int beginIndex, int endIndex)
      Creates a new Shape whose content is a slice of this shape.

      The sub shape begins at the specified beginIndex and extends to endIndex - 1.

      Parameters:
      beginIndex - the beginning index, inclusive
      endIndex - the ending index, exclusive
      Returns:
      a new Shape whose content is a slice of this shape
    • filterByLayoutType

      public Shape filterByLayoutType(Predicate<LayoutType> predicate)
      Returns only the axes of the Shape whose layout types match the predicate.
      Parameters:
      predicate - the predicate to compare the axes of the Shape with
      Returns:
      a new filtered Shape
    • map

      public Shape map(Function<ai.djl.util.Pair<Long,LayoutType>,ai.djl.util.Pair<Long,LayoutType>> mapper)
      Returns a mapped shape.
      Parameters:
      mapper - the function to map each element of the Shape by
      Returns:
      a new mapped Shape
    • stream

      public Stream<ai.djl.util.Pair<Long,LayoutType>> stream()
      Returns a stream of the Shape.
      Returns:
      the stream of the Shape
    • add

      public Shape add(long... axes)
      Joins this shape with axes.
      Parameters:
      axes - the axes to join
      Returns:
      the joined Shape
    • addAll

      public Shape addAll(Shape other)
      Joins this shape with specified other shape.
      Parameters:
      other - the shape to join
      Returns:
      the joined Shape
    • head

      public long head()
      Returns the head index of the shape.
      Returns:
      the head index of the shape
      Throws:
      IndexOutOfBoundsException - Thrown if the shape is empty
    • tail

      public long tail()
      Returns the tail index of the shape.
      Returns:
      the tail index of the shape
      Throws:
      IndexOutOfBoundsException - Thrown if the shape is empty
    • getTrailingOnes

      public int getTrailingOnes()
      Returns the number of trailing ones in the array shape.

      For example, a rank 3 array with shape [10, 1, 1] would return 2 for this method

      Returns:
      the number of trailing ones in the shape
    • getLeadingOnes

      public int getLeadingOnes()
      Returns the number of leading ones in the array shape.

      For example, a rank 3 array with shape [1, 10, 1] would return value 1 for this method

      Returns:
      the number of leading ones in the shape
    • isScalar

      public boolean isScalar()
      Returns true if the NDArray is a scalar.
      Returns:
      whether the NDArray is a scalar
    • hasZeroDimension

      public boolean hasZeroDimension()
      Returns true if the NDArray contains zero dimensions.
      Returns:
      whether the NDArray contain zero dimensions
    • isLayoutKnown

      public boolean isLayoutKnown()
      Returns true if a layout is set.
      Returns:
      whether a layout has been set
    • getLayout

      public LayoutType[] getLayout()
      Returns the layout type for each axis in this shape.
      Returns:
      the layout type for each axis in this shape
    • toLayoutString

      public String toLayoutString()
      Returns the string layout type for each axis in this shape.
      Returns:
      the string layout type for each axis in this shape
    • getEncoded

      public byte[] getEncoded()
      Gets the byte array representation of this Shape for serialization.
      Returns:
      a byte array representation of this Shape
    • equals

      public boolean equals(Object o)
      Overrides:
      equals in class Object
    • hashCode

      public int hashCode()
      Overrides:
      hashCode in class Object
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • decode

      public static Shape decode(DataInputStream dis) throws IOException
      Decodes the data in the given DataInputStream and converts it into the corresponding Shape object.
      Parameters:
      dis - the inputstream to read from
      Returns:
      the corresponding Shape object
      Throws:
      IOException - when an I/O error occurs
    • decode

      public static Shape decode(ByteBuffer bb)
      Decodes the data in the given ByteBuffer and converts it into the corresponding Shape object.
      Parameters:
      bb - the ByteBuffer to read from
      Returns:
      the corresponding Shape object
    • isRankOne

      public boolean isRankOne()
      Returns if the array is rank-1 which is inferred from the shape.

      For example, an array with shape [1, 10, 1] returns true. Array with indeterminate size -1 returns false.

      Returns:
      if the array is rank-1