掌握聚合最新动态了解行业最新趋势
API接口,开发服务,免费咨询服务

PyTorch和TensorFlow哪家强:九项对比读懂各自长项短板

本文经机器之心(微信公众号:almosthuman2014)授权转载,禁止二次转载。

这篇指南主要介绍了我找到的 PyTorch 和 TensorFlow 之间的不同之处。这篇文章的目的是帮助那些想要开始一个新项目或从一种深度学习框架切换到另一种框架的人。本文重点关注的是在设置训练组件和部署深度学习时的可编程性和灵活性。我不会深入到性能方面(速度/内存占用)的比较。

概要

PyTorch 更适用于研究、爱好者和小规模项目的快速原型开发。TensorFlow 更适合大规模部署,尤其是涉及跨平台和嵌入式部署时。

上手时间

获胜者:PyTorch

PyTorch 本质上是支持 GPU 的 NumPy 替代,配备了可用于构建和训练深度神经网络的更高级的功能。所以如果你熟悉 NumPy、Python 和常用的深度学习抽象(卷积层、循环层、SGD 等),那 PyTorch 就很容易学。

另一方面,则可以将 TensorFlow 看作是一种嵌入 Python 的编程语言。当你编写 TensorFlow 代码时,它会被 Python「编译」成图(graph),然后由 TensorFlow 执行引擎运行。我看到过有些 TensorFlow 新手难以理解这额外增加的间接一层工序。同样因为这个原因,TensorFlow 还有一些需要额外学习的概念,比如会话(session)、图、变量范围和占位符。要让基本的模型跑起来也需要更多样板代码。上手 TensorFlow 的时间肯定会比 PyTorch 长。

创建图和调试

获胜者:PyTorch

创建和运行计算图可能是这两个框架差别最大的地方。在 PyTorch 中,图结构是动态的,也就是说图是在运行时创建的。在 TensorFlow 中,图结构是静态的,也就是说图在「编译」之后再运行。举个简单例子,在 PyTorch 中,你可以使用标准的 Python 句法写一个 for 循环:

for _ in range(T):
    h = torch.matmul(W, h) + b

而且 T 可以在这段代码的执行之间改变。在 TensorFlow 中,这需要在构建图时使用控制流操作(control flow operations),比如 tf.while_loop。TensorFlow 确实有 dynamic_rnn 可用于更常见的结构,但创建自定义的动态计算也更加困难。

PyTorch 简单的图构建方式更容易理解,但也许更重要的是也更容易调试。调试 PyTorch 代码就跟调试 Python 代码一样。你可以使用 pdb,并且可以在任何地方设置断点。调试 TensorFlow 则没这么容易。它有两个选择,一是从会话中请求你想检查的变量,而是学会使用 TensorFlow 调试器(tfdbg)。

覆盖度

获胜者:TensorFlow

随着 PyTorch 的发展,我预计这两者之间的差距会缩小至零。但是,TensorFlow 仍然支持一些 PyTorch 并不支持的功能。PyTorch 目前还不具备的特性包括:

  • 沿维度方向的张量翻转(np.flip、 np.flipud、 np.fliplr)
  • 检查张量是否为 NaN 和无穷大(np.is_nan、np.is_inf)
  • 快速傅立叶变换(np.fft)

而 TensorFlow 支持所有这些。另外比起 PyTorch,TensorFlow 的 contrib 包也有远远更多更高级的函数和模型。

序列化(serialization)

获胜者:TensorFlow

在这两种框架中,保存和加载模型都很简单。PyTorch 有一个非常简单的 API,既可以保存模型的所有权重,也可以 pickle(加工)整个类。TensorFlow 的 Saver 对象也很容易使用,而且也为检查点提供了更多选择。

TensorFlow 在序列化方面的主要优势是整个计算图都可以保存为 protocol buffer。这既包括参数,也包括运算。然后这个图可以用其它支持的语言(C++、Java)加载。对于不支持 Python 的部署环境来说,这是非常重要的功能。而且理论上,这个功能也可以在你修改模型的源代码,但又想运行旧模型时为你提供帮助。


部署

获胜者:TensorFlow

对于小型服务器(比如 Flask 网页服务器)上的部署,两种框架都很简单。

TensorFlow 支持移动和嵌入式部署,而包括 PyTorch 在内的很多深度学习框架都没有这个能力。在 TensorFlow 上,要将模型部署到安卓或 iOS 上需要不小的工作量,但至少你不必使用 Java 或 C++ 重写你模型的整个推理部分。

对于高性能服务器上的部署,还有 TensorFlow Serving 可用。我还没体验过 TensorFlow Serving,所以我不能说它有哪些优缺点。对于严重依赖机器学习的服务,我猜想 TensorFlow Serving 可能就是继续使用 TensorFlow 的充分理由。除了性能方面的优势,TensorFlow Serving 的另一个重要特性是无需中断服务,就能实现模型的热插拔。Zendesk 的这篇博客文章介绍了使用 TensorFlow Serving 部署一个问答机器人的案例:https://medium.com/zendesk-engineering/how-zendesk-serves-tensorflow-models-in-production-751ee22f0f4b

文档

获胜者:平局

对于这两种框架,我都找到了我需要的一切。Python API 的文档做得很好,两个框架也都有足够多的示例和教程可以学习。

一个不太重要的麻烦是 PyTorch 的 C 语言库基本上没有文档。但是,只有当你编写自定义 C 语言扩展或为这个软件库贡献代码时才有用。

数据加载

获胜者:PyTorch

PyTorch 的数据加载 API 设计得很好。数据集、采样器和数据加载器的接口都是特定的。数据加载器可以接收一个数据集和一个采样器,并根据该采样器的调度得出数据集上的一个迭代器(iterator)。并行化数据加载很简单,只需为数据加载器传递一个 num_workers 参数即可。

我还没找到 TensorFlow 的非常有用的数据加载工具(读取器、队列、队列运行器等等)。部分原因是要将你想并行运行的所有预处理代码加入到 TensorFlow 图中并不总是那么简单直接(比如计算频谱图)。另外,TensorFlow 的 API 本身也更加冗长,学习起来也更难。

设备管理

获胜者:TensorFlow

TensorFlow 的设备管理的无缝性能非常好。通常你不需要指定任何东西,因为默认的设置就很好。比如说,TensorFlow 假设如果存在可用的 GPU,你就希望在 GPU 上运行。而在 PyTorch 中,你必须在启用了 CUDA 之后明确地将所有东西移到 GPU 上。

TensorFlow 设备管理的唯一缺陷是它会默认占用所有可用的 GPU 上的所有内存,即使真正用到的只有其中一个。但也有一种简单的解决方案,就是指定 CUDA_VISIBLE_DEVICES。有时候人们会忘记这一点,就会让 GPU 看起来很繁忙,尽管实际上它们啥也没干。

在使用 PyTorch 时,我发现我的代码需要更频繁地检查 CUDA 的可用性和更明确的设备管理。尤其是当编写可以在 CPU 和 GPU 上同时运行的代码时更是如此。另外,要将 GPU 上的 PyTorch Variable 等转换成 NumPy 数组也较为繁琐。

numpy_var = variable.cpu().data.numpy()

自定义扩展

获胜者:PyTorch

这两种框架都可以构建或绑定用 C、C++ 或 CUDA 写的扩展。TensorFlow 还是需要更多样板代码,尽管有人认为它能更简单清晰地支持多种类型和设备。在 PyTorch 中,你只需要简单地为每个 CPU 和 GPU 版本写一个接口和对应实现即可。这两种框架对扩展的编译都很直接,不需要下载 pip 安装之外的任何头文件或源代码。

关于 TensorBoard 的一点说明

TensorBoard 是一个用于可视化训练机器学习模型各个方面的工具。它是 TensorFlow 项目产出的最有用的功能之一。仅需在训练脚本中加入少许代码,你就可以查看任何模型的训练曲线和验证结果。TensorBoard 作为一个网页服务运行,可以尤其方便地可视化存储在 headless 节点上的结果。

这是我在使用 PyTorch 时也想继续使用的一个功能(或找到可替代的工具)。幸运的是,确实有这样的工具——至少有两个开源工具可以做到:

tensorboard_logger 库用起来甚至比 TensorFlow 中的 TensorBoard「summaries」还简单,尽管你需要在安装了 TensorBoard 后才能使用它。crayon 项目可以完全替代 TensorBoard,但需要更多设置(docker 是必需的前提)。

关于 Keras 的一点说明

Keras 是一种带有可配置的后端的更高层的 API。目前支持 TensorFlow、Theano 和 CNTK,尽管也许不久之后 PyTorch 也将加入这一名单。Keras 也作为 tf.contrib 通过 TensorFlow 提供。

尽管前面我没有讨论 Keras,但这个 API 的使用尤其简单。它是运行许多常用深度神经网络架构的最快方式。话虽如此,这个 API 并没有 PyTorch 或核心 TensorFlow 那么灵活。

关于 TensorFlow Fold 的一点说明

2017 年 2 月,谷歌发布了 TensorFlow Fold。这个库构建于 TensorFlow 之上,允许实现更动态的图构建。这个库的主要优势应该是动态批量化处理(dynamic batching)。动态批量化可以自动批量化处理不同规模的输入的计算(考虑一下解析树上的递归网络)。从可编程性上看,它的句法并没有 PyTorch 的那么简单,尽管考虑到批量化在一些情况下带来的性能提升,这样的成本也是值得的。

原文来自:机器之心

声明:所有来源为“聚合数据”的内容信息,未经本网许可,不得转载!如对内容有异议或投诉,请与我们联系。邮箱:marketing@think-land.com

  • 个人/企业涉诉查询

    通过企业关键词查询企业涉讼详情,如裁判文书、开庭公告、执行公告、失信公告、案件流程等等。

    通过企业关键词查询企业涉讼详情,如裁判文书、开庭公告、执行公告、失信公告、案件流程等等。

  • IP反查域名

    IP反查域名是通过IP查询相关联的域名信息的功能,它提供IP地址历史上绑定过的域名信息。

    IP反查域名是通过IP查询相关联的域名信息的功能,它提供IP地址历史上绑定过的域名信息。

  • 人脸卫士

    结合权威身份认证的精准人脸风险查询服务,提升人脸应用及身份认证生态的安全性。人脸风险情报库,覆盖范围广、准确性高,数据权威可靠。

    结合权威身份认证的精准人脸风险查询服务,提升人脸应用及身份认证生态的安全性。人脸风险情报库,覆盖范围广、准确性高,数据权威可靠。

  • 全国城市空气质量

    全国城市和站点空气质量查询,污染物浓度及空气质量分指数、空气质量指数、首要污染物及空气质量级别、健康指引及建议采取的措施等。

    全国城市和站点空气质量查询,污染物浓度及空气质量分指数、空气质量指数、首要污染物及空气质量级别、健康指引及建议采取的措施等。

  • 手机号防骚扰黑名单

    输入手机号和拦截等级,查看是否是风险号码

    输入手机号和拦截等级,查看是否是风险号码

0512-88869195
数 据 驱 动 未 来
Data Drives The Future