public class DBOW<T extends SequenceElement> extends Object implements SequenceLearningAlgorithm<T>
Modifier and Type | Field and Description |
---|---|
protected VectorsConfiguration |
configuration |
protected WeightLookupTable<T> |
lookupTable |
protected double |
negative |
protected SkipGram<T> |
skipGram |
protected boolean |
useAdaGrad |
protected VocabCache<T> |
vocabCache |
protected int |
window |
Constructor and Description |
---|
DBOW() |
Modifier and Type | Method and Description |
---|---|
void |
configure(@NonNull VocabCache<T> vocabCache,
@NonNull WeightLookupTable<T> lookupTable,
@NonNull VectorsConfiguration configuration) |
protected void |
dbow(int i,
Sequence<T> sequence,
int b,
AtomicLong nextRandom,
double alpha,
boolean isInference,
org.nd4j.linalg.api.ndarray.INDArray inferenceVector,
BatchSequences<T> batchSequences) |
void |
finish() |
String |
getCodeName() |
ElementsLearningAlgorithm<T> |
getElementsLearningAlgorithm() |
org.nd4j.linalg.api.ndarray.INDArray |
inferSequence(Sequence<T> sequence,
long nextRandom,
double learningRate,
double minLearningRate,
int iterations)
This method does training on previously unseen paragraph, and returns inferred vector
|
boolean |
isEarlyTerminationHit()
DBOW has no reasons for early termination
|
double |
learnSequence(@NonNull Sequence<T> sequence,
@NonNull AtomicLong nextRandom,
double learningRate,
BatchSequences<T> batchSequences)
This method does training over the sequence of elements passed into it
|
void |
pretrain(SequenceIterator<T> iterator)
DBOW doesn't involves any pretraining
|
protected VocabCache<T extends SequenceElement> vocabCache
protected WeightLookupTable<T extends SequenceElement> lookupTable
protected VectorsConfiguration configuration
protected int window
protected boolean useAdaGrad
protected double negative
protected SkipGram<T extends SequenceElement> skipGram
public ElementsLearningAlgorithm<T> getElementsLearningAlgorithm()
getElementsLearningAlgorithm
in interface SequenceLearningAlgorithm<T extends SequenceElement>
public String getCodeName()
getCodeName
in interface SequenceLearningAlgorithm<T extends SequenceElement>
public void configure(@NonNull @NonNull VocabCache<T> vocabCache, @NonNull @NonNull WeightLookupTable<T> lookupTable, @NonNull @NonNull VectorsConfiguration configuration)
configure
in interface SequenceLearningAlgorithm<T extends SequenceElement>
public void pretrain(SequenceIterator<T> iterator)
pretrain
in interface SequenceLearningAlgorithm<T extends SequenceElement>
iterator
- public double learnSequence(@NonNull @NonNull Sequence<T> sequence, @NonNull @NonNull AtomicLong nextRandom, double learningRate, BatchSequences<T> batchSequences)
SequenceLearningAlgorithm
learnSequence
in interface SequenceLearningAlgorithm<T extends SequenceElement>
public boolean isEarlyTerminationHit()
isEarlyTerminationHit
in interface SequenceLearningAlgorithm<T extends SequenceElement>
protected void dbow(int i, Sequence<T> sequence, int b, AtomicLong nextRandom, double alpha, boolean isInference, org.nd4j.linalg.api.ndarray.INDArray inferenceVector, BatchSequences<T> batchSequences)
public org.nd4j.linalg.api.ndarray.INDArray inferSequence(Sequence<T> sequence, long nextRandom, double learningRate, double minLearningRate, int iterations)
inferSequence
in interface SequenceLearningAlgorithm<T extends SequenceElement>
sequence
- nextRandom
- learningRate
- public void finish()
finish
in interface SequenceLearningAlgorithm<T extends SequenceElement>
Copyright © 2021. All rights reserved.