Package ai.djl.training.dataset
Class ArrayDataset
java.lang.Object
ai.djl.training.dataset.RandomAccessDataset
ai.djl.training.dataset.ArrayDataset
- All Implemented Interfaces:
Dataset
ArrayDataset
is an implementation of RandomAccessDataset
that consist entirely of
large NDArray
s. It is recommended only for datasets small enough to fit in memory that
come in array formats. Otherwise, consider directly using the RandomAccessDataset
instead.
There can be multiple data and label NDArray
s within the dataset. Each sample will be
retrieved by indexing each NDArray
along the first dimension.
The following is an example of how to use ArrayDataset:
ArrayDataset dataset = new ArrayDataset.Builder() .setData(data1, data2) .optLabels(labels1, labels2, labels3) .setSampling(20, false) .build();
Suppose you get a Batch
from trainer.iterateDataset(dataset)
or
dataset.getData(manager)
. In the data of this batch, it will be an NDList with one NDArray for
each data input. In this case, it would be 2 arrays. Similarly, the labels would have 3 arrays.
- See Also:
-
Nested Class Summary
Nested classes/interfaces inherited from class ai.djl.training.dataset.RandomAccessDataset
RandomAccessDataset.BaseBuilder<T extends RandomAccessDataset.BaseBuilder<T>>
Nested classes/interfaces inherited from interface ai.djl.training.dataset.Dataset
Dataset.Usage
-
Field Summary
Fields inherited from class ai.djl.training.dataset.RandomAccessDataset
dataBatchifier, device, labelBatchifier, limit, pipeline, prefetchNumber, sampler, targetPipeline
-
Constructor Summary
ConstructorDescriptionArrayDataset
(RandomAccessDataset.BaseBuilder<?> builder) Creates a new instance ofArrayDataset
with the arguments inArrayDataset.Builder
. -
Method Summary
Modifier and TypeMethodDescriptionprotected long
Returns the number of records available to be read in thisDataset
.Gets theRecord
for the given index from the dataset.getByIndices
(NDManager manager, long... indices) Gets theBatch
for the given indices from the dataset.getByRange
(NDManager manager, long fromIndex, long toIndex) Gets theBatch
for the given range from the dataset.getData
(NDManager manager, Sampler sampler, ExecutorService executorService) Fetches an iterator that can iterate through theDataset
with a custom sampler multi-threaded.protected RandomAccessDataset
newSubDataset
(int[] indices, int from, int to) protected RandomAccessDataset
newSubDataset
(List<Long> subIndices) void
prepare
(ai.djl.util.Progress progress) Prepares the dataset for use with tracked progress.Methods inherited from class ai.djl.training.dataset.RandomAccessDataset
getData, getData, getData, randomSplit, size, subDataset, subDataset, subDataset, subDataset, toArray
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
Methods inherited from interface ai.djl.training.dataset.Dataset
matchingTranslatorOptions, prepare
-
Field Details
-
data
-
labels
-
-
Constructor Details
-
ArrayDataset
Creates a new instance ofArrayDataset
with the arguments inArrayDataset.Builder
.- Parameters:
builder
- a builder with the required arguments
-
-
Method Details
-
availableSize
protected long availableSize()Returns the number of records available to be read in thisDataset
.- Specified by:
availableSize
in classRandomAccessDataset
- Returns:
- the number of records available to be read in this
Dataset
-
get
Gets theRecord
for the given index from the dataset.- Specified by:
get
in classRandomAccessDataset
- Parameters:
manager
- the manager used to create the arraysindex
- the index of the requested data item- Returns:
- a
Record
that contains the data and label of the requested data item
-
getByIndices
Gets theBatch
for the given indices from the dataset.- Parameters:
manager
- the manager used to create the arraysindices
- indices of the requested data items- Returns:
- a
Batch
that contains the data and label of the requested data items
-
getByRange
Gets theBatch
for the given range from the dataset.- Parameters:
manager
- the manager used to create the arraysfromIndex
- low endpoint (inclusive) of the datasettoIndex
- high endpoint (exclusive) of the dataset- Returns:
- a
Batch
that contains the data and label of the requested data items
-
newSubDataset
- Overrides:
newSubDataset
in classRandomAccessDataset
-
newSubDataset
- Overrides:
newSubDataset
in classRandomAccessDataset
-
getData
public Iterable<Batch> getData(NDManager manager, Sampler sampler, ExecutorService executorService) throws IOException, TranslateException Fetches an iterator that can iterate through theDataset
with a custom sampler multi-threaded.- Overrides:
getData
in classRandomAccessDataset
- Parameters:
manager
- the manager to create the arrayssampler
- the sampler to use to iterate through the datasetexecutorService
- the executorService to multi-thread with- Returns:
- an
Iterable
ofBatch
that contains batches of data from the dataset - Throws:
IOException
- for various exceptions depending on the datasetTranslateException
- if there is an error while processing input
-
prepare
Prepares the dataset for use with tracked progress.- Parameters:
progress
- the progress tracker- Throws:
IOException
- for various exceptions depending on the dataset
-