如何在 S3 上运行 TensorFlow
TensorFlow 支持读取和写入 S3。S3 是一种广泛使用的对象存储 API,并且可以用于数据会被多个参与者访问的情况下,例如在分布式训练中。
本文档将帮助你完成所需的设置,并提供相关的使用示例。
配置
你可以通过一系列环境变量来控制你的 TensorFlow 程序在 S3 上读写数据:
- AWS_REGION:S3 默认使用区域终端节点,区域通过
AWS_REGION
来设置。如果AWS_REGION
没有被指定,默认会使用us-east-1
。 - S3_ENDPOINT:重写
S3_ENDPOINT
来明确指定节点。 - S3_USE_HTTPS:S3 默认通过 HTTPS 访问,除非你设置了
S3_USE_HTTPS=0
。 - S3_VERIFY_SSL:如果你使用了 HTTPS,可以通过
S3_VERIFY_SSL=0
来关闭 SSL 校验。
要读取或写入不可公开访问的存储桶中的对象,必须通过以下方法之一提供 AWS 凭据:
- 本地的 AWS 证书文件中设置凭证,Linux、 macOS 和 Unix 存放在
~/.aws/credentials
下,Windows 存放在C:\Users\USERNAME\.aws\credentials
下。 - 设置
AWS_ACCESS_KEY_ID
和AWS_SECRET_ACCESS_KEY
环境变量。 - 如果 TensorFlow 部署在一个 EC2 实例上,创建一个 IAM 角色,并赋予该 EC2 实例使用该角色的权限。
示例设置
使用上述信息,我们可以通过设置以下环境变量来配置 TensorFlow 与 S3 端点进行通信:
bash
AWS_ACCESS_KEY_ID=XXXXX # 仅在连接到专用端点时才需要凭据
AWS_SECRET_ACCESS_KEY=XXXXX
AWS_REGION=us-east-1 # S3 存储区的区域,并不是必填的。默认为 us-east-1。
S3_ENDPOINT=s3.us-east-1.amazonaws.com # 要连接的 S3 API 端点。这是以 `主机名:端口` 格式指定的。
S3_USE_HTTPS=1 # 是否启用 HTTPS。0 为禁用。
S3_VERIFY_SSL=1 # 如果使用 HTTPS,这个属性用于控制是否启用 SSL。0 为禁用。
使用
设置完成后,TensorFlow 就可以通过多种方式与 S3 进行交互。在任何有 Tensorflow IO 功能的地方,都可以使用 S3 URL。
冒烟测试
要测试你的设置,新建一个文件:
python
from tensorflow.python.lib.io import file_io
print file_io.stat('s3://bucketname/path/')
正确的输出应该如下所示:
console
tensorflow.python.pywrap_tensorflow_internal.FileStatistics; proxy of <Swig Object of type 'tensorflow::FileStatistics *' at 0x10c2171b0>
读取数据
当读取数据时,将用于读写的文件路径改成 S3 路径。示例如下:
python
filenames = ["s3://bucketname/path/to/file1.tfrecord",
"s3://bucketname/path/to/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
Tensorflow 工具
许多 Tensorflow 工具,如 Tensorboard 或 TensorFlow serving,也能以 S3 URL 作为参数:
bash
tensorboard --logdir s3://bucketname/path/to/model/
tensorflow_model_server --port=9000 --model_name=model --model_base_path=s3://bucketname/path/to/model/export/
这将启用一个使用 S3 的端对端工作流并传输所需的所有的数据。
S3 端点实现
S3 是由 Amazon 发明的,但 S3 API 已经普及并且有许多实现。以下实现已通过基本兼容性测试: