为什么在Java的Tensorflow中出现“维度0的切片索引0越界”错误?

我有一个带有以下签名的模型,我正在尝试使用tensorflow for Java调用它:

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['jpegbase64_bytes'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: Placeholder:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['predictions'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 256)
        name: model/global_average_pooling2d/Mean:0
  Method name is: tensorflow/serving/predict

我调用模型的代码如下所示:

float[] predict(byte[] imageBytes) {
    try (Tensor result = SavedModelBundle.load("model.pb", "serve").session().runner()
            .feed("myinput", 0, TString.tensorOfBytes(NdArrays.scalarOfObject(imageBytes)))
            .fetch("myoutput")
            .run()
            .get(0)) {
        float[] buffer = new float[256];
        FloatNdArray floatNdArray = FloatDenseNdArray.create(RawDataBufferFactory.create(buffer, false),
                Shape.of(1, description.getNumFeatures()));
        ((TFloat32) result).copyTo(floatNdArray);
        return buffer;
    }
}

但是,这会引发以下错误:

slice index 0 of dimension 0 out of bounds.
     [[{{node map/TensorArrayUnstack/strided_slice}}]]
org.tensorflow.exceptions.TFInvalidArgumentException: slice index 0 of dimension 0 out of bounds.
     [[{{node map/TensorArrayUnstack/strided_slice}}]]
    at org.tensorflow.internal.c_api.AbstractTF_Status.throwExceptionIfNotOK(AbstractTF_Status.java:87)
    at org.tensorflow.Session.run(Session.java:691)
    at org.tensorflow.Session.access$100(Session.java:72)
    at org.tensorflow.Session$Runner.runHelper(Session.java:381)
    at org.tensorflow.Session$Runner.run(Session.java:329)
    at com.mridang.myapp.ImageModel.predict(ImageModel.java:69)
    ...
    ...
    ...
    ...

据我所知,该模型需要一个密集型弦张量,而我的则不需要。我在 Stackoverflow slice index 0 of Dimension of 0 out of bounds using Java API上找到了这个答案,但这似乎与非常旧的tensorflow版本有关。

我正在使用这些依赖项:

layer group: 'org.tensorflow', name: 'tensorflow-core-platform', version: '0.3.1'
layer group: 'org.tensorflow', name: 'tensorflow-framework', version: '0.3.1'

以上是为什么在Java的Tensorflow中出现“维度0的切片索引0越界”错误?的全部内容。
THE END
分享
二维码
< <上一篇
下一篇>>