diff --git a/tensorflow_serving/util/json_tensor.cc b/tensorflow_serving/util/json_tensor.cc index 82532b12e38..d8ada0d5b37 100644 --- a/tensorflow_serving/util/json_tensor.cc +++ b/tensorflow_serving/util/json_tensor.cc @@ -346,13 +346,25 @@ Status AddValueToTensor(const rapidjson::Value& val, DataType dtype, // `val` can be scalar or list or list of lists with arbitrary nesting. If a // scalar (non array) is passed, we do not add dimension info to shape (as // scalars do not have a dimension). -void GetDenseTensorShape(const rapidjson::Value& val, TensorShapeProto* shape) { - if (!val.IsArray()) return; - const auto size = val.Size(); - shape->add_dim()->set_size(size); - if (size > 0) { - GetDenseTensorShape(val[0], shape); +constexpr int kMaxTensorRank = 254; + +Status GetDenseTensorShape(const rapidjson::Value& val, TensorShapeProto* shape) { + const rapidjson::Value* curr = &val; + int depth = 0; + while (curr->IsArray()) { + if (++depth > kMaxTensorRank) { + return errors::InvalidArgument( + "Tensor rank exceeds maximum allowed rank: ", kMaxTensorRank); + } + const auto size = curr->Size(); + shape->add_dim()->set_size(size); + if (size > 0) { + curr = &((*curr)[0]); + } else { + break; + } } + return OkStatus(); } bool IsValBase64Object(const rapidjson::Value& val) { @@ -392,6 +404,10 @@ Status JsonDecodeBase64Object(const rapidjson::Value& val, Status FillTensorProto(const rapidjson::Value& val, int level, DataType dtype, int* val_count, TensorProto* tensor) { const auto rank = tensor->tensor_shape().dim_size(); + if (rank > kMaxTensorRank) { + return errors::InvalidArgument( + "Tensor rank ", rank, " exceeds maximum allowed rank ", kMaxTensorRank); + } if (!val.IsArray()) { // DOM tree for a (dense) tensor will always have all values // at same (leaf) level equal to the rank of the tensor. @@ -453,7 +469,7 @@ Status AddInstanceItem(const rapidjson::Value& item, const string& name, const auto dtype = tensorinfo_map.at(name).dtype(); auto* tensor = &(*tensor_map)[name]; tensor->mutable_tensor_shape()->Clear(); - GetDenseTensorShape(item, tensor->mutable_tensor_shape()); + TF_RETURN_IF_ERROR(GetDenseTensorShape(item, tensor->mutable_tensor_shape())); TF_RETURN_IF_ERROR( FillTensorProto(item, 0 /* level */, dtype, &size, tensor)); if (!size_map->count(name)) { @@ -623,7 +639,7 @@ Status FillTensorMapFromInputsMap( auto* tensor = &(*tensor_map)[tensorinfo_map.begin()->first]; tensor->set_dtype(tensorinfo_map.begin()->second.dtype()); - GetDenseTensorShape(val, tensor->mutable_tensor_shape()); + TF_RETURN_IF_ERROR(GetDenseTensorShape(val, tensor->mutable_tensor_shape())); int unused_size = 0; TF_RETURN_IF_ERROR(FillTensorProto(val, 0 /* level */, tensor->dtype(), &unused_size, tensor)); @@ -639,7 +655,7 @@ Status FillTensorMapFromInputsMap( auto* tensor = &(*tensor_map)[name]; tensor->set_dtype(dtype); tensor->mutable_tensor_shape()->Clear(); - GetDenseTensorShape(item->value, tensor->mutable_tensor_shape()); + TF_RETURN_IF_ERROR(GetDenseTensorShape(item->value, tensor->mutable_tensor_shape())); int unused_size = 0; TF_RETURN_IF_ERROR(FillTensorProto(item->value, 0 /* level */, dtype, &unused_size, tensor)); diff --git a/tensorflow_serving/util/json_tensor_test.cc b/tensorflow_serving/util/json_tensor_test.cc index c117da26c54..750a0cf7903 100644 --- a/tensorflow_serving/util/json_tensor_test.cc +++ b/tensorflow_serving/util/json_tensor_test.cc @@ -118,6 +118,25 @@ TEST(JsontensorTest, DeeplyNestedMalformed) { EXPECT_THAT(status.message(), HasSubstr("key must be a string value")); } +TEST(JsontensorTest, DeeplyNestedTensorValueExceedsMaxRank) { + TensorInfoMap infomap; + ASSERT_TRUE( + TextFormat::ParseFromString("dtype: DT_INT32", &infomap["default"])); + + PredictRequest req; + JsonPredictRequestFormat format; + std::string json_req = R"({"instances":)"; + int depth = 300; // exceeds kMaxTensorRank (254) + json_req.append(depth, '['); + json_req.append("1"); + json_req.append(depth, ']'); + json_req.append("}"); + auto status = + FillPredictRequestFromJson(json_req, getmap(infomap), &req, &format); + ASSERT_FALSE(status.ok()); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); +} + TEST(JsontensorTest, MixedInputForFloatTensor) { TensorInfoMap infomap; ASSERT_TRUE(