Package ai.djl.training.dataset
Class ArrayDataset
- java.lang.Object
-
- ai.djl.training.dataset.RandomAccessDataset
-
- ai.djl.training.dataset.ArrayDataset
-
- All Implemented Interfaces:
Dataset
public class ArrayDataset extends RandomAccessDataset
ArrayDataset
is an implementation ofRandomAccessDataset
that consist entirely of largeNDArray
s. It is recommended only for datasets small enough to fit in memory that come in array formats. Otherwise, consider directly using theRandomAccessDataset
instead.There can be multiple data and label
NDArray
s within the dataset. Each sample will be retrieved by indexing eachNDArray
along the first dimension.The following is an example of how to use ArrayDataset:
ArrayDataset dataset = new ArrayDataset.Builder() .setData(data1, data2) .optLabels(labels1, labels2, labels3) .setSampling(20, false) .build();
Suppose you get a
Batch
fromtrainer.iterateDataset(dataset)
ordataset.getData(manager)
. In the data of this batch, it will be an NDList with one NDArray for each data input. In this case, it would be 2 arrays. Similarly, the labels would have 3 arrays.- See Also:
Dataset
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
ArrayDataset.Builder
The Builder to construct anArrayDataset
.-
Nested classes/interfaces inherited from class ai.djl.training.dataset.RandomAccessDataset
RandomAccessDataset.BaseBuilder<T extends RandomAccessDataset.BaseBuilder<T>>
-
Nested classes/interfaces inherited from interface ai.djl.training.dataset.Dataset
Dataset.Usage
-
-
Field Summary
Fields Modifier and Type Field Description protected NDArray[]
data
protected NDArray[]
labels
-
Fields inherited from class ai.djl.training.dataset.RandomAccessDataset
dataBatchifier, device, labelBatchifier, limit, pipeline, prefetchNumber, sampler, targetPipeline
-
-
Constructor Summary
Constructors Constructor Description ArrayDataset(RandomAccessDataset.BaseBuilder<?> builder)
Creates a new instance ofArrayDataset
with the arguments inArrayDataset.Builder
.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description protected long
availableSize()
Returns the number of records available to be read in thisDataset
.Record
get(NDManager manager, long index)
Gets theRecord
for the given index from the dataset.Batch
getByIndices(NDManager manager, long... indices)
Gets theBatch
for the given indices from the dataset.Batch
getByRange(NDManager manager, long fromIndex, long toIndex)
Gets theBatch
for the given range from the dataset.java.lang.Iterable<Batch>
getData(NDManager manager, Sampler sampler, java.util.concurrent.ExecutorService executorService)
Fetches an iterator that can iterate through theDataset
with a custom sampler multi-threaded.protected RandomAccessDataset
newSubDataset(int[] indices, int from, int to)
protected RandomAccessDataset
newSubDataset(java.util.List<java.lang.Long> subIndices)
void
prepare(ai.djl.util.Progress progress)
Prepares the dataset for use with tracked progress.-
Methods inherited from class ai.djl.training.dataset.RandomAccessDataset
getData, getData, getData, randomSplit, size, subDataset, subDataset, subDataset, subDataset, toArray
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
-
Methods inherited from interface ai.djl.training.dataset.Dataset
matchingTranslatorOptions, prepare
-
-
-
-
Constructor Detail
-
ArrayDataset
public ArrayDataset(RandomAccessDataset.BaseBuilder<?> builder)
Creates a new instance ofArrayDataset
with the arguments inArrayDataset.Builder
.- Parameters:
builder
- a builder with the required arguments
-
-
Method Detail
-
availableSize
protected long availableSize()
Returns the number of records available to be read in thisDataset
.- Specified by:
availableSize
in classRandomAccessDataset
- Returns:
- the number of records available to be read in this
Dataset
-
get
public Record get(NDManager manager, long index)
Gets theRecord
for the given index from the dataset.- Specified by:
get
in classRandomAccessDataset
- Parameters:
manager
- the manager used to create the arraysindex
- the index of the requested data item- Returns:
- a
Record
that contains the data and label of the requested data item
-
getByIndices
public Batch getByIndices(NDManager manager, long... indices)
Gets theBatch
for the given indices from the dataset.- Parameters:
manager
- the manager used to create the arraysindices
- indices of the requested data items- Returns:
- a
Batch
that contains the data and label of the requested data items
-
getByRange
public Batch getByRange(NDManager manager, long fromIndex, long toIndex)
Gets theBatch
for the given range from the dataset.- Parameters:
manager
- the manager used to create the arraysfromIndex
- low endpoint (inclusive) of the datasettoIndex
- high endpoint (exclusive) of the dataset- Returns:
- a
Batch
that contains the data and label of the requested data items
-
newSubDataset
protected RandomAccessDataset newSubDataset(int[] indices, int from, int to)
- Overrides:
newSubDataset
in classRandomAccessDataset
-
newSubDataset
protected RandomAccessDataset newSubDataset(java.util.List<java.lang.Long> subIndices)
- Overrides:
newSubDataset
in classRandomAccessDataset
-
getData
public java.lang.Iterable<Batch> getData(NDManager manager, Sampler sampler, java.util.concurrent.ExecutorService executorService) throws java.io.IOException, TranslateException
Fetches an iterator that can iterate through theDataset
with a custom sampler multi-threaded.- Overrides:
getData
in classRandomAccessDataset
- Parameters:
manager
- the manager to create the arrayssampler
- the sampler to use to iterate through the datasetexecutorService
- the executorService to multi-thread with- Returns:
- an
Iterable
ofBatch
that contains batches of data from the dataset - Throws:
java.io.IOException
- for various exceptions depending on the datasetTranslateException
- if there is an error while processing input
-
prepare
public void prepare(ai.djl.util.Progress progress) throws java.io.IOException
Prepares the dataset for use with tracked progress.- Parameters:
progress
- the progress tracker- Throws:
java.io.IOException
- for various exceptions depending on the dataset
-
-