Google在今年的I/O大会,发布数项TensorFlow与Keras深度学习工具更新,重点包括可让开发者能够简单存取预训练模型的模块化函式库,还推出可用于同步分布式模型运算的扩充套件DTensor,而借由新的JAX2TF API,开发者便能够在TensorFlow生态系中,使用JAX数值函式库编写的模型。

模块化库KerasCV与KerasNLP,可简化计算机视觉和自然语言处理预训练模型访问。 这两个新的模块化库,让开发者只要编写几行代码,就可在应用程序中整合图像分类或是文字生成等机器学习功能。 由于这两个函式库皆是Keras的一部分,而Keras在TensorFlow 2.0成为内置进阶API,因此开发者能够直接在TensorFlow中使用Keras,这也就代表KerasCV、KerasNLP与TensorFlow生态系可完全整合。

TensorFlow扩充套件DTensor通过组合并微调多种平行技术,以支持更大且高性能的模型训练。 以往机器学习开发人员可以通过数据平行技术扩展模型,将数据拆分之后,供水平扩展的模型实例训练使用,不过这种扩展训练方法有一个严重的限制,即是要求模型在单个硬件设备执行。
但随着模型越来越大,单一设备的运算能力可能不足以处理庞大的模型,因此开发者开始需要将模型扩展到更多硬件设备上执行。 也就是说训练庞大模型不仅需要数据平行性,还需要模型平行性,将模型分割成可以平行训练的分片。
而DTensor不只支持数据平行性,也提供模型平行性,通过结合这两种技术,更有效地扩展模型,同时DTensor也不受加速器类型的限制,支持TPU、GPU等各种运算装置。

Google也释出轻量级API JAX2TF,来加速机器学习研究生产化的速度。 Google开发的Python库JAX被大量用于高效数值运算上,同时JAX也支持硬件加速,能够在GPU或TPU上高速处理大型数据集和复杂运算,但要把JAX用于生产中仍不是一件直观简单的事。
而JAX2TF API的出现,是要让JAX能够更简单地进入TensorFlow生态系,使开发者可以将JAX模型部署到TensorFlow Serving服务器或是TFLite装置上,并在TensorFlow中继续训练JAX模型,甚至是将JAX模型和TensorFlow模型融合,以获得更大的灵活性。
除了以上更新,开发团队也预告,他们即将推出TensorFlow量化API,该API将会是TensorFlow 2的原生量化工具,能够在不影响模型质量的前提下,进一步缩小模型,并且提升模型执行速度。