美文网首页
如何将训练好的tensorflow模型导入到DJL

如何将训练好的tensorflow模型导入到DJL

作者: 郭彦超 | 来源:发表于2020-06-08 21:39 被阅读0次

模型准备

基于tensorflow 网上有人基于新冠肺炎患者X光片和正常人光片训练了一个新冠肺炎识别模型;今天我们就使用这个模型测试一下djl加载的效果

链接: https://pan.baidu.com/s/1gBRKV7Rugou3H2_oNL5bVg 提取码: rt7d

添加依赖


        <dependency>
            <groupId>ai.djl.tensorflow</groupId>
            <artifactId>tensorflow-model-zoo</artifactId>
            <version>0.5.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.tensorflow</groupId>
            <artifactId>tensorflow-engine</artifactId>
            <version>0.5.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.tensorflow</groupId>
            <artifactId>tensorflow-native-auto</artifactId>
            <version>2.1.0</version>
            <scope>runtime</scope>
        </dependency>

加载模型


public static void main(String[] args) throws Exception {
        System.setProperty("ai.djl.repository.zoo.location", "{covid save model path}");
        Criteria<Image, Classifications> criteria =
                Criteria.builder()
                        .setTypes(Image.class, Classifications.class)
                        .optTranslator(new MyTranslator())
                        .build();
        ZooModel model = ModelZoo.loadModel(criteria);
        Predictor<Image, Classifications> predictor = model.newPredictor();

        Classifications classifications = predictor.predict(ImageFactory.getInstance().fromFile(Paths.get(args[0])));
        System.out.printf(new Gson().toJson(classifications.best()));
    }
   
 private static final class MyTranslator implements Translator<Image, Classifications> {

        private static final List<String> CLASSES = Arrays.asList("covid-19", "normal");

        @Override
        public NDList processInput(TranslatorContext ctx, Image input) {
            NDArray array =
                    input.toNDArray(
                            ctx.getNDManager(), Image.Flag.COLOR);
            array = NDImageUtils.resize(array, 224).div(255.0f);
            return new NDList(array);
        }

        @Override
        public Classifications processOutput(TranslatorContext ctx, NDList list) {
            NDArray probabilities = list.singletonOrThrow();
            return new Classifications(CLASSES, probabilities);
        }
    }

需要注意的是目前djl对tensorflow支持的还不是特别好 ,并且不支持windows环境运行

相关文章

网友评论

      本文标题:如何将训练好的tensorflow模型导入到DJL

      本文链接:https://www.haomeiwen.com/subject/pqpdtktx.html