不想学Python又想给应用加AI能力?本文手把手教你用Java和TensorFlow训练神经网络模型,从Iris鸢尾花分类到Spring Boot集成部署。附完整代码+Maven依赖+避坑指南。
为什么Java开发者需要关注TensorFlow?
想象下面这个场景
假如你是某电商公司的Java后端开发。产品经理拿着一份AI研究报告来找你:
“我们想在订单系统里加一个功能:根据用户输入的投诉文本,自动分类投诉类型(物流问题、产品质量、退换货)。我看AI都能做,给你三天时间搞一下?”
你的第一反应可能是:“这要用Python吧?我要去学PyTorch/TensorFlow、搞环境配置、再搭个Flask服务……”
好消息是:你可以直接用Java完成这件事。
TensorFlow官方提供了完整的Java API。你可以:
本文就用一个经典案例(Iris鸢尾花分类)带你走通这条链路。即使你对神经网络零基础,也能跟着代码跑起来。
TensorFlow Java入门准备
Maven依赖配置
<?xml version="1.0" encoding="UTF-8"?><projectxmlns="http://maven.apache.org/POM/4.0.0"xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"><modelVersion>4.0.0</modelVersion><groupId>com.example</groupId><artifactId>tensorflow-java-demo</artifactId><version>1.0-SNAPSHOT</version><properties><maven.compiler.source>17</maven.compiler.source><maven.compiler.target>17</maven.compiler.target><tensorflow.version>0.5.0</tensorflow.version></properties><dependencies><!-- TensorFlow Java Core --><dependency><groupId>org.tensorflow</groupId><artifactId>tensorflow-core-platform</artifactId><version>${tensorflow.version}</version></dependency><!-- 日志 --><dependency><groupId>org.slf4j</groupId><artifactId>slf4j-simple</artifactId><version>2.0.9</version></dependency></dependencies></project>
验证安装
import org.tensorflow.TensorFlow;publicclassTestTensorFlow{publicstaticvoidmain(String[] args){ System.out.println("TensorFlow version: " + TensorFlow.version()); }}
预期输出:TensorFlow version: 2.15.0
神经网络核心概念速查
💡 建议:第一遍阅读时快速扫过,看到代码后再回头对照。
| | |
|---|
| 层(Layer) | | hiddenLayer1Weights |
| 权重(Weight) | | Variable(initializer) |
| 偏置(Bias) | | variable(fill(0.1f)) |
| 激活函数 | | nn.relu() |
| 损失函数 | | MeanSquaredError |
| 优化器 | | Adam |
数据集准备——Iris鸢尾花分类
数据集介绍
Iris数据集是机器学习界的“Hello World”:
| |
|---|
| |
| |
| 3个(Setosa、Versicolour、Virginica) |
数据实体类
publicenum IrisSpecies { IRIS_SETOSA(0, "Iris-setosa"), IRIS_VERSICOLOUR(1, "Iris-versicolor"), IRIS_VIRGINICA(2, "Iris-virginica");// 省略构造方法和getter}publicclassIrisDataLine{privatefinalfloat sepalLength;privatefinalfloat sepalWidth;privatefinalfloat petalLength;privatefinalfloat petalWidth;privatefinal IrisSpecies irisSpecies;// 省略构造方法和getter}
数据读取
private List<IrisDataLine> loadTrainingData()throws IOException { List<IrisDataLine> trainData = new ArrayList<>();try (InputStream is = getClass().getResourceAsStream("/iris.csv"); BufferedReader reader = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) { reader.readLine(); // 跳过表头 String line;while ((line = reader.readLine()) != null) {if (line.isBlank()) continue; String[] parts = line.split(","); trainData.add(new IrisDataLine( Float.parseFloat(parts[0]), // sepal_length Float.parseFloat(parts[1]), // sepal_width Float.parseFloat(parts[2]), // petal_length Float.parseFloat(parts[3]), // petal_width IrisSpecies.getIrisSpecies(parts[4]) )); } }return trainData;}
构建神经网络模型
网络拓扑
输入层(4) → 隐藏层1(5) → 隐藏层2(4) → 输出层(3)激活函数: ReLU ReLU Softmax
配置常量
publicclassModelConfig{publicstaticfinalint INPUT_LAYER_WIDTH = 4;publicstaticfinalint HIDDEN_LAYER_1_WIDTH = 5;publicstaticfinalint HIDDEN_LAYER_2_WIDTH = 4;publicstaticfinalint OUTPUT_LAYER_WIDTH = 3;publicstaticfinalint TRAINING_EPOCHS = 4;publicstaticfinalfloat LEARNING_RATE = 0.01f;publicstaticfinallong RANDOM_SEED = 42L;}
构建代码
// 创建计算图Graph graph = new Graph();Ops tf = Ops.create(graph);// 输入层占位符Placeholder<TFloat32> inputLayer = tf.withName("input_placeholder") .placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 4)));// 初始化器Initializer initializer = new Glorot<>(Distribution.NORMAL, RANDOM_SEED);// 隐藏层1:4 → 5Var<TFloat32> h1Weights = tf.withName("hidden1_weights") .variable(initializer.call(tf, tf.array(4, 5), TFloat32.class));Var<TFloat32> h1Biases = tf.withName("hidden1_biases") .variable(tf.fill(tf.array(5), tf.constant(0.1f)));var h1 = tf.nn.relu(tf.math.add(tf.linalg.matMul(inputLayer, h1Weights), h1Biases));// 隐藏层2:5 → 4Var<TFloat32> h2Weights = tf.withName("hidden2_weights") .variable(initializer.call(tf, tf.array(5, 4), TFloat32.class));Var<TFloat32> h2Biases = tf.withName("hidden2_biases") .variable(tf.fill(tf.array(4), tf.constant(0.1f)));var h2 = tf.nn.relu(tf.math.add(tf.linalg.matMul(h1, h2Weights), h2Biases));// 输出层:4 → 3Var<TFloat32> outWeights = tf.withName("output_weights") .variable(initializer.call(tf, tf.array(4, 3), TFloat32.class));Var<TFloat32> outBiases = tf.withName("output_biases") .variable(tf.fill(tf.array(3), tf.constant(0.1f)));var output = tf.withName("output_activation") .nn.softmax(tf.math.add(tf.linalg.matMul(h2, outWeights), outBiases));
训练模型
训练配置
try (Session session = new Session(graph)) {var lossFunction = new MeanSquaredError(Reduction.AUTO);var optimizer = new Adam(LEARNING_RATE); Placeholder<TFloat32> trainingOutput = tf.placeholder(TFloat32.class,Placeholder.shape(Shape.of(-1, 3)));var loss = lossFunction.call(tf, trainingOutput, graph.operation("output_activation").output(0));var minimize = optimizer.minimize(loss);// 训练循环...}
训练循环
List<IrisDataLine> trainData = loadTrainingData();for (int epoch = 0; epoch < 4; epoch++) {int correct = 0;for (IrisDataLine data : trainData) {try (Tensor<TFloat32> input = createInputTensor(data); Tensor<TFloat32> expected = createOutputTensor(data.getIrisSpecies())) { List<Tensor<?>> result = session.runner() .addTarget(minimize) .feed("input_placeholder", input) .feed(trainingOutput, expected) .fetch("output_activation") .run(); Tensor<TFloat32> out = (TFloat32) result.get(0);if (predictSpecies(out) == data.getIrisSpecies()) correct++; out.close(); result.forEach(Tensor::close); } } System.out.printf("Epoch %d: 准确率 %.1f%% (%d/150)%n", epoch, correct/150.0*100, correct);}
预期输出:
Epoch 0: 准确率 78.0% (117/150)Epoch 1: 准确率 80.7% (121/150)Epoch 2: 准确率 82.0% (123/150)Epoch 3: 准确率 82.7% (124/150)
保存和导出模型
// 构建签名Signature signature = Signature.builder() .key(Signature.DEFAULT_KEY) .input("input_placeholder", graph.operation("input_placeholder").output(0)) .output("output_activation", graph.operation("output_activation").output(0)) .build();// 导出模型SessionFunction sessionFunction = SessionFunction.create(signature, session);SavedModelBundle.exporter("./models/iris_classifier") .withFunction(sessionFunction) .withTags(SavedModelBundle.DEFAULT_TAG) .export();
导出后的目录结构:
models/iris_classifier/├── saved_model.pb└── variables/ ├── variables.data-00000-of-00001 └── variables.index
加载模型并进行预测
publicclassIrisClassifier{private SavedModelBundle model;private Session session;publicIrisClassifier(String modelPath){ model = SavedModelBundle.load(modelPath, "serve"); session = model.session(); }publicfloat[] predict(float sepalLength, float sepalWidth,float petalLength, float petalWidth) {try (Tensor<TFloat32> input = Tensor.of(TFloat32.class, Shape.of(1, 4), data -> { data.setFloat(sepalLength, 0, 0); data.setFloat(sepalWidth, 0, 1); data.setFloat(petalLength, 0, 2); data.setFloat(petalWidth, 0, 3); })) { List<Tensor<?>> result = session.runner() .feed("input_placeholder", input) .fetch("output_activation") .run(); Tensor<TFloat32> output = (TFloat32) result.get(0);float[] probs = newfloat[]{ output.getFloat(0, 0), output.getFloat(0, 1), output.getFloat(0, 2) }; output.close(); result.forEach(Tensor::close);return probs; } }publicvoidclose(){if (session != null) session.close();if (model != null) model.close(); }}
Spring Boot集成实战
模型服务Bean
@ServicepublicclassTensorFlowModelService{@Value("${tensorflow.model.path:./models/iris_classifier}")private String modelPath;private SavedModelBundle model;private Session session;@PostConstructpublicvoidinit(){ model = SavedModelBundle.load(modelPath, "serve"); session = model.session(); System.out.println("✅ 模型加载成功"); }@PreDestroypublicvoidcleanup(){if (session != null) session.close();if (model != null) model.close(); }publicfloat[] predict(float sepalLength, float sepalWidth,float petalLength, float petalWidth) {// 见上一节代码 }}
Controller
@RestController@RequestMapping("/api/iris")@RequiredArgsConstructorpublicclassIrisController{privatefinal TensorFlowModelService modelService;@PostMapping("/predict")public IrisPredictResponse predict(@RequestBody IrisPredictRequest request){float[] probs = modelService.predict( request.getSepalLength(), request.getSepalWidth(), request.getPetalLength(), request.getPetalWidth() );int idx = getMaxIndex(probs); String[] species = {"Setosa", "Versicolour", "Virginica"};returnnew IrisPredictResponse(species[idx], probs[idx], probs); }}
测试API
curl -X POST http://localhost:8080/api/iris/predict \ -H "Content-Type: application/json" \ -d '{"sepalLength":5.1,"sepalWidth":3.5,"petalLength":1.4,"petalWidth":0.2}'
响应:
{"predictedSpecies": "Setosa","confidence": 0.92,"probabilities": [0.92, 0.05, 0.03]}
使用预训练模型(EfficientDet)
查看模型Signature
SavedModelBundle model = SavedModelBundle.load("./models/efficientdet-d0", "serve");System.out.println(model.signatures());
输出:
Inputs: "input_tensor": dtype=DT_UINT8, shape=(1, -1, -1, 3)Outputs: "detection_boxes", "detection_scores", "detection_classes", "num_detections"
推理代码
public List<DetectionResult> detect(float[][][] image){try (Tensor<TUint8> input = createImageTensor(image)) { List<Tensor<?>> result = session.runner() .feed("input_tensor", input) .fetch("detection_boxes") .fetch("detection_scores") .fetch("detection_classes") .run();// 解析结果... }}
选型对比与总结
Java vs Python TensorFlow
最佳实践:Python训练 → 导出SavedModel → Java加载推理
核心要点清单
## 训练阶段- 定义网络拓扑- 配置损失函数和优化器- 训练循环 + 准确率监控## 导出阶段- 构建Signature- 使用 SavedModelBundle.exporter()## 推理阶段- 加载模型:SavedModelBundle.load()- 创建输入Tensor- session.runner().feed().fetch()- 关闭Tensor和Session## Spring Boot集成- @PostConstruct 加载模型- @PreDestroy 释放资源- 作为单例Bean提供服务
写在最后
一句话总结:Java + TensorFlow = 用你最熟悉的语言,拥抱AI的能力。
对于Java开发者来说,TensorFlow Java API提供了一个低门槛的AI入口。你不需要切换到Python,不需要搭建额外的微服务,就能在自己的应用中直接调用神经网络模型。
无论是经典的Iris分类,还是前沿的目标检测,Java都能胜任推理这一核心环节。
如果本文帮你打开了Java+AI的大门,欢迎收藏和分享。