要实现这个功能的思路是:训练一个 DL4J 图像分类模型识别钟表类型 / 状态(如数字钟、指针钟、特定品牌钟表等),再通过 Spring Boot 提供 HTTP 接口接收图片,调用训练好的模型返回识别结果。

一、环境准备
1. 技术栈版本
- JDK 11+(DL4J 对 JDK8 兼容性一般,推荐 11)
- Spring Boot 3.2.x
- DeepLearning4j 1.0.0-M2.1(稳定版,支持 CNN 图像分类)
- ND4J(DL4J 依赖的数值计算库)
- OpenCV(可选,用于图片预处理,如尺寸统一、灰度化)
2. Maven 依赖(pom.xml)
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.2.1</version>
<relativePath/>
</parent>
<groupId>com.example</groupId>
<artifactId>clock-recognition</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>clock-recognition</name>
<description>Spring Boot + DL4J 钟表图片识别</description>
<properties>
<java.version>11</java.version>
<dl4j.version>1.0.0-M2.1</dl4j.version>
<nd4j.backend>nd4j-native-platform</nd4j.backend> <!-- CPU版,GPU版用nd4j-cuda-11.8-platform -->
</properties>
<dependencies>
<!-- Spring Boot核心 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<!-- DL4J核心依赖 -->
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-zoo</artifactId> <!-- 预训练模型 zoo(可选) -->
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>${nd4j.backend}</artifactId>
<version>${dl4j.version}</version>
</dependency>
<!-- 图片处理 -->
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>javacv-platform</artifactId> <!-- 替代OpenCV,简化依赖 -->
<version>1.5.9</version>
</dependency>
<!-- 工具类 -->
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.15.1</version>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>
二、核心步骤
1. 数据集准备
- 类别定义:如数字钟、指针钟、智能手表(可根据需求扩展)
- 数据集结构(推荐 ImageNet 格式):
dataset/ train/ 数字钟/ img1.jpg img2.jpg … 指针钟/ img1.jpg … 智能手表/ … val/ # 验证集(约占总数据20%) 数字钟/ … …
- 数据量要求:每个类别至少 50 张图片(数据量越大,模型精度越高)
- 数据增强(可选,提升泛化能力):旋转、缩放、翻转、亮度调整等
2. DL4J 模型训练(CNN 分类模型)
使用 DL4J 的ComputationGraph构建卷积神经网络,步骤如下:
(1)数据加载与预处理
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.io.File;
import java.util.Random;
public class DataLoader {
// 图片尺寸(统一缩放为224x224,适配CNN)
private static final int HEIGHT = 224;
private static final int WIDTH = 224;
private static final int CHANNELS = 3; // RGB彩色图
private static final int BATCH_SIZE = 16; // 批次大小
private static final int NUM_CLASSES = 3; // 类别数(数字钟、指针钟、智能手表)
// 加载训练集
public static DataSetIterator loadTrainData(String dataPath) throws Exception {
File parentDir = new File(dataPath + "/train");
FileSplit fileSplit = new FileSplit(parentDir, NativeImageLoader.ALLOWED_FORMATS, new Random(123));
// 平衡数据集(避免类别不平衡)
BalancedPathFilter pathFilter = new BalancedPathFilter(new Random(123), NativeImageLoader.ALLOWED_FORMATS, NUM_CLASSES);
// 图片读取器(转换为NDArray)
ImageRecordReader recordReader = new ImageRecordReader(HEIGHT, WIDTH, CHANNELS);
recordReader.initialize(fileSplit, pathFilter);
// 转换为DataSetIterator(标签自动从文件夹名生成)
return new RecordReaderDataSetIterator(
recordReader, BATCH_SIZE, 1, NUM_CLASSES, true
);
}
// 加载验证集(逻辑同训练集,路径改为val)
public static DataSetIterator loadValData(String dataPath) throws Exception {
File parentDir = new File(dataPath + "/val");
FileSplit fileSplit = new FileSplit(parentDir, NativeImageLoader.ALLOWED_FORMATS, new Random(123));
BalancedPathFilter pathFilter = new BalancedPathFilter(new Random(123), NativeImageLoader.ALLOWED_FORMATS, NUM_CLASSES);
ImageRecordReader recordReader = new ImageRecordReader(HEIGHT, WIDTH, CHANNELS);
recordReader.initialize(fileSplit, pathFilter);
return new RecordReaderDataSetIterator(
recordReader, BATCH_SIZE, 1, NUM_CLASSES, true
);
}
}
(2)构建 CNN 模型
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class ClockCNNModel {
private static final int HEIGHT = 224;
private static final int WIDTH = 224;
private static final int CHANNELS = 3;
private static final int NUM_CLASSES = 3;
// 构建CNN模型(简化版VGG16)
public static ComputationGraph buildModel() {
ComputationGraphConfiguration config = new ComputationGraphConfiguration.Builder()
// 输入层:224x224x3
.setInputTypes(InputType.convolutional(HEIGHT, WIDTH, CHANNELS))
// 卷积层1:32个3x3卷积核,步长1,填充same
.addLayer("conv1", new ConvolutionLayer.Builder(3, 3)
.nIn(CHANNELS)
.nOut(32)
.stride(1, 1)
.padding(1, 1)
.weightInit(WeightInit.XAVIER)
.activation(Activation.RELU)
.convolutionMode(ConvolutionMode.Same)
.build(), "input")
// 池化层1:2x2最大池化
.addLayer("pool1", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build(), "conv1")
// 卷积层2 + 池化层2
.addLayer("conv2", new ConvolutionLayer.Builder(3, 3)
.nOut(64)
.stride(1, 1)
.padding(1, 1)
.weightInit(WeightInit.XAVIER)
.activation(Activation.RELU)
.convolutionMode(ConvolutionMode.Same)
.build(), "pool1")
.addLayer("pool2", new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build(), "conv2")
// 全连接层1
.addLayer("fc1", new DenseLayer.Builder()
.nOut(512)
.weightInit(WeightInit.XAVIER)
.activation(Activation.RELU)
.dropOut(0.5) // dropout防止过拟合
.build(), "pool2")
// 输出层:softmax激活(多分类)
.addLayer("output", new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(NUM_CLASSES)
.weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX)
.build(), "fc1")
// 定义输入输出节点
.setOutputs("output")
// 优化器:Adam(学习率0.001)
.updater(new Adam(0.001))
// 梯度归一化(稳定训练)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(1.0)
.build();
ComputationGraph model = new ComputationGraph(config);
model.init(); // 初始化模型
return model;
}
}
(3)模型训练与保存
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.io.ClassPathResource;
import java.io.File;
public class ModelTrainer {
private static final int EPOCHS = 20; // 训练轮数
private static final String DATA_PATH = "path/to/your/dataset"; // 数据集路径
private static final String MODEL_SAVE_PATH = "src/main/resources/model/clock-model.zip"; // 模型保存路径
public static void train() throws Exception {
// 1. 加载数据
DataSetIterator trainIter = DataLoader.loadTrainData(DATA_PATH);
DataSetIterator valIter = DataLoader.loadValData(DATA_PATH);
// 2. 构建模型
ComputationGraph model = ClockCNNModel.buildModel();
// 3. 训练模型
for (int i = 0; i < EPOCHS; i++) {
model.fit(trainIter); // 训练一轮
System.out.println("完成第 " + (i+1) + " 轮训练");
// 4. 验证模型
Evaluation eval = model.evaluate(valIter);
System.out.println("验证集准确率:" + eval.accuracy());
System.out.println(eval.stats());
// 重置迭代器(下一轮重新读取数据)
trainIter.reset();
valIter.reset();
}
// 5. 保存模型(用于Spring Boot部署)
model.save(new File(new ClassPathResource("").getFile(), MODEL_SAVE_PATH));
System.out.println("模型保存成功:" + MODEL_SAVE_PATH);
}
public static void main(String[] args) throws Exception {
train(); // 执行训练
}
}
3. Spring Boot 部署模型(提供 HTTP 识别接口)
(1)模型加载配置(单例模式,避免重复加载)
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import javax.annotation.PostConstruct;
import java.io.File;
import java.io.IOException;
@Configuration
public class ModelConfig {
private ComputationGraph model;
private final String MODEL_PATH = "src/main/resources/model/clock-model.zip";
private final int HEIGHT = 224;
private final int WIDTH = 224;
private final int CHANNELS = 3;
// 加载模型(项目启动时执行)
@PostConstruct
public void loadModel() throws IOException {
model = ComputationGraph.load(new File(MODEL_PATH), true);
System.out.println("模型加载成功!");
}
// 提供模型Bean(供Controller调用)
@Bean
public ComputationGraph getModel() {
return model;
}
// 图片预处理(缩放、归一化)
@Bean
public ImagePreProcessingScaler getImageScaler() {
return new ImagePreProcessingScaler(0, 1); // 像素值归一化到[0,1]
}
// 类别映射(索引->类别名)
@Bean
public String[] getLabelMap() {
// 注意:与训练时的类别顺序一致(文件夹名排序)
return new String[]{"数字钟", "指针钟", "智能手表"};
}
}
(2)图片识别工具类
import org.bytedeco.javacv.CanvasFrame;
import org.bytedeco.javacv.OpenCVFrameConverter;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.awt.image.BufferedImage;
import java.io.InputStream;
@Component
public class ClockRecognitionUtil {
@Autowired
private ComputationGraph model;
@Autowired
private ImagePreProcessingScaler imageScaler;
@Autowired
private String[] labelMap;
private final int HEIGHT = 224;
private final int WIDTH = 224;
private final int CHANNELS = 3;
// 识别图片(输入InputStream,返回类别名)
public String recognize(InputStream imageInputStream) throws Exception {
// 1. 读取图片并转换为INDArray(224x224x3)
NativeImageLoader imageLoader = new NativeImageLoader(HEIGHT, WIDTH, CHANNELS);
INDArray imageArray = imageLoader.asMatrix(imageInputStream);
// 2. 图片预处理(归一化)
imageScaler.transform(imageArray);
// 3. 模型预测(添加批次维度:[1, 3, 224, 224])
INDArray[] outputs = model.output(imageArray.reshape(1, CHANNELS, HEIGHT, WIDTH));
INDArray predictions = outputs[0]; // 预测结果(概率分布)
// 4. 获取概率最大的类别索引
int predictedIndex = predictions.argMax(1).getInt(0);
// 5. 返回类别名
return labelMap[predictedIndex];
}
}
(3)HTTP 接口 Controller
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
import java.io.InputStream;
@RestController
public class ClockRecognitionController {
@Autowired
private ClockRecognitionUtil recognitionUtil;
// 图片识别接口(POST请求,接收MultipartFile)
@PostMapping("/api/recognize-clock")
public ResponseEntity<?> recognizeClock(@RequestParam("image") MultipartFile file) {
try {
if (file.isEmpty()) {
return ResponseEntity.badRequest().body("请上传图片文件!");
}
// 读取图片流并识别
InputStream inputStream = file.getInputStream();
String result = recognitionUtil.recognize(inputStream);
// 返回识别结果
return ResponseEntity.ok().body("识别结果:" + result);
} catch (Exception e) {
e.printStackTrace();
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR).body("识别失败:" + e.getMessage());
}
}
}
(4)启动类
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class ClockRecognitionApplication {
public static void main(String[] args) {
SpringApplication.run(ClockRecognitionApplication.class, args);
}
}

三、测试接口
1. 启动 Spring Boot 应用
2. 调用接口(使用 Postman 或 curl)
- 请求方式:POST
- 请求 URL:http://localhost:8080/api/recognize-clock
- 请求参数:form-data 格式,key 为image,value 为钟表图片文件
- 响应示例:
- json
- “识别结果:指针钟”
四、常见问题
- 模型训练时内存溢出:减小批次大小(如从 16 改为 8)降低图片尺寸(如从 224×224 改为 128×128)增加 JVM 堆内存(-Xms4g -Xmx8g)
- 识别准确率低:增加数据集多样性(不同角度、光照、背景的钟表图片)延长训练轮数(或使用早停法防止过拟合)调整学习率(如训练后期降低学习率)
- Spring Boot 启动时模型加载失败:检查模型路径是否正确(使用ClassPathResource确保资源被加载)确认 DL4J 版本与 ND4J 后端版本一致
核心是模型训练和接口部署两部分,可根据实际需求调整类别、模型结构和参数。
© 版权声明
文章版权归作者所有,未经允许请勿转载。
相关文章
暂无评论...


