大数据深度学习逐步成为研究的主流趋势。这是《30 天吃掉那只 TensorFlow2.0》里面的其中一篇,介绍在 Spark 中调用训练好的 TensorFlow 模型进行预测的方法。本篇文章通过 TensorFlow for Java 在 Spark 中调用训练好的 TensorFlow 模型。利用 Spark 的分布式计算能力,从而可以让训练好的 TensorFlow 模型在成百上千的机器上分布式并行执行模型推断。
本案例以 TensorFlow 2.0 的 tf.keras 接口训练的线性模型为例进行演示。在本例基础上稍作修改则可以用 Spark 调用训练好的各种复杂的神经网络模型进行分布式模型推断。但实际上 TensorFlow 并不仅仅适合实现神经网络,其底层的计算图语言可以表达各种数值计算过程。利用其丰富的低阶 API,我们可以在 TensorFlow 2.0 上实现任意机器学习模型。结合 tf.Module 提供的便捷的封装功能,我们可以将训练好的模型导出成模型文件并在 Spark 上分布式调用执行。这无疑为我们的工程应用提供了巨大的想象空间。