模型准备
基于tensorflow 网上有人基于新冠肺炎患者X光片和正常人光片训练了一个新冠肺炎识别模型;今天我们就使用这个模型测试一下djl加载的效果
链接: https://pan.baidu.com/s/1gBRKV7Rugou3H2_oNL5bVg 提取码: rt7d
添加依赖
<dependency>
<groupId>ai.djl.tensorflow</groupId>
<artifactId>tensorflow-model-zoo</artifactId>
<version>0.5.0</version>
</dependency>
<dependency>
<groupId>ai.djl.tensorflow</groupId>
<artifactId>tensorflow-engine</artifactId>
<version>0.5.0</version>
</dependency>
<dependency>
<groupId>ai.djl.tensorflow</groupId>
<artifactId>tensorflow-native-auto</artifactId>
<version>2.1.0</version>
<scope>runtime</scope>
</dependency>
加载模型
public static void main(String[] args) throws Exception {
System.setProperty("ai.djl.repository.zoo.location", "{covid save model path}");
Criteria<Image, Classifications> criteria =
Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optTranslator(new MyTranslator())
.build();
ZooModel model = ModelZoo.loadModel(criteria);
Predictor<Image, Classifications> predictor = model.newPredictor();
Classifications classifications = predictor.predict(ImageFactory.getInstance().fromFile(Paths.get(args[0])));
System.out.printf(new Gson().toJson(classifications.best()));
}
private static final class MyTranslator implements Translator<Image, Classifications> {
private static final List<String> CLASSES = Arrays.asList("covid-19", "normal");
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
NDArray array =
input.toNDArray(
ctx.getNDManager(), Image.Flag.COLOR);
array = NDImageUtils.resize(array, 224).div(255.0f);
return new NDList(array);
}
@Override
public Classifications processOutput(TranslatorContext ctx, NDList list) {
NDArray probabilities = list.singletonOrThrow();
return new Classifications(CLASSES, probabilities);
}
}
需要注意的是目前djl对tensorflow支持的还不是特别好 ,并且不支持windows环境运行
网友评论