Class GoEmotions

java.lang.Object
ai.djl.training.dataset.RandomAccessDataset
ai.djl.basicdataset.nlp.TextDataset
ai.djl.basicdataset.nlp.GoEmotions
All Implemented Interfaces:
ai.djl.training.dataset.Dataset

public class GoEmotions extends TextDataset
GoEmotions is a corpus of 58k carefully curated comments extracted from Reddit, with human annotations to 27 emotion categories or Neutral. This version of data is filtered based on rater-agreement on top of the raw data, and contains a train/test/validation split. The emotion categories are: admiration, amusement, anger, annoyance, approval, caring, confusion, curiosity, desire, disappointment, disapproval, disgust, embarrassment, excitement, fear, gratitude, grief, joy, love, nervousness, optimism, pride, realization, relief, remorse, sadness, surprise.
  • Nested Class Summary

    Nested Classes
    Modifier and Type
    Class
    Description
    static final class 
    A builder to construct a GoEmotions.

    Nested classes/interfaces inherited from class ai.djl.basicdataset.nlp.TextDataset

    TextDataset.Sample

    Nested classes/interfaces inherited from class ai.djl.training.dataset.RandomAccessDataset

    ai.djl.training.dataset.RandomAccessDataset.BaseBuilder<T extends ai.djl.training.dataset.RandomAccessDataset.BaseBuilder<T>>

    Nested classes/interfaces inherited from interface ai.djl.training.dataset.Dataset

    ai.djl.training.dataset.Dataset.Usage
  • Field Summary

    Fields inherited from class ai.djl.basicdataset.nlp.TextDataset

    manager, mrl, prepared, samples, sourceTextData, targetTextData, usage

    Fields inherited from class ai.djl.training.dataset.RandomAccessDataset

    dataBatchifier, device, labelBatchifier, limit, pipeline, prefetchNumber, sampler, targetPipeline
  • Method Summary

    Modifier and Type
    Method
    Description
    protected long
    Returns the number of records available to be read in this Dataset.
    Creates a builder to build a GoEmotions.
    ai.djl.training.dataset.Record
    get(ai.djl.ndarray.NDManager manager, long index)
    Gets the Record for the given index from the dataset.
    void
    prepare(ai.djl.util.Progress progress)
    Prepares the dataset for use with tracked progress.

    Methods inherited from class ai.djl.basicdataset.nlp.TextDataset

    getProcessedText, getRawText, getSamples, getTextEmbedding, getVocabulary, preprocess

    Methods inherited from class ai.djl.training.dataset.RandomAccessDataset

    getData, getData, getData, getData, newSubDataset, newSubDataset, randomSplit, size, subDataset, subDataset, subDataset, subDataset, toArray

    Methods inherited from class java.lang.Object

    clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait

    Methods inherited from interface ai.djl.training.dataset.Dataset

    matchingTranslatorOptions, prepare
  • Method Details

    • prepare

      public void prepare(ai.djl.util.Progress progress) throws IOException, ai.djl.modality.nlp.embedding.EmbeddingException
      Prepares the dataset for use with tracked progress. In this method the TSV file will be parsed. All datasets will be preprocessed.
      Parameters:
      progress - the progress tracker
      Throws:
      IOException - for various exceptions depending on the dataset
      ai.djl.modality.nlp.embedding.EmbeddingException
    • get

      public ai.djl.training.dataset.Record get(ai.djl.ndarray.NDManager manager, long index) throws IOException
      Gets the Record for the given index from the dataset.
      Specified by:
      get in class ai.djl.training.dataset.RandomAccessDataset
      Parameters:
      manager - the manager used to create the arrays
      index - the index of the requested data item
      Returns:
      a Record that contains the data and label of the requested data item. The data NDList contains three NDArrays representing the embedded title, context and question, which are named accordingly. The label NDList contains multiple NDArrays corresponding to each embedded answer.
      Throws:
      IOException
    • availableSize

      protected long availableSize()
      Returns the number of records available to be read in this Dataset. In this implementation, the actual size of available records are the size of questionInfoList.
      Specified by:
      availableSize in class ai.djl.training.dataset.RandomAccessDataset
      Returns:
      the number of records available to be read in this Dataset
    • builder

      public static GoEmotions.Builder builder()
      Creates a builder to build a GoEmotions.
      Returns:
      a new builder