Tensorflow-Keras:考虑关闭自动分片或将auto_shard_policy切换为DATA以对该数据集进行分片
在 keras / tensorflow 中训练模型时:
代码片段:
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
我收到以下错误/警告:
Consider either turning off auto-sharding or switching the auto_shard_policy to DATA to shard this dataset. You can do this by creating a new `tf.data.Options()` object then setting `options.experimental_distribute.auto_shard_policy = AutoShardPolicy.DATA` before applying the options object to the dataset via `dataset.with_options(options)`.
2020-12-16 17:12:20.885741: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:127] None of the MLIR optimization passes are enabled (registered 2)
2020-12-16 17:12:20.905570: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 3593105000 Hz
Epoch 1/40
任何帮助表示赞赏。
回答
这里的错误信息是新到达的tensorflow 2.4.0。虽然错误暗示了解决方案,但它预先假定您的数据是类型的对象tf.data.Dataset。以前没有严格要求以这种形式输入数据(例如 numpy 数组很好),但现在似乎是分发策略的要求(例如tf.distribute.MirroredStrategy())。无论如何,如果不将数据包装在 Dataset 对象中,似乎没有办法避免 tensorflow 的最新控制台呕吐。
因此,假设您当前的代码如下所示:
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
model = ... # awesome model definition
train_x, train_y = np.array(...), np.array(...)
val_x, val_y = np.array(...), np.array(...)
batch_size = 32
model.fit(train_x, train_y, batch_size=batch_size, validation_data=(val_x, val_y))
它需要更改为如下所示:
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
model = ... # awesome model definition
train_x, train_y = np.array(...), np.array(...)
val_x, val_y = np.array(...), np.array(...)
# Wrap data in Dataset objects.
train_data = tf.data.Dataset.from_tensor_slices((train_x, train_y))
val_data = tf.data.Dataset.from_tensor_slices((val_x, val_y))
# The batch size must now be set on the Dataset objects.
batch_size = 32
train_data = train_data.batch(batch_size)
val_data = val_data.batch(batch_size)
# Disable AutoShard.
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
train_data = train_data.with_options(options)
val_data = val_data.with_options(options)
model.fit(train_data, validation_data=val_data)
请注意,如果您没有在 Dataset 对象上设置批量大小,您将收到一个像这样的神秘错误:
File "/usr/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/distribute.py", line 496, in get_static_batch_dim
return output_shape.dims[0].value
IndexError: list index out of range
THE END
二维码