Class NDIndex

java.lang.Object
ai.djl.ndarray.index.NDIndex

public class NDIndex extends Object
The NDIndex allows you to specify a subset of an NDArray that can be used for fetching or updating.

It accepts a different index option for each dimension, given in the order of the dimensions. Each dimension has options corresponding to:

  • Return all dimensions - Pass null to addIndices
  • A single value in the dimension - Pass the value to addIndices with a negative index -i corresponding to [dimensionLength - i]
  • A range of values - Use addSliceDim

We recommend creating the NDIndex using NDIndex(String, Object...).

See Also:
  • Constructor Details

    • NDIndex

      public NDIndex()
      Creates an empty NDIndex to append values to.
    • NDIndex

      public NDIndex(String indices, Object... args)
      Creates a NDIndex given the index values.

      Here are some examples of the indices format.

           NDArray a = manager.ones(new Shape(5, 4, 3));
      
           // Gets a subsection of the NDArray in the first axis.
           assertEquals(a.get(new NDIndex("2")).getShape(), new Shape(4, 3));
      
           // Gets a subsection of the NDArray indexing from the end (-i == length - i).
           assertEquals(a.get(new NDIndex("-2")).getShape(), new Shape(4, 3));
      
           // Gets everything in the first axis and a subsection in the second axis.
           // You can use either : or * to represent everything
           assertEquals(a.get(new NDIndex(":, 2")).getShape(), new Shape(5, 3));
           assertEquals(a.get(new NDIndex("*, 2")).getShape(), new Shape(5, 3));
      
           // Gets a range of values along the second axis that is inclusive on the bottom and exclusive on the top.
           assertEquals(a.get(new NDIndex(":, 1:3")).getShape(), new Shape(5, 2, 3));
      
           // Excludes either the min or the max of the range to go all the way to the beginning or end.
           assertEquals(a.get(new NDIndex(":, :3")).getShape(), new Shape(5, 3, 3));
           assertEquals(a.get(new NDIndex(":, 1:")).getShape(), new Shape(5, 4, 3));
      
           // Uses the value after the second colon in a slicing range, the step, to get every other result.
           assertEquals(a.get(new NDIndex(":, 1::2")).getShape(), new Shape(5, 2, 3));
      
           // Uses a negative step to reverse along the dimension.
           assertEquals(a.get(new NDIndex("-1")).getShape(), new Shape(5, 4, 3));
      
           // Uses a variable argument to the index
           // It can replace any number in any of these formats with {} and then the value of {}
           // is specified in an argument following the indices string.
           assertEquals(a.get(new NDIndex("{}, {}:{}", 0, 1, 3)).getShape(), new Shape(2, 3));
      
           // Uses ellipsis to insert many full slices
           assertEquals(a.get(new NDIndex("...")).getShape(), new Shape(5, 4, 3));
      
           // Uses ellipsis to select all the dimensions except for last axis where we only get a subsection.
           assertEquals(a.get(new NDIndex("..., 2")).getShape(), new Shape(5, 4));
      
           // Uses null to add an extra axis to the output array
           assertEquals(a.get(new NDIndex(":2, null, 0, :2")).getShape(), new Shape(2, 1, 2));
      
           // Gets entries of an NDArray with mixed index
           index1 = manager.create(new long[] {0, 1, 1}, new Shape(2));
           bool1 = manager.create(new boolean[] {true, false, true});
           assertEquals(a.get(new NDIndex(":{}, {}, {}, {}" 2, index1, bool1, null).getShape(), new Shape(2, 2, 1));
      
       
      Parameters:
      indices - a comma separated list of indices corresponding to either subsections, everything, or slices on a particular dimension
      args - arguments to replace the variable "{}" in the indices string. Can be an integer, long, boolean NDArray, or integer NDArray.
      See Also:
    • NDIndex

      public NDIndex(long... indices)
      Creates an NDIndex with the given indices as specified values on the NDArray.
      Parameters:
      indices - the indices with each index corresponding to the dimensions and negative indices starting from the end
  • Method Details

    • sliceAxis

      public static NDIndex sliceAxis(int axis, long min, long max)
      Creates an NDIndex that just has one slice in the given axis.
      Parameters:
      axis - the axis to slice
      min - the min of the slice
      max - the max of the slice
      Returns:
      a new NDIndex with the given slice.
    • getRank

      public int getRank()
      Returns the number of dimensions specified in the Index.
      Returns:
      the number of dimensions specified in the Index
    • getEllipsisIndex

      public int getEllipsisIndex()
      Returns the index of the ellipsis.
      Returns:
      the index of the ellipsis within this index or -1 for none.
    • get

      public NDIndexElement get(int dimension)
      Returns the index affecting the given dimension.
      Parameters:
      dimension - the affected dimension
      Returns:
      the index affecting the given dimension
    • getIndices

      public List<NDIndexElement> getIndices()
      Returns the indices.
      Returns:
      the indices
    • addIndices

      public final NDIndex addIndices(String indices, Object... args)
      Updates the NDIndex by appending indices to the array.
      Parameters:
      indices - the indices to add similar to NDIndex(String, Object...)
      args - arguments to replace the variable "{}" in the indices string. Can be an integer, long, boolean NDArray, or integer NDArray.
      Returns:
      the updated NDIndex
      See Also:
    • addIndices

      public final NDIndex addIndices(long... indices)
      Updates the NDIndex by appending indices as specified values on the NDArray.
      Parameters:
      indices - with each index corresponding to the dimensions and negative indices starting from the end
      Returns:
      the updated NDIndex
    • addBooleanIndex

      public NDIndex addBooleanIndex(NDArray index)
      Updates the NDIndex by appending a boolean NDArray.

      The NDArray should have a matching shape to the dimensions being fetched and will return where the values in NDIndex do not equal zero.

      Parameters:
      index - a boolean NDArray where all nonzero elements correspond to elements to return
      Returns:
      the updated NDIndex
    • addEllipseDim

      public NDIndex addEllipseDim()
      Appends ellipse index in the current dimension.
      Returns:
      the updated NDIndex
    • addAllDim

      public NDIndex addAllDim()
      Appends a new index to get all values in the dimension.
      Returns:
      the updated NDIndex
    • addAllDim

      public NDIndex addAllDim(int count)
      Appends multiple new index to get all values in the dimension.
      Parameters:
      count - how many axes of NDIndexAll to add.
      Returns:
      the updated NDIndex
      Throws:
      IllegalArgumentException - if count is negative
    • addSliceDim

      public NDIndex addSliceDim(long min, long max)
      Appends a new index to slice the dimension and returns a range of values.
      Parameters:
      min - the minimum of the range
      max - the maximum of the range
      Returns:
      the updated NDIndex
    • addSliceDim

      public NDIndex addSliceDim(long min, long max, long step)
      Appends a new index to slice the dimension and returns a range of values.
      Parameters:
      min - the minimum of the range
      max - the maximum of the range
      step - the step of the slice
      Returns:
      the updated NDIndex
    • addPickDim

      public NDIndex addPickDim(NDArray index)
      Appends a picking index that gets values by index in the axis.
      Parameters:
      index - the indices should be NDArray. For each element in the indices array, it acts like a fixed index returning an element of that shape. So, the final shape would be indices.getShape().addAll(target.getShape().slice(1)) (assuming it is the first index element).
      Returns:
      the updated NDIndex
    • stream

      public Stream<NDIndexElement> stream()
      Returns a stream of the NDIndexElements.
      Returns:
      a stream of the NDIndexElements