将tensorflow数据集记录分成多条记录

我有一个未批处理的tensorflow数据集,如下所示:

ds = ...
for record in ds.take(3):
    print('data shape={}'.format(record['data'].shape))

-> data shape=(512, 512, 87)
-> data shape=(512, 512, 277)
-> data shape=(512, 512, 133)

我想将数据以深度为 5 的块形式提供给我的网络。在上面的示例中,形状 (512, 512, 87) 的张量将被划分为 17 个形状 (512, 512, 5) 的张量。tensor[:,:, 85:87]应丢弃矩阵 ( )的最后 2 行。

例如:

chunked_ds = ...
for record in chunked_ds.take(1):
    print('chunked data shape={}'.format(record['data'].shape))

-> chunked data shape=(512, 512, 5)

我怎样才能从dschunked_dstf.data.Dataset.window()看起来像我需要的,但我无法让它工作。

以上是将tensorflow数据集记录分成多条记录的全部内容。
THE END
分享
二维码
< <上一篇
下一篇>>