Class ArrayDataset

  • All Implemented Interfaces:
    Dataset

    public class ArrayDataset
    extends RandomAccessDataset
    ArrayDataset is an implementation of RandomAccessDataset that consist entirely of large NDArrays. It is recommended only for datasets small enough to fit in memory that come in array formats. Otherwise, consider directly using the RandomAccessDataset instead.

    There can be multiple data and label NDArrays within the dataset. Each sample will be retrieved by indexing each NDArray along the first dimension.

    The following is an example of how to use ArrayDataset:

         ArrayDataset dataset = new ArrayDataset.Builder()
                                  .setData(data1, data2)
                                  .optLabels(labels1, labels2, labels3)
                                  .setSampling(20, false)
                                  .build();
     

    Suppose you get a Batch from trainer.iterateDataset(dataset) or dataset.getData(manager). In the data of this batch, it will be an NDList with one NDArray for each data input. In this case, it would be 2 arrays. Similarly, the labels would have 3 arrays.

    See Also:
    Dataset
    • Method Detail

      • availableSize

        protected long availableSize()
        Returns the number of records available to be read in this Dataset.
        Specified by:
        availableSize in class RandomAccessDataset
        Returns:
        the number of records available to be read in this Dataset
      • get

        public Record get​(NDManager manager,
                          long index)
        Gets the Record for the given index from the dataset.
        Specified by:
        get in class RandomAccessDataset
        Parameters:
        manager - the manager used to create the arrays
        index - the index of the requested data item
        Returns:
        a Record that contains the data and label of the requested data item
      • getByIndices

        public Batch getByIndices​(NDManager manager,
                                  long... indices)
        Gets the Batch for the given indices from the dataset.
        Parameters:
        manager - the manager used to create the arrays
        indices - indices of the requested data items
        Returns:
        a Batch that contains the data and label of the requested data items
      • getByRange

        public Batch getByRange​(NDManager manager,
                                long fromIndex,
                                long toIndex)
        Gets the Batch for the given range from the dataset.
        Parameters:
        manager - the manager used to create the arrays
        fromIndex - low endpoint (inclusive) of the dataset
        toIndex - high endpoint (exclusive) of the dataset
        Returns:
        a Batch that contains the data and label of the requested data items
      • getData

        public java.lang.Iterable<Batch> getData​(NDManager manager,
                                                 Sampler sampler,
                                                 java.util.concurrent.ExecutorService executorService)
                                          throws java.io.IOException,
                                                 TranslateException
        Fetches an iterator that can iterate through the Dataset with a custom sampler multi-threaded.
        Overrides:
        getData in class RandomAccessDataset
        Parameters:
        manager - the manager to create the arrays
        sampler - the sampler to use to iterate through the dataset
        executorService - the executorService to multi-thread with
        Returns:
        an Iterable of Batch that contains batches of data from the dataset
        Throws:
        java.io.IOException - for various exceptions depending on the dataset
        TranslateException - if there is an error while processing input
      • prepare

        public void prepare​(ai.djl.util.Progress progress)
                     throws java.io.IOException
        Prepares the dataset for use with tracked progress.
        Parameters:
        progress - the progress tracker
        Throws:
        java.io.IOException - for various exceptions depending on the dataset