Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions tensorflow_serving/util/json_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,25 @@ Status AddValueToTensor(const rapidjson::Value& val, DataType dtype,
// `val` can be scalar or list or list of lists with arbitrary nesting. If a
// scalar (non array) is passed, we do not add dimension info to shape (as
// scalars do not have a dimension).
void GetDenseTensorShape(const rapidjson::Value& val, TensorShapeProto* shape) {
if (!val.IsArray()) return;
const auto size = val.Size();
shape->add_dim()->set_size(size);
if (size > 0) {
GetDenseTensorShape(val[0], shape);
constexpr int kMaxTensorRank = 254;

Status GetDenseTensorShape(const rapidjson::Value& val, TensorShapeProto* shape) {
const rapidjson::Value* curr = &val;
int depth = 0;
while (curr->IsArray()) {
if (++depth > kMaxTensorRank) {
return errors::InvalidArgument(
"Tensor rank exceeds maximum allowed rank: ", kMaxTensorRank);
}
const auto size = curr->Size();
shape->add_dim()->set_size(size);
if (size > 0) {
curr = &((*curr)[0]);
} else {
break;
}
}
return OkStatus();
}

bool IsValBase64Object(const rapidjson::Value& val) {
Expand Down Expand Up @@ -392,6 +404,10 @@ Status JsonDecodeBase64Object(const rapidjson::Value& val,
Status FillTensorProto(const rapidjson::Value& val, int level, DataType dtype,
int* val_count, TensorProto* tensor) {
const auto rank = tensor->tensor_shape().dim_size();
if (rank > kMaxTensorRank) {
return errors::InvalidArgument(
"Tensor rank ", rank, " exceeds maximum allowed rank ", kMaxTensorRank);
}
if (!val.IsArray()) {
// DOM tree for a (dense) tensor will always have all values
// at same (leaf) level equal to the rank of the tensor.
Expand Down Expand Up @@ -453,7 +469,7 @@ Status AddInstanceItem(const rapidjson::Value& item, const string& name,
const auto dtype = tensorinfo_map.at(name).dtype();
auto* tensor = &(*tensor_map)[name];
tensor->mutable_tensor_shape()->Clear();
GetDenseTensorShape(item, tensor->mutable_tensor_shape());
TF_RETURN_IF_ERROR(GetDenseTensorShape(item, tensor->mutable_tensor_shape()));
TF_RETURN_IF_ERROR(
FillTensorProto(item, 0 /* level */, dtype, &size, tensor));
if (!size_map->count(name)) {
Expand Down Expand Up @@ -623,7 +639,7 @@ Status FillTensorMapFromInputsMap(

auto* tensor = &(*tensor_map)[tensorinfo_map.begin()->first];
tensor->set_dtype(tensorinfo_map.begin()->second.dtype());
GetDenseTensorShape(val, tensor->mutable_tensor_shape());
TF_RETURN_IF_ERROR(GetDenseTensorShape(val, tensor->mutable_tensor_shape()));
int unused_size = 0;
TF_RETURN_IF_ERROR(FillTensorProto(val, 0 /* level */, tensor->dtype(),
&unused_size, tensor));
Expand All @@ -639,7 +655,7 @@ Status FillTensorMapFromInputsMap(
auto* tensor = &(*tensor_map)[name];
tensor->set_dtype(dtype);
tensor->mutable_tensor_shape()->Clear();
GetDenseTensorShape(item->value, tensor->mutable_tensor_shape());
TF_RETURN_IF_ERROR(GetDenseTensorShape(item->value, tensor->mutable_tensor_shape()));
int unused_size = 0;
TF_RETURN_IF_ERROR(FillTensorProto(item->value, 0 /* level */, dtype,
&unused_size, tensor));
Expand Down
19 changes: 19 additions & 0 deletions tensorflow_serving/util/json_tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,25 @@ TEST(JsontensorTest, DeeplyNestedMalformed) {
EXPECT_THAT(status.message(), HasSubstr("key must be a string value"));
}

TEST(JsontensorTest, DeeplyNestedTensorValueExceedsMaxRank) {
TensorInfoMap infomap;
ASSERT_TRUE(
TextFormat::ParseFromString("dtype: DT_INT32", &infomap["default"]));

PredictRequest req;
JsonPredictRequestFormat format;
std::string json_req = R"({"instances":)";
int depth = 300; // exceeds kMaxTensorRank (254)
json_req.append(depth, '[');
json_req.append("1");
json_req.append(depth, ']');
json_req.append("}");
auto status =
FillPredictRequestFromJson(json_req, getmap(infomap), &req, &format);
ASSERT_FALSE(status.ok());
EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument);
}

TEST(JsontensorTest, MixedInputForFloatTensor) {
TensorInfoMap infomap;
ASSERT_TRUE(
Expand Down