使用Spring Boot + DeepLearning4j(DL4J)的钟表图片识别

要实现这个功能的思路是:训练一个 DL4J 图像分类模型识别钟表类型 / 状态(如数字钟、指针钟、特定品牌钟表等),再通过 Spring Boot 提供 HTTP 接口接收图片,调用训练好的模型返回识别结果。

使用Spring Boot + DeepLearning4j(DL4J)的钟表图片识别

一、环境准备

1. 技术栈版本

  • JDK 11+(DL4J 对 JDK8 兼容性一般,推荐 11)
  • Spring Boot 3.2.x
  • DeepLearning4j 1.0.0-M2.1(稳定版,支持 CNN 图像分类)
  • ND4J(DL4J 依赖的数值计算库)
  • OpenCV(可选,用于图片预处理,如尺寸统一、灰度化)

2. Maven 依赖(pom.xml)

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>3.2.1</version>
        <relativePath/>
    </parent>
    <groupId>com.example</groupId>
    <artifactId>clock-recognition</artifactId>
    <version>0.0.1-SNAPSHOT</version>
    <name>clock-recognition</name>
    <description>Spring Boot + DL4J 钟表图片识别</description>

    <properties>
        <java.version>11</java.version>
        <dl4j.version>1.0.0-M2.1</dl4j.version>
        <nd4j.backend>nd4j-native-platform</nd4j.backend> <!-- CPU版,GPU版用nd4j-cuda-11.8-platform -->
    </properties>

    <dependencies>
        <!-- Spring Boot核心 -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>

        <!-- DL4J核心依赖 -->
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>${dl4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-zoo</artifactId> <!-- 预训练模型 zoo(可选) -->
            <version>${dl4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>${nd4j.backend}</artifactId>
            <version>${dl4j.version}</version>
        </dependency>

        <!-- 图片处理 -->
        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>javacv-platform</artifactId> <!-- 替代OpenCV,简化依赖 -->
            <version>1.5.9</version>
        </dependency>

        <!-- 工具类 -->
        <dependency>
            <groupId>commons-io</groupId>
            <artifactId>commons-io</artifactId>
            <version>2.15.1</version>
        </dependency>
    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
            </plugin>
        </plugins>
    </build>
</project>

二、核心步骤

1. 数据集准备

  • 类别定义:如数字钟、指针钟、智能手表(可根据需求扩展)
  • 数据集结构(推荐 ImageNet 格式):

dataset/ train/ 数字钟/ img1.jpg img2.jpg … 指针钟/ img1.jpg … 智能手表/ … val/ # 验证集(约占总数据20%) 数字钟/ … …

  • 数据量要求:每个类别至少 50 张图片(数据量越大,模型精度越高)
  • 数据增强(可选,提升泛化能力):旋转、缩放、翻转、亮度调整等

2. DL4J 模型训练(CNN 分类模型)

使用 DL4J 的ComputationGraph构建卷积神经网络,步骤如下:

(1)数据加载与预处理

import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;

import java.io.File;
import java.util.Random;

public class DataLoader {
    // 图片尺寸(统一缩放为224x224,适配CNN)
    private static final int HEIGHT = 224;
    private static final int WIDTH = 224;
    private static final int CHANNELS = 3; // RGB彩色图
    private static final int BATCH_SIZE = 16; // 批次大小
    private static final int NUM_CLASSES = 3; // 类别数(数字钟、指针钟、智能手表)

    // 加载训练集
    public static DataSetIterator loadTrainData(String dataPath) throws Exception {
        File parentDir = new File(dataPath + "/train");
        FileSplit fileSplit = new FileSplit(parentDir, NativeImageLoader.ALLOWED_FORMATS, new Random(123));
        
        // 平衡数据集(避免类别不平衡)
        BalancedPathFilter pathFilter = new BalancedPathFilter(new Random(123), NativeImageLoader.ALLOWED_FORMATS, NUM_CLASSES);
        
        // 图片读取器(转换为NDArray)
        ImageRecordReader recordReader = new ImageRecordReader(HEIGHT, WIDTH, CHANNELS);
        recordReader.initialize(fileSplit, pathFilter);
        
        // 转换为DataSetIterator(标签自动从文件夹名生成)
        return new RecordReaderDataSetIterator(
            recordReader, BATCH_SIZE, 1, NUM_CLASSES, true
        );
    }

    // 加载验证集(逻辑同训练集,路径改为val)
    public static DataSetIterator loadValData(String dataPath) throws Exception {
        File parentDir = new File(dataPath + "/val");
        FileSplit fileSplit = new FileSplit(parentDir, NativeImageLoader.ALLOWED_FORMATS, new Random(123));
        BalancedPathFilter pathFilter = new BalancedPathFilter(new Random(123), NativeImageLoader.ALLOWED_FORMATS, NUM_CLASSES);
        ImageRecordReader recordReader = new ImageRecordReader(HEIGHT, WIDTH, CHANNELS);
        recordReader.initialize(fileSplit, pathFilter);
        return new RecordReaderDataSetIterator(
            recordReader, BATCH_SIZE, 1, NUM_CLASSES, true
        );
    }
}

(2)构建 CNN 模型

import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class ClockCNNModel {
    private static final int HEIGHT = 224;
    private static final int WIDTH = 224;
    private static final int CHANNELS = 3;
    private static final int NUM_CLASSES = 3;

    // 构建CNN模型(简化版VGG16)
    public static ComputationGraph buildModel() {
        ComputationGraphConfiguration config = new ComputationGraphConfiguration.Builder()
            // 输入层:224x224x3
            .setInputTypes(InputType.convolutional(HEIGHT, WIDTH, CHANNELS))
            // 卷积层1:32个3x3卷积核,步长1,填充same
            .addLayer("conv1", new ConvolutionLayer.Builder(3, 3)
                .nIn(CHANNELS)
                .nOut(32)
                .stride(1, 1)
                .padding(1, 1)
                .weightInit(WeightInit.XAVIER)
                .activation(Activation.RELU)
                .convolutionMode(ConvolutionMode.Same)
                .build(), "input")
            // 池化层1:2x2最大池化
            .addLayer("pool1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                .kernelSize(2, 2)
                .stride(2, 2)
                .build(), "conv1")
            // 卷积层2 + 池化层2
            .addLayer("conv2", new ConvolutionLayer.Builder(3, 3)
                .nOut(64)
                .stride(1, 1)
                .padding(1, 1)
                .weightInit(WeightInit.XAVIER)
                .activation(Activation.RELU)
                .convolutionMode(ConvolutionMode.Same)
                .build(), "pool1")
            .addLayer("pool2", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                .kernelSize(2, 2)
                .stride(2, 2)
                .build(), "conv2")
            // 全连接层1
            .addLayer("fc1", new DenseLayer.Builder()
                .nOut(512)
                .weightInit(WeightInit.XAVIER)
                .activation(Activation.RELU)
                .dropOut(0.5) //  dropout防止过拟合
                .build(), "pool2")
            // 输出层:softmax激活(多分类)
            .addLayer("output", new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .nOut(NUM_CLASSES)
                .weightInit(WeightInit.XAVIER)
                .activation(Activation.SOFTMAX)
                .build(), "fc1")
            // 定义输入输出节点
            .setOutputs("output")
            // 优化器:Adam(学习率0.001)
            .updater(new Adam(0.001))
            // 梯度归一化(稳定训练)
            .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
            .gradientNormalizationThreshold(1.0)
            .build();

        ComputationGraph model = new ComputationGraph(config);
        model.init(); // 初始化模型
        return model;
    }
}

(3)模型训练与保存

import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.io.ClassPathResource;

import java.io.File;

public class ModelTrainer {
    private static final int EPOCHS = 20; // 训练轮数
    private static final String DATA_PATH = "path/to/your/dataset"; // 数据集路径
    private static final String MODEL_SAVE_PATH = "src/main/resources/model/clock-model.zip"; // 模型保存路径

    public static void train() throws Exception {
        // 1. 加载数据
        DataSetIterator trainIter = DataLoader.loadTrainData(DATA_PATH);
        DataSetIterator valIter = DataLoader.loadValData(DATA_PATH);

        // 2. 构建模型
        ComputationGraph model = ClockCNNModel.buildModel();

        // 3. 训练模型
        for (int i = 0; i < EPOCHS; i++) {
            model.fit(trainIter); // 训练一轮
            System.out.println("完成第 " + (i+1) + " 轮训练");

            // 4. 验证模型
            Evaluation eval = model.evaluate(valIter);
            System.out.println("验证集准确率:" + eval.accuracy());
            System.out.println(eval.stats());

            // 重置迭代器(下一轮重新读取数据)
            trainIter.reset();
            valIter.reset();
        }

        // 5. 保存模型(用于Spring Boot部署)
        model.save(new File(new ClassPathResource("").getFile(), MODEL_SAVE_PATH));
        System.out.println("模型保存成功:" + MODEL_SAVE_PATH);
    }

    public static void main(String[] args) throws Exception {
        train(); // 执行训练
    }
}

3. Spring Boot 部署模型(提供 HTTP 识别接口)

(1)模型加载配置(单例模式,避免重复加载)

import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import javax.annotation.PostConstruct;
import java.io.File;
import java.io.IOException;

@Configuration
public class ModelConfig {
    private ComputationGraph model;
    private final String MODEL_PATH = "src/main/resources/model/clock-model.zip";
    private final int HEIGHT = 224;
    private final int WIDTH = 224;
    private final int CHANNELS = 3;

    // 加载模型(项目启动时执行)
    @PostConstruct
    public void loadModel() throws IOException {
        model = ComputationGraph.load(new File(MODEL_PATH), true);
        System.out.println("模型加载成功!");
    }

    // 提供模型Bean(供Controller调用)
    @Bean
    public ComputationGraph getModel() {
        return model;
    }

    // 图片预处理(缩放、归一化)
    @Bean
    public ImagePreProcessingScaler getImageScaler() {
        return new ImagePreProcessingScaler(0, 1); // 像素值归一化到[0,1]
    }

    // 类别映射(索引->类别名)
    @Bean
    public String[] getLabelMap() {
        // 注意:与训练时的类别顺序一致(文件夹名排序)
        return new String[]{"数字钟", "指针钟", "智能手表"};
    }
}

(2)图片识别工具类

import org.bytedeco.javacv.CanvasFrame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.awt.image.BufferedImage;
import java.io.InputStream;

@Component
public class ClockRecognitionUtil {
    @Autowired
    private ComputationGraph model;

    @Autowired
    private ImagePreProcessingScaler imageScaler;

    @Autowired
    private String[] labelMap;

    private final int HEIGHT = 224;
    private final int WIDTH = 224;
    private final int CHANNELS = 3;

    // 识别图片(输入InputStream,返回类别名)
    public String recognize(InputStream imageInputStream) throws Exception {
        // 1. 读取图片并转换为INDArray(224x224x3)
        NativeImageLoader imageLoader = new NativeImageLoader(HEIGHT, WIDTH, CHANNELS);
        INDArray imageArray = imageLoader.asMatrix(imageInputStream);

        // 2. 图片预处理(归一化)
        imageScaler.transform(imageArray);

        // 3. 模型预测(添加批次维度:[1, 3, 224, 224])
        INDArray[] outputs = model.output(imageArray.reshape(1, CHANNELS, HEIGHT, WIDTH));
        INDArray predictions = outputs[0]; // 预测结果(概率分布)

        // 4. 获取概率最大的类别索引
        int predictedIndex = predictions.argMax(1).getInt(0);

        // 5. 返回类别名
        return labelMap[predictedIndex];
    }
}

(3)HTTP 接口 Controller

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;

import java.io.InputStream;

@RestController
public class ClockRecognitionController {
    @Autowired
    private ClockRecognitionUtil recognitionUtil;

    // 图片识别接口(POST请求,接收MultipartFile)
    @PostMapping("/api/recognize-clock")
    public ResponseEntity<?> recognizeClock(@RequestParam("image") MultipartFile file) {
        try {
            if (file.isEmpty()) {
                return ResponseEntity.badRequest().body("请上传图片文件!");
            }

            // 读取图片流并识别
            InputStream inputStream = file.getInputStream();
            String result = recognitionUtil.recognize(inputStream);

            // 返回识别结果
            return ResponseEntity.ok().body("识别结果:" + result);
        } catch (Exception e) {
            e.printStackTrace();
            return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).body("识别失败:" + e.getMessage());
        }
    }
}

(4)启动类

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class ClockRecognitionApplication {
    public static void main(String[] args) {
        SpringApplication.run(ClockRecognitionApplication.class, args);
    }
}

使用Spring Boot + DeepLearning4j(DL4J)的钟表图片识别

三、测试接口

1. 启动 Spring Boot 应用

2. 调用接口(使用 Postman 或 curl)

  • 请求方式:POST
  • 请求 URL:http://localhost:8080/api/recognize-clock
  • 请求参数:form-data 格式,key 为image,value 为钟表图片文件
  • 响应示例
  • json
  • “识别结果:指针钟”

四、常见问题

  1. 模型训练时内存溢出:减小批次大小(如从 16 改为 8)降低图片尺寸(如从 224×224 改为 128×128)增加 JVM 堆内存(-Xms4g -Xmx8g)
  2. 识别准确率低:增加数据集多样性(不同角度、光照、背景的钟表图片)延长训练轮数(或使用早停法防止过拟合)调整学习率(如训练后期降低学习率)
  3. Spring Boot 启动时模型加载失败:检查模型路径是否正确(使用ClassPathResource确保资源被加载)确认 DL4J 版本与 ND4J 后端版本一致

核心是模型训练接口部署两部分,可根据实际需求调整类别、模型结构和参数。

© 版权声明

相关文章

暂无评论

您必须登录才能参与评论!
立即登录
none
暂无评论...