Package ai.djl.repository.zoo
Class BaseModelLoader
java.lang.Object
ai.djl.repository.zoo.BaseModelLoader
- All Implemented Interfaces:
ModelLoader
Shared code for the
ModelLoader
implementations.-
Field Summary
Fields -
Constructor Summary
ConstructorsConstructorDescriptionBaseModelLoader
(MRL mrl) Constructs aModelLoader
given the repository, mrl, and version. -
Method Summary
Modifier and TypeMethodDescriptionprotected Model
createModel
(Path modelPath, String name, Device device, Block block, Map<String, Object> arguments, String engine) <I,
O> void downloadModel
(Criteria<I, O> criteria, ai.djl.util.Progress progress) Downloads the model artifacts to local directory.Returns the application of theModelLoader
.Returns the artifact ID of theModelLoader
.Returns the group ID of theModelLoader
.getMrl()
Returns theMRL
of theModelLoader
.protected TranslatorFactory
getTranslatorFactory
(Criteria<?, ?> criteria, Map<String, Object> arguments) <I,
O> boolean isDownloaded
(Criteria<I, O> criteria) Returnstrue
if the model is downloaded in local directory.<I,
O> ZooModel<I, O> Loads the model with the given criteria.toString()
-
Field Details
-
mrl
-
defaultFactory
-
-
Constructor Details
-
BaseModelLoader
Constructs aModelLoader
given the repository, mrl, and version.- Parameters:
mrl
- the mrl of the model to load
-
-
Method Details
-
getGroupId
Returns the group ID of theModelLoader
.- Specified by:
getGroupId
in interfaceModelLoader
- Returns:
- the group ID of the
ModelLoader
-
getArtifactId
Returns the artifact ID of theModelLoader
.- Specified by:
getArtifactId
in interfaceModelLoader
- Returns:
- the artifact ID of the
ModelLoader
-
getApplication
Returns the application of theModelLoader
.- Specified by:
getApplication
in interfaceModelLoader
- Returns:
- the application of the
ModelLoader
-
getMrl
Returns theMRL
of theModelLoader
.- Specified by:
getMrl
in interfaceModelLoader
- Returns:
- the
MRL
of theModelLoader
-
loadModel
public <I,O> ZooModel<I,O> loadModel(Criteria<I, O> criteria) throws IOException, ModelNotFoundException, MalformedModelExceptionLoads the model with the given criteria.- Specified by:
loadModel
in interfaceModelLoader
- Type Parameters:
I
- the input data typeO
- the output data type- Parameters:
criteria
- the criteria to match against the loaded model- Returns:
- the loaded model
- Throws:
IOException
- for various exceptions loading data from the repositoryModelNotFoundException
- if no model with the specified criteria is foundMalformedModelException
- if the model data is malformed
-
isDownloaded
public <I,O> boolean isDownloaded(Criteria<I, O> criteria) throws IOException, ModelNotFoundExceptionReturnstrue
if the model is downloaded in local directory.- Specified by:
isDownloaded
in interfaceModelLoader
- Type Parameters:
I
- the input data typeO
- the output data type- Parameters:
criteria
- the criteria to match against the loaded model- Returns:
true
if the model is downloaded in local directory- Throws:
IOException
- for various exceptions loading data from the repositoryModelNotFoundException
- if no model with the specified criteria is found
-
downloadModel
public <I,O> void downloadModel(Criteria<I, O> criteria, ai.djl.util.Progress progress) throws IOException, ModelNotFoundExceptionDownloads the model artifacts to local directory.- Specified by:
downloadModel
in interfaceModelLoader
- Type Parameters:
I
- the input data typeO
- the output data type- Parameters:
criteria
- the criteria to match against the loaded modelprogress
- the progress tracker- Throws:
IOException
- for various exceptions loading data from the repositoryModelNotFoundException
- if no model with the specified criteria is found
-
createModel
protected Model createModel(Path modelPath, String name, Device device, Block block, Map<String, Object> arguments, String engine) throws IOException- Throws:
IOException
-
toString
-
getTranslatorFactory
protected TranslatorFactory getTranslatorFactory(Criteria<?, ?> criteria, Map<String, Object> arguments)
-