diff --git a/.agents/repo-reference.md b/.agents/repo-reference.md index 8df1d30b8e..144a432528 100644 --- a/.agents/repo-reference.md +++ b/.agents/repo-reference.md @@ -49,6 +49,18 @@ Apache Fory is a multi-language serialization framework with multiple wire forma - `foryc schema.fdl --lang --output ` - Never edit generated code manually. Update the source schema or IDL and regenerate. - Protocol changes must update `docs/specification/**` and the relevant cross-language tests. +- Remote struct `TypeDef` or `TypeMeta` schema limits are resource controls on cold metadata + cache-miss parse/publish paths only. They must not change wire format, registration, dynamic type + loading, unknown-type behavior, deserialization policy, schema-evolution semantics, or metadata + cache-hit/generated-reader hot paths. Count a remote schema version only after the schema-specific + read state has been successfully built and the owning metadata cache can publish it; failed or + incompatible metadata must not consume the limit. +- Remote struct metadata body and field-count limits are also cold-path resource controls. + `maxTypeMetaBytes` limits one received TypeDef or TypeMeta body excluding the 8-byte header and + extended-size varint; `maxTypeFields` limits one received struct metadata body's field count + (Java native TypeDef counts total fields across class layers). Check these before body + copy/decompression and before field-list allocation, and never add cache-hit or generated-reader + hot-path work for them. ## Runtime Map diff --git a/AGENTS.md b/AGENTS.md index 02e44f69a3..c95bd29eca 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -145,6 +145,9 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - If Swift xlang behavior changes, run `org.apache.fory.xlang.SwiftXlangTest` too. - For performance regressions or optimizations, profile or otherwise measure the current branch and a fresh `apache/main` baseline before changing code; optimize the measured hotspot, not guessed code. - When comparing benchmark results against `apache/main`, use a separate sibling worktree named `fory-benchmark-baseline` by default. Before creating a new worktree, check whether `../fory-benchmark-baseline` already exists and reuse it to avoid repeated benchmark dependency rebuilds. Always fetch and sync that baseline worktree to the latest `apache/main` before measuring it, and store benchmark result files under that worktree so older runs remain available as reference data. Treat stored benchmark results as historical references, not truth, because machine load and benchmark variance change over time. Create a different baseline worktree only when explicitly requested. +- Before benchmarking a checked-out version, install or build the required Fory packages for that version, such as the Java artifacts, Python package, and the target runtime package needed by the benchmark. +- Run old/new benchmark comparisons for only one language at a time. Do not batch multiple language comparisons into one long benchmark session, because longer sessions increase machine-state drift and variance. +- Treat a same-benchmark gap greater than 1% as a likely hot-path regression until repeated focused measurements or profiling prove it is noise. - Do not change protocol behavior, benchmark payloads, or public APIs solely to manufacture performance wins. - For performance work, run the relevant benchmark immediately after each change and report the command plus before/after numbers. - For performance-optimization rounds, append the hypothesis, change, benchmark command, before/after numbers, and keep/revert decision to `tasks/perf_optimization_rounds.md`. diff --git a/cpp/fory/serialization/config.h b/cpp/fory/serialization/config.h index d471c39074..665c1d2ed4 100644 --- a/cpp/fory/serialization/config.h +++ b/cpp/fory/serialization/config.h @@ -52,6 +52,19 @@ struct Config { /// When enabled, avoids duplicating shared objects and handles cycles. bool track_ref = true; + /// Maximum accepted field count in one received struct TypeMeta. + uint32_t max_type_fields = 512; + + /// Maximum accepted body size in one received TypeMeta. + uint32_t max_type_meta_bytes = 4096; + + /// Maximum accepted remote struct schema versions for one logical type. + uint32_t max_schema_versions_per_type = 10; + + /// Maximum accepted average remote struct schema versions across logical + /// types. The effective global minimum remains 8192 schemas. + uint32_t max_average_schema_versions_per_type = 3; + /// Default constructor with sensible defaults Config() = default; }; diff --git a/cpp/fory/serialization/context.cc b/cpp/fory/serialization/context.cc index 95b2453205..015e1dbfdd 100644 --- a/cpp/fory/serialization/context.cc +++ b/cpp/fory/serialization/context.cc @@ -22,6 +22,8 @@ #include "fory/serialization/type_resolver.h" #include "fory/thirdparty/MurmurHash3.h" #include "fory/type/type.h" +#include +#include namespace fory { namespace serialization { @@ -477,8 +479,71 @@ ReadContext::read_enum_type_info(uint32_t base_type_id) { return Unexpected(Error::type_mismatch(type_id, base_type_id)); } -// Maximum number of parsed type defs to cache (avoid OOM from malicious input) -static constexpr size_t k_max_parsed_num_type_defs = 8192; +static constexpr size_t k_min_remote_struct_schema_limit = 8192; + +Result +ReadContext::check_remote_struct_schema_limit(const TypeMeta &type_meta) { + const auto type_id = static_cast(type_meta.type_id); + switch (type_id) { + case TypeId::STRUCT: + case TypeId::COMPATIBLE_STRUCT: + case TypeId::NAMED_STRUCT: + case TypeId::NAMED_COMPATIBLE_STRUCT: + break; + default: + return std::string(); + } + + std::string key; + if (type_meta.register_by_name) { + key.reserve(type_meta.namespace_str.size() + type_meta.type_name.size() + + 2); + key.push_back('n'); + key.append(type_meta.namespace_str); + key.push_back('\0'); + key.append(type_meta.type_name); + } else { + key = "i" + std::to_string(type_meta.user_type_id); + } + + auto *entry = remote_schema_versions_by_type_.find(key); + const uint32_t versions_for_type = entry == nullptr ? 0 : entry->second; + if (FORY_PREDICT_FALSE(versions_for_type >= + config_->max_schema_versions_per_type)) { + return Unexpected(Error::invalid_data( + "Remote struct schema versions for one type exceeded " + "max_schema_versions_per_type=" + + std::to_string(config_->max_schema_versions_per_type))); + } + + const size_t accepted_type_count = + remote_schema_versions_by_type_.size() + (entry == nullptr ? 1 : 0); + const size_t global_limit = std::max( + k_min_remote_struct_schema_limit, + accepted_type_count * + static_cast(config_->max_average_schema_versions_per_type)); + if (FORY_PREDICT_FALSE(total_accepted_schema_versions_ >= global_limit)) { + return Unexpected(Error::invalid_data( + "Remote struct schema versions exceeded global limit from " + "max_average_schema_versions_per_type=" + + std::to_string(config_->max_average_schema_versions_per_type))); + } + + return key; +} + +void ReadContext::record_remote_struct_schema(const std::string &type_key) { + if (type_key.empty()) { + return; + } + auto *entry = remote_schema_versions_by_type_.find(type_key); + if (entry == nullptr) { + remote_schema_versions_by_type_[type_key] = 1; + } else { + ++entry->second; + } + ++total_accepted_schema_versions_; +} Result ReadContext::read_type_meta() { Error error; @@ -531,8 +596,12 @@ Result ReadContext::read_type_meta() { } // Not in cache - parse the TypeMeta - FORY_TRY(parsed_meta, - TypeMeta::from_bytes_with_header(*buffer_, meta_header)); + const uint32_t type_def_start = + buffer_->reader_index() - static_cast(sizeof(int64_t)); + FORY_TRY(parsed_meta, TypeMeta::from_bytes_with_header( + *buffer_, meta_header, config_->max_type_fields, + config_->max_type_meta_bytes)); + const uint32_t type_def_end = buffer_->reader_index(); // Find local TypeInfo to get field_id mapping (optional for schema evolution) const TypeInfo *local_type_info = nullptr; @@ -557,6 +626,24 @@ Result ReadContext::read_type_meta() { } } + if (local_type_info) { + const auto &local_type_def = local_type_info->type_def; + const size_t remote_type_def_size = + static_cast(type_def_end - type_def_start); + if (local_type_def.size() == remote_type_def_size && + std::memcmp(local_type_def.data(), buffer_->data() + type_def_start, + remote_type_def_size) == 0) { + parsed_type_infos_[meta_header] = local_type_info; + has_last_meta_header_ = true; + last_meta_header_ = meta_header; + last_meta_type_info_ = local_type_info; + reading_type_infos_.push_back(local_type_info); + return local_type_info; + } + } + + FORY_TRY(remote_schema_key, check_remote_struct_schema_limit(*parsed_meta)); + // Create TypeInfo with field_ids assigned auto type_info = std::make_unique(); if (local_type_info) { @@ -583,21 +670,13 @@ Result ReadContext::read_type_meta() { type_info->type_meta = std::move(parsed_meta); } - // get raw pointer before moving into storage - const TypeInfo *raw_ptr = type_info.get(); - - // Store in primary storage - if (parsed_type_infos_.size() < k_max_parsed_num_type_defs) { - cached_type_infos_.push_back(std::move(type_info)); - raw_ptr = cached_type_infos_.back().get(); - parsed_type_infos_[meta_header] = raw_ptr; - has_last_meta_header_ = true; - last_meta_header_ = meta_header; - last_meta_type_info_ = raw_ptr; - } else { - owned_reading_type_infos_.push_back(std::move(type_info)); - raw_ptr = owned_reading_type_infos_.back().get(); - } + cached_type_infos_.push_back(std::move(type_info)); + const TypeInfo *raw_ptr = cached_type_infos_.back().get(); + parsed_type_infos_[meta_header] = raw_ptr; + has_last_meta_header_ = true; + last_meta_header_ = meta_header; + last_meta_type_info_ = raw_ptr; + record_remote_struct_schema(remote_schema_key); reading_type_infos_.push_back(raw_ptr); return raw_ptr; @@ -677,7 +756,6 @@ void ReadContext::reset() { ref_reader_.reset(); } reading_type_infos_.clear(); - owned_reading_type_infos_.clear(); current_dyn_depth_ = 0; if (meta_string_table_active_) { meta_string_table_.reset(); diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index 60e7974d06..b5b47ef2f4 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -31,6 +31,7 @@ #include "fory/util/result.h" #include +#include #include namespace fory { @@ -39,6 +40,7 @@ namespace serialization { // Forward declarations class TypeResolver; class ReadContext; +class TypeMeta; /// RAII helper to automatically decrease dynamic depth when leaving scope. /// Used for tracking nested polymorphic type deserialization depth. @@ -656,6 +658,10 @@ class ReadContext { inline const Config &config() const { return *config_; } private: + FORY_NOINLINE Result + check_remote_struct_schema_limit(const TypeMeta &type_meta); + void record_remote_struct_schema(const std::string &type_key); + // Error state - accumulated during deserialization, checked at the end Error error_; @@ -666,14 +672,14 @@ class ReadContext { uint32_t current_dyn_depth_; // Meta sharing state (for compatible mode) - // Per-message storage for TypeInfo objects not cached across messages. - std::vector> owned_reading_type_infos_; // Persistent cache storage for TypeInfo objects keyed by meta header. std::vector> cached_type_infos_; - // Index-based access (pointers to owned_reading_type_infos_ or type_resolver) + // Index-based access (pointers to cached_type_infos_ or type_resolver) std::vector reading_type_infos_; // Cache by meta_header (pointers to cached_type_infos_) fory::flat_hash_map parsed_type_infos_; + fory::flat_hash_map remote_schema_versions_by_type_; + size_t total_accepted_schema_versions_ = 0; // Fast path for repeated type meta headers. int64_t last_meta_header_ = 0; const TypeInfo *last_meta_type_info_ = nullptr; diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index 3361cb621d..246288e5dd 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -109,6 +109,37 @@ class ForyBuilder { return *this; } + /// Set maximum accepted field count in one received struct TypeMeta. + ForyBuilder &max_type_fields(uint32_t max_fields) { + FORY_CHECK(max_fields > 0) << "max_type_fields must be positive"; + config_.max_type_fields = max_fields; + return *this; + } + + /// Set maximum accepted body size in one received TypeMeta. + ForyBuilder &max_type_meta_bytes(uint32_t max_bytes) { + FORY_CHECK(max_bytes > 0) << "max_type_meta_bytes must be positive"; + config_.max_type_meta_bytes = max_bytes; + return *this; + } + + /// Set maximum accepted remote struct schema versions for one logical type. + ForyBuilder &max_schema_versions_per_type(uint32_t max_versions) { + FORY_CHECK(max_versions > 0) + << "max_schema_versions_per_type must be positive"; + config_.max_schema_versions_per_type = max_versions; + return *this; + } + + /// Set maximum accepted average remote struct schema versions across logical + /// types. The effective global minimum remains 8192 schemas. + ForyBuilder &max_average_schema_versions_per_type(uint32_t max_versions) { + FORY_CHECK(max_versions > 0) + << "max_average_schema_versions_per_type must be positive"; + config_.max_average_schema_versions_per_type = max_versions; + return *this; + } + /// Provide a custom type resolver instance. ForyBuilder &type_resolver(std::shared_ptr resolver) { type_resolver_ = std::move(resolver); diff --git a/cpp/fory/serialization/serialization_test.cc b/cpp/fory/serialization/serialization_test.cc index 840a522a66..f1a7b3c91f 100644 --- a/cpp/fory/serialization/serialization_test.cc +++ b/cpp/fory/serialization/serialization_test.cc @@ -957,6 +957,63 @@ TEST(SerializationTest, RegistrationByNameFailureDoesNotLeakTypeInfo) { EXPECT_EQ(dotted_type_name.error().code(), ErrorCode::Invalid); } +static std::vector make_remote_type_meta(const std::string &type_name, + const std::string &field) { + std::vector fields; + fields.emplace_back( + field, FieldType(static_cast(TypeId::VARINT32), false)); + TypeMeta meta = + TypeMeta::from_fields(static_cast(TypeId::NAMED_STRUCT), + "example", type_name, true, 0, std::move(fields)); + auto bytes = meta.to_bytes(); + EXPECT_TRUE(bytes.ok()) << "TypeMeta serialization failed: " + << bytes.error().to_string(); + return bytes.value(); +} + +static Result +append_and_read_type_meta(ReadContext &ctx, const std::vector &bytes) { + Buffer buffer; + buffer.write_var_uint32(0); + buffer.write_bytes(bytes.data(), static_cast(bytes.size())); + ctx.reset(); + ctx.attach(buffer); + auto result = ctx.read_type_meta(); + ctx.detach(); + return result; +} + +TEST(SerializationTest, RemoteSchemaLimitRejectsExtraVersions) { + Config config; + config.compatible = true; + config.max_schema_versions_per_type = 1; + ReadContext ctx(config, std::make_unique()); + + auto first = append_and_read_type_meta( + ctx, make_remote_type_meta("Unknown", "first_value")); + ASSERT_TRUE(first.ok()) << first.error().to_string(); + + auto second = append_and_read_type_meta( + ctx, make_remote_type_meta("Unknown", "second_value")); + EXPECT_FALSE(second.ok()); + ASSERT_FALSE(second.ok()); + EXPECT_EQ(second.error().code(), ErrorCode::InvalidData); +} + +TEST(SerializationTest, RemoteSchemaLimitKeepsUnknownTypesSeparate) { + Config config; + config.compatible = true; + config.max_schema_versions_per_type = 1; + ReadContext ctx(config, std::make_unique()); + + auto first = append_and_read_type_meta( + ctx, make_remote_type_meta("UnknownA", "value")); + ASSERT_TRUE(first.ok()) << first.error().to_string(); + auto second = append_and_read_type_meta( + ctx, make_remote_type_meta("UnknownB", "value")); + EXPECT_TRUE(second.ok()) << second.error().to_string(); +} + TEST(SerializationTest, TypeMetaRejectsOverConsumedDeclaredSize) { TypeMeta meta = TypeMeta::from_fields(static_cast(TypeId::STRUCT), "", "S", @@ -1021,6 +1078,52 @@ TEST(SerializationTest, TypeMetaHeaderUses52BitMetadataHash) { parsed.value()->get_hash()); } +TEST(SerializationTest, TypeMetaRejectsMaxTypeFields) { + std::vector fields; + fields.emplace_back( + "first", FieldType(static_cast(TypeId::VARINT32), false)); + fields.emplace_back( + "second", FieldType(static_cast(TypeId::VARINT32), false)); + TypeMeta meta = TypeMeta::from_fields(static_cast(TypeId::STRUCT), + "", "S", false, 1, std::move(fields)); + auto bytes_result = meta.to_bytes(); + ASSERT_TRUE(bytes_result.ok()) + << "TypeMeta serialization failed: " << bytes_result.error().to_string(); + std::vector bytes = bytes_result.value(); + Buffer buffer(bytes); + Error error; + int64_t header = 0; + buffer.read_bytes(&header, sizeof(header), error); + ASSERT_TRUE(error.ok()) << error.to_string(); + + auto parsed = TypeMeta::from_bytes_with_header(buffer, header, 1, 4096); + ASSERT_FALSE(parsed.ok()); + EXPECT_NE(parsed.error().to_string().find("max_type_fields"), + std::string::npos); +} + +TEST(SerializationTest, TypeMetaRejectsMaxTypeMetaBytes) { + std::vector fields; + fields.emplace_back( + "value", FieldType(static_cast(TypeId::VARINT32), false)); + TypeMeta meta = TypeMeta::from_fields(static_cast(TypeId::STRUCT), + "", "S", false, 1, std::move(fields)); + auto bytes_result = meta.to_bytes(); + ASSERT_TRUE(bytes_result.ok()) + << "TypeMeta serialization failed: " << bytes_result.error().to_string(); + std::vector bytes = bytes_result.value(); + Buffer buffer(bytes); + Error error; + int64_t header = 0; + buffer.read_bytes(&header, sizeof(header), error); + ASSERT_TRUE(error.ok()) << error.to_string(); + + auto parsed = TypeMeta::from_bytes_with_header(buffer, header, 512, 1); + ASSERT_FALSE(parsed.ok()); + EXPECT_NE(parsed.error().to_string().find("max_type_meta_bytes"), + std::string::npos); +} + TEST(SerializationTest, TypeMetaRejectsBodyOnlyHeaderHash) { TypeMeta meta = TypeMeta::from_fields(static_cast(TypeId::STRUCT), "", "S", diff --git a/cpp/fory/serialization/type_resolver.cc b/cpp/fory/serialization/type_resolver.cc index 45db2d6541..49e2bca39e 100644 --- a/cpp/fory/serialization/type_resolver.cc +++ b/cpp/fory/serialization/type_resolver.cc @@ -50,6 +50,8 @@ constexpr int8_t NUM_HASH_BITS = 52; constexpr uint32_t TYPE_META_HASH_SHIFT = 64 - NUM_HASH_BITS; constexpr uint64_t TYPE_META_HASH_BITS_MASK = ~uint64_t{0} << TYPE_META_HASH_SHIFT; +constexpr uint32_t kDefaultMaxTypeFields = 512; +constexpr uint32_t kDefaultMaxTypeMetaBytes = 4096; // ============================================================================ // FieldType Implementation @@ -450,6 +452,30 @@ read_type_meta_size(Buffer &buffer, uint64_t header, size_t *header_size) { return static_cast(meta_size); } +inline Result +check_type_meta_body_size(uint32_t meta_size, uint32_t max_type_meta_bytes) { + if (FORY_PREDICT_FALSE(meta_size > max_type_meta_bytes)) { + return Unexpected(Error::invalid_data( + "Type metadata body size " + std::to_string(meta_size) + + " exceeds max_type_meta_bytes " + std::to_string(max_type_meta_bytes) + + ". The data may be malicious. If the data is not malicious, please " + "increase max_type_meta_bytes.")); + } + return Result(); +} + +inline Result check_type_meta_fields(size_t num_fields, + uint32_t max_type_fields) { + if (FORY_PREDICT_FALSE(num_fields > max_type_fields)) { + return Unexpected(Error::invalid_data( + "Type metadata field count " + std::to_string(num_fields) + + " exceeds max_type_fields " + std::to_string(max_type_fields) + + ". The data may be malicious. If the data is not malicious, please " + "increase max_type_fields.")); + } + return Result(); +} + inline Result validate_type_meta_hash(Buffer &buffer, uint32_t body_start, uint32_t meta_size, @@ -635,6 +661,8 @@ TypeMeta::from_bytes(Buffer &buffer, const TypeMeta *local_type_info) { uint64_t header_bits = static_cast(header); FORY_RETURN_IF_ERROR(validate_type_meta_header(header_bits)); FORY_TRY(meta_size, read_type_meta_size(buffer, header_bits, &header_size)); + FORY_RETURN_IF_ERROR( + check_type_meta_body_size(meta_size, kDefaultMaxTypeMetaBytes)); int64_t meta_hash = static_cast(header_bits >> TYPE_META_HASH_SHIFT); uint32_t body_start = static_cast(start_pos + header_size); if (FORY_PREDICT_FALSE(!buffer.ensure_readable(meta_size, error))) { @@ -671,6 +699,8 @@ TypeMeta::from_bytes(Buffer &buffer, const TypeMeta *local_type_info) { } num_fields += extra; } + FORY_RETURN_IF_ERROR( + check_type_meta_fields(num_fields, kDefaultMaxTypeFields)); } else { if (FORY_PREDICT_FALSE((meta_header & NON_STRUCT_RESERVED_BITS_MASK) != 0)) { @@ -752,10 +782,14 @@ TypeMeta::from_bytes(Buffer &buffer, const TypeMeta *local_type_info) { } Result, Error> -TypeMeta::from_bytes_with_header(Buffer &buffer, int64_t header) { +TypeMeta::from_bytes_with_header(Buffer &buffer, int64_t header, + uint32_t max_type_fields, + uint32_t max_type_meta_bytes) { uint64_t header_bits = static_cast(header); FORY_RETURN_IF_ERROR(validate_type_meta_header(header_bits)); FORY_TRY(meta_size, read_type_meta_size(buffer, header_bits, nullptr)); + FORY_RETURN_IF_ERROR( + check_type_meta_body_size(meta_size, max_type_meta_bytes)); int64_t meta_hash = static_cast(header_bits >> TYPE_META_HASH_SHIFT); uint32_t start_pos = buffer.reader_index(); @@ -795,6 +829,7 @@ TypeMeta::from_bytes_with_header(Buffer &buffer, int64_t header) { } num_fields += extra; } + FORY_RETURN_IF_ERROR(check_type_meta_fields(num_fields, max_type_fields)); } else { if (FORY_PREDICT_FALSE((meta_header & NON_STRUCT_RESERVED_BITS_MASK) != 0)) { @@ -1770,13 +1805,11 @@ TypeResolver::build_final_type_resolver() { // Update the TypeInfo in place partial_ptr->type_def = std::move(type_def); - - // Parse the serialized TypeMeta back to create unique_ptr - Buffer buffer(partial_ptr->type_def.data(), - static_cast(partial_ptr->type_def.size()), false); - buffer.writer_index(static_cast(partial_ptr->type_def.size())); - FORY_TRY(parsed_meta, TypeMeta::from_bytes(buffer, nullptr)); - partial_ptr->type_meta = std::move(parsed_meta); + uint64_t header_bits = 0; + std::memcpy(&header_bits, partial_ptr->type_def.data(), + sizeof(header_bits)); + meta.hash = static_cast(header_bits >> TYPE_META_HASH_SHIFT); + partial_ptr->type_meta = std::make_unique(std::move(meta)); } // Clear partial_type_infos in the final resolver since they're all completed diff --git a/cpp/fory/serialization/type_resolver.h b/cpp/fory/serialization/type_resolver.h index dcf358b130..15b39be920 100644 --- a/cpp/fory/serialization/type_resolver.h +++ b/cpp/fory/serialization/type_resolver.h @@ -220,8 +220,6 @@ class FieldInfo { // TypeMeta - Complete type metadata (for schema evolution) // ============================================================================ -constexpr size_t MAX_PARSED_NUM_TYPE_DEFS = 8192; - /// Type metadata containing all field information /// Used for schema evolution to compare remote and local type schemas class TypeMeta { @@ -258,7 +256,9 @@ class TypeMeta { /// @param buffer Source buffer (positioned after header) /// @param header Pre-read 8-byte header static Result, Error> - from_bytes_with_header(Buffer &buffer, int64_t header); + from_bytes_with_header(Buffer &buffer, int64_t header, + uint32_t max_type_fields = 512, + uint32_t max_type_meta_bytes = 4096); /// skip type meta in buffer without parsing static Result skip_bytes_for_validated_header(Buffer &buffer, diff --git a/csharp/src/Fory/Config.cs b/csharp/src/Fory/Config.cs index 2d8d94896b..550945e956 100644 --- a/csharp/src/Fory/Config.cs +++ b/csharp/src/Fory/Config.cs @@ -27,12 +27,41 @@ internal Config( bool trackRef, bool compatible, bool checkStructVersion, - int maxDepth) + int maxDepth, + int maxTypeFields, + int maxTypeMetaBytes, + int maxSchemaVersionsPerType, + int maxAverageSchemaVersionsPerType) { + if (maxDepth <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxDepth), "MaxDepth must be greater than 0."); + } + if (maxTypeFields <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTypeFields), "MaxTypeFields must be greater than 0."); + } + if (maxTypeMetaBytes <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxTypeMetaBytes), "MaxTypeMetaBytes must be greater than 0."); + } + if (maxSchemaVersionsPerType <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxSchemaVersionsPerType), "MaxSchemaVersionsPerType must be greater than 0."); + } + if (maxAverageSchemaVersionsPerType <= 0) + { + throw new ArgumentOutOfRangeException(nameof(maxAverageSchemaVersionsPerType), "MaxAverageSchemaVersionsPerType must be greater than 0."); + } + TrackRef = trackRef; Compatible = compatible; CheckStructVersion = checkStructVersion; MaxDepth = maxDepth; + MaxTypeFields = maxTypeFields; + MaxTypeMetaBytes = maxTypeMetaBytes; + MaxSchemaVersionsPerType = maxSchemaVersionsPerType; + MaxAverageSchemaVersionsPerType = maxAverageSchemaVersionsPerType; } /// @@ -54,6 +83,26 @@ internal Config( /// Gets the maximum allowed nesting depth for dynamic object payload reads. /// public int MaxDepth { get; } + + /// + /// Gets the maximum accepted field count in one received struct TypeMeta. + /// + public int MaxTypeFields { get; } + + /// + /// Gets the maximum accepted body size in one received TypeMeta. + /// + public int MaxTypeMetaBytes { get; } + + /// + /// Gets the maximum accepted remote schema versions for one struct type. + /// + public int MaxSchemaVersionsPerType { get; } + + /// + /// Gets the average remote schema version limit across accepted struct types. + /// + public int MaxAverageSchemaVersionsPerType { get; } } /// @@ -65,6 +114,10 @@ public sealed class ForyBuilder private bool? _compatible; private bool _checkStructVersion; private int _maxDepth = 20; + private int _maxTypeFields = 512; + private int _maxTypeMetaBytes = 4096; + private int _maxSchemaVersionsPerType = 10; + private int _maxAverageSchemaVersionsPerType = 3; /// /// Enables or disables reference tracking for shared and circular object graphs. @@ -116,6 +169,62 @@ public ForyBuilder MaxDepth(int value) return this; } + /// + /// Sets the maximum accepted field count in one received struct TypeMeta. + /// + public ForyBuilder MaxTypeFields(int value) + { + if (value <= 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "MaxTypeFields must be greater than 0."); + } + + _maxTypeFields = value; + return this; + } + + /// + /// Sets the maximum accepted body size in one received TypeMeta. + /// + public ForyBuilder MaxTypeMetaBytes(int value) + { + if (value <= 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "MaxTypeMetaBytes must be greater than 0."); + } + + _maxTypeMetaBytes = value; + return this; + } + + /// + /// Sets the maximum accepted remote schema versions for one struct type. + /// + public ForyBuilder MaxSchemaVersionsPerType(int value) + { + if (value <= 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "MaxSchemaVersionsPerType must be greater than 0."); + } + + _maxSchemaVersionsPerType = value; + return this; + } + + /// + /// Sets the average remote schema version limit across accepted struct types. + /// + public ForyBuilder MaxAverageSchemaVersionsPerType(int value) + { + if (value <= 0) + { + throw new ArgumentOutOfRangeException(nameof(value), "MaxAverageSchemaVersionsPerType must be greater than 0."); + } + + _maxAverageSchemaVersionsPerType = value; + return this; + } + private Config BuildConfig() { bool compatible = _compatible ?? true; @@ -125,7 +234,11 @@ private Config BuildConfig() trackRef: _trackRef, compatible: compatible, checkStructVersion: compatible ? false : _checkStructVersion, - maxDepth: _maxDepth); + maxDepth: _maxDepth, + maxTypeFields: _maxTypeFields, + maxTypeMetaBytes: _maxTypeMetaBytes, + maxSchemaVersionsPerType: _maxSchemaVersionsPerType, + maxAverageSchemaVersionsPerType: _maxAverageSchemaVersionsPerType); } /// diff --git a/csharp/src/Fory/Fory.cs b/csharp/src/Fory/Fory.cs index f193d4b923..9bbafd1775 100644 --- a/csharp/src/Fory/Fory.cs +++ b/csharp/src/Fory/Fory.cs @@ -44,10 +44,7 @@ internal Fory(Config config) _readContext = new ReadContext( new ByteReader(Array.Empty()), _typeResolver, - Config.TrackRef, - Config.Compatible, - Config.CheckStructVersion, - Config.MaxDepth); + Config); } /// diff --git a/csharp/src/Fory/ReadContext.cs b/csharp/src/Fory/ReadContext.cs index 2192bb55cf..f9cde8844a 100644 --- a/csharp/src/Fory/ReadContext.cs +++ b/csharp/src/Fory/ReadContext.cs @@ -15,14 +15,17 @@ // specific language governing permissions and limitations // under the License. +using System.Buffers.Binary; + namespace Apache.Fory; public sealed class ReadContext { - private const int MaxParsedTypeMetaEntries = 8192; + private const int MinRemoteStructSchemaLimit = 8192; private readonly ReusableArray _readTypeMetas = new(); private readonly Dictionary _cachedTypeMetasByHeader = []; + private readonly Dictionary _remoteSchemaVersionsByType = []; private TypeMeta? _firstReadTypeMeta; private bool _hasFirstReadTypeMeta; private ulong _lastMetaHeader; @@ -40,27 +43,24 @@ public sealed class ReadContext internal Type? _cachedTypeMetaType; internal TypeMeta? _cachedTypeMeta; internal int _currentDynamicReadDepth; + private readonly Config _config; + private int _totalAcceptedSchemaVersions; public ReadContext( ByteReader reader, TypeResolver typeResolver, - bool trackRef, - bool compatible = false, - bool checkStructVersion = false, - int maxDynamicReadDepth = 20) + Config config) { - if (maxDynamicReadDepth <= 0) - { - throw new ArgumentOutOfRangeException(nameof(maxDynamicReadDepth), "MaxDepth must be greater than 0."); - } + ArgumentNullException.ThrowIfNull(config); Reader = reader; TypeResolver = typeResolver; - TrackRef = trackRef; - Compatible = compatible; - CheckStructVersion = checkStructVersion; + TrackRef = config.TrackRef; + Compatible = config.Compatible; + CheckStructVersion = config.CheckStructVersion; RefReader = new RefReader(); - _maxDynamicReadDepth = maxDynamicReadDepth; + _maxDynamicReadDepth = config.MaxDepth; + _config = config; } public ByteReader Reader { get; private set; } @@ -164,15 +164,66 @@ internal void CacheReadTypeMeta(ulong header, TypeMeta typeMeta) return; } - if (_cachedTypeMetasByHeader.Count >= MaxParsedTypeMetaEntries) - { - return; - } - + object? typeKey = CheckRemoteStructSchemaLimit(typeMeta); _lastMetaHeader = header; _lastTypeMeta = typeMeta; _hasLastMetaHeader = true; _cachedTypeMetasByHeader.TryAdd(header, typeMeta); + RecordRemoteStructSchema(typeKey); + } + + [System.Runtime.CompilerServices.MethodImpl(System.Runtime.CompilerServices.MethodImplOptions.NoInlining)] + private object? CheckRemoteStructSchemaLimit(TypeMeta typeMeta) + { + if (typeMeta.TypeId is not uint typeId || + typeId is not ((uint)TypeId.Struct or + (uint)TypeId.CompatibleStruct or + (uint)TypeId.NamedStruct or + (uint)TypeId.NamedCompatibleStruct)) + { + return null; + } + + object typeKey = typeMeta.RegisterByName + ? $"{typeMeta.NamespaceName.Value}\0{typeMeta.TypeName.Value}" + : typeMeta.UserTypeId!.Value; + _remoteSchemaVersionsByType.TryGetValue(typeKey, out int versionsForType); + int maxSchemaVersionsPerType = _config.MaxSchemaVersionsPerType; + if (versionsForType >= maxSchemaVersionsPerType) + { + throw new InvalidDataException( + $"Remote schema version limit exceeded for type {typeKey}: {versionsForType} >= {maxSchemaVersionsPerType}. " + + "Increase MaxSchemaVersionsPerType if this peer legitimately sends many schema versions for one type."); + } + + int acceptedStructTypeCount = versionsForType == 0 + ? _remoteSchemaVersionsByType.Count + 1 + : _remoteSchemaVersionsByType.Count; + int maxAverageSchemaVersionsPerType = _config.MaxAverageSchemaVersionsPerType; + long globalLimit = Math.Max( + MinRemoteStructSchemaLimit, + (long)acceptedStructTypeCount * maxAverageSchemaVersionsPerType); + if (_totalAcceptedSchemaVersions >= globalLimit) + { + throw new InvalidDataException( + $"Remote schema version limit exceeded: {_totalAcceptedSchemaVersions} schemas for {acceptedStructTypeCount} " + + $"accepted struct types exceeds the average limit {maxAverageSchemaVersionsPerType}. Increase " + + "MaxAverageSchemaVersionsPerType if this peer legitimately sends many schema versions across many types."); + } + + return typeKey; + } + + private void RecordRemoteStructSchema(object? typeKey) + { + if (typeKey is null) + { + return; + } + + _remoteSchemaVersionsByType.TryGetValue(typeKey, out int versionsForType); + _remoteSchemaVersionsByType[typeKey] = versionsForType + 1; + _totalAcceptedSchemaVersions++; } internal MetaString? GetReadMetaString(int index) @@ -185,20 +236,11 @@ internal void AppendReadMetaString(MetaString value) _readMetaStrings.Add(value); } - internal TypeMeta ReadTypeMeta() + internal TypeMeta ReadTypeMeta(TypeInfo? exactLocal = null) { - uint indexMarker = Reader.ReadVarUInt32(); - bool isRef = (indexMarker & 1) == 1; - int index = checked((int)(indexMarker >> 1)); - if (isRef) + if (TryReadTypeMetaRef(out int index, out TypeMeta typeMeta)) { - TypeMeta? cached = GetReadTypeMeta(index); - if (cached is null) - { - throw new InvalidDataException($"unknown type meta ref index {index}"); - } - - return cached; + return typeMeta; } ulong header = Reader.ReadUInt64(); @@ -212,13 +254,89 @@ internal TypeMeta ReadTypeMeta() return cachedTypeMeta; } + if (exactLocal is not null && TryReadExactLocalTypeMeta(header, exactLocal, out TypeMeta localTypeMeta)) + { + StoreReadTypeMeta(localTypeMeta, index); + return localTypeMeta; + } + Reader.MoveBack(sizeof(ulong)); - TypeMeta typeMeta = TypeMeta.Decode(Reader); - StoreReadTypeMeta(typeMeta, index); + typeMeta = DecodeReadTypeMeta(); CacheReadTypeMeta(header, typeMeta); + StoreReadTypeMeta(typeMeta, index); return typeMeta; } + internal bool TryReadTypeMetaRef(out int index, out TypeMeta typeMeta) + { + uint indexMarker = Reader.ReadVarUInt32(); + bool isRef = (indexMarker & 1) == 1; + index = checked((int)(indexMarker >> 1)); + if (isRef) + { + TypeMeta? cached = GetReadTypeMeta(index); + if (cached is null) + { + throw new InvalidDataException($"unknown type meta ref index {index}"); + } + + typeMeta = cached; + return true; + } + + typeMeta = null!; + return false; + } + + [System.Runtime.CompilerServices.MethodImpl(System.Runtime.CompilerServices.MethodImplOptions.NoInlining)] + internal bool TryReadExactLocalTypeMeta(ulong header, TypeInfo exactLocal, out TypeMeta typeMeta) + { + TypeInfo.TypeMetaCacheEntry local = exactLocal.GetTypeMetaCacheEntry(TrackRef); + byte[] encoded = local.EncodedBytes; + if (encoded.Length < sizeof(ulong) || + BinaryPrimitives.ReadUInt64LittleEndian(encoded) != header) + { + typeMeta = null!; + return false; + } + + int bodyBytes = encoded.Length - sizeof(ulong); + TypeMeta.CheckEncodedBodySize(encoded, _config.MaxTypeMetaBytes); + Reader.CheckBound(bodyBytes); + int start = Reader.Cursor - sizeof(ulong); + if (!Reader.Storage.AsSpan(start, encoded.Length).SequenceEqual(encoded)) + { + typeMeta = null!; + return false; + } + + Reader.Skip(bodyBytes); + typeMeta = local.TypeMeta; + return true; + } + + [System.Runtime.CompilerServices.MethodImpl(System.Runtime.CompilerServices.MethodImplOptions.NoInlining)] + internal bool TryUseExactLocalTypeMeta(int start, int end, TypeInfo exactLocal, out TypeMeta typeMeta) + { + TypeInfo.TypeMetaCacheEntry local = exactLocal.GetTypeMetaCacheEntry(TrackRef); + byte[] encoded = local.EncodedBytes; + if (end - start != encoded.Length || + !Reader.Storage.AsSpan(start, encoded.Length).SequenceEqual(encoded)) + { + typeMeta = null!; + return false; + } + + typeMeta = local.TypeMeta; + return true; + } + + [System.Runtime.CompilerServices.MethodImpl(System.Runtime.CompilerServices.MethodImplOptions.NoInlining)] + internal TypeMeta DecodeReadTypeMeta() + { + return TypeMeta.Decode(Reader, _config.MaxTypeFields, _config.MaxTypeMetaBytes); + } + internal void StoreTypeMeta(Type type, TypeMeta typeMeta) { ulong typeKey = TypeMapKey.Get(type); diff --git a/csharp/src/Fory/TypeMeta.cs b/csharp/src/Fory/TypeMeta.cs index 31ea8e29aa..b45b30726c 100644 --- a/csharp/src/Fory/TypeMeta.cs +++ b/csharp/src/Fory/TypeMeta.cs @@ -394,6 +394,9 @@ public override int GetHashCode() public sealed class TypeMeta : IEquatable { + private const int DefaultMaxTypeFields = 512; + private const int DefaultMaxTypeMetaBytes = 4096; + private bool _hasAssignedFieldIds; public TypeMeta( @@ -486,14 +489,20 @@ public byte[] Encode() public static TypeMeta Decode(byte[] bytes) { - return Decode(new ByteReader(bytes)); + return Decode(new ByteReader(bytes), DefaultMaxTypeFields, DefaultMaxTypeMetaBytes); } public static TypeMeta Decode(ByteReader reader) + { + return Decode(reader, DefaultMaxTypeFields, DefaultMaxTypeMetaBytes); + } + + internal static TypeMeta Decode(ByteReader reader, int maxTypeFields, int maxTypeMetaBytes) { ulong header = reader.ReadUInt64(); ValidateGlobalHeader(header); int metaSize = ReadBodySize(reader, header); + CheckBodySize(metaSize, maxTypeMetaBytes); byte[] encodedBody = reader.ReadBytes(metaSize); ByteReader bodyReader = new(encodedBody); byte metaHeader = bodyReader.ReadUInt8(); @@ -516,6 +525,7 @@ public static TypeMeta Decode(ByteReader reader) { numFields += (int)bodyReader.ReadVarUInt32(); } + CheckFieldCount(numFields, maxTypeFields); } else { @@ -570,6 +580,15 @@ public static TypeMeta Decode(ByteReader reader) header >> TypeMetaConstants.TypeMetaHashShift); } + internal static void CheckEncodedBodySize(byte[] encoded, int maxTypeMetaBytes) + { + ByteReader reader = new(encoded); + ulong header = reader.ReadUInt64(); + ValidateGlobalHeader(header); + int metaSize = ReadBodySize(reader, header); + CheckBodySize(metaSize, maxTypeMetaBytes); + } + internal static void ValidateAndSkipBody(ByteReader reader, ulong header) { ValidateGlobalHeader(header); @@ -608,6 +627,24 @@ private static int ReadBodySize(ByteReader reader, ulong header) return metaSize; } + private static void CheckBodySize(int metaSize, int maxTypeMetaBytes) + { + if (metaSize > maxTypeMetaBytes) + { + throw new InvalidDataException( + $"Type metadata body size {metaSize} exceeds MaxTypeMetaBytes {maxTypeMetaBytes}. The data may be malicious. If the data is not malicious, please increase MaxTypeMetaBytes."); + } + } + + private static void CheckFieldCount(int numFields, int maxTypeFields) + { + if (numFields > maxTypeFields) + { + throw new InvalidDataException( + $"Type metadata field count {numFields} exceeds MaxTypeFields {maxTypeFields}. The data may be malicious. If the data is not malicious, please increase MaxTypeFields."); + } + } + internal static void SkipBody(ByteReader reader, ulong header) { reader.Skip(ReadBodySize(reader, header)); diff --git a/csharp/src/Fory/TypeResolver.cs b/csharp/src/Fory/TypeResolver.cs index 8313fed758..84a2109c9d 100644 --- a/csharp/src/Fory/TypeResolver.cs +++ b/csharp/src/Fory/TypeResolver.cs @@ -697,20 +697,14 @@ private void ReadTypeInfoCore(Type type, TypeInfo typeInfo, ReadContext context) !info.RegisterByName && typeId == TypeId.CompatibleStruct) { - TypeMeta remoteTypeMeta = context.ReadTypeMeta(); - if (!HasValidatedTypeMeta(info, remoteTypeMeta)) - { - ValidateTypeMeta( - remoteTypeMeta, - info, - info.UserTypeKind.Value, - info.RegisterByName, - context.Compatible, - typeId); - remoteTypeMeta.EnsureAssignedFieldIds(TypeMetaFields(info, context.TrackRef)); - SetValidatedTypeMeta(info, remoteTypeMeta); - } - + TypeMeta remoteTypeMeta = ReadCheckedTypeMeta( + info, + info.UserTypeKind.Value, + info.RegisterByName, + context.Compatible, + typeId, + assignFieldIds: true, + context); context.StoreTypeMeta(type, remoteTypeMeta); return; } @@ -729,20 +723,14 @@ private void ReadTypeInfoCore(Type type, TypeInfo typeInfo, ReadContext context) case TypeId.CompatibleStruct: case TypeId.NamedCompatibleStruct: { - TypeMeta remoteTypeMeta = context.ReadTypeMeta(); - if (!HasValidatedTypeMeta(info, remoteTypeMeta)) - { - ValidateTypeMeta( - remoteTypeMeta, - info, - declaredKind, - registerByName, - compatible, - typeId); - remoteTypeMeta.EnsureAssignedFieldIds(TypeMetaFields(info, context.TrackRef)); - SetValidatedTypeMeta(info, remoteTypeMeta); - } - + TypeMeta remoteTypeMeta = ReadCheckedTypeMeta( + info, + declaredKind, + registerByName, + compatible, + typeId, + assignFieldIds: true, + context); context.StoreTypeMeta(type, remoteTypeMeta); return; } @@ -781,19 +769,14 @@ private void ReadNamedTypeInfo( { if (compatible) { - TypeMeta remoteTypeMeta = context.ReadTypeMeta(); - if (!HasValidatedTypeMeta(typeInfo, remoteTypeMeta)) - { - ValidateTypeMeta( - remoteTypeMeta, - typeInfo, - declaredKind, - registerByName, - compatible, - wireTypeId); - SetValidatedTypeMeta(typeInfo, remoteTypeMeta); - } - + _ = ReadCheckedTypeMeta( + typeInfo, + declaredKind, + registerByName, + compatible, + wireTypeId, + assignFieldIds: false, + context); return; } @@ -818,6 +801,109 @@ private void ReadNamedTypeInfo( } } + private TypeMeta ReadCheckedTypeMeta( + TypeInfo typeInfo, + UserTypeKind declaredKind, + bool registerByName, + bool compatible, + TypeId wireTypeId, + bool assignFieldIds, + ReadContext context) + { + if (context.TryReadTypeMetaRef(out int index, out TypeMeta remoteTypeMeta)) + { + ValidateCheckedTypeMeta( + remoteTypeMeta, + typeInfo, + declaredKind, + registerByName, + compatible, + wireTypeId, + assignFieldIds, + context); + return remoteTypeMeta; + } + + ulong header = context.Reader.ReadUInt64(); + if (context.TryGetCachedReadTypeMeta(header, out remoteTypeMeta)) + { + // Checked cache hits are already parsed and hash-validated. Keep this + // path body-skip only; schema limits are owned by cold cache misses. + TypeMeta.SkipBody(context.Reader, header); + context.StoreReadTypeMeta(remoteTypeMeta, index); + ValidateCheckedTypeMeta( + remoteTypeMeta, + typeInfo, + declaredKind, + registerByName, + compatible, + wireTypeId, + assignFieldIds, + context); + return remoteTypeMeta; + } + + if (context.TryReadExactLocalTypeMeta(header, typeInfo, out remoteTypeMeta)) + { + ValidateCheckedTypeMeta( + remoteTypeMeta, + typeInfo, + declaredKind, + registerByName, + compatible, + wireTypeId, + assignFieldIds, + context); + context.StoreReadTypeMeta(remoteTypeMeta, index); + return remoteTypeMeta; + } + + context.Reader.MoveBack(sizeof(ulong)); + remoteTypeMeta = context.DecodeReadTypeMeta(); + ValidateCheckedTypeMeta( + remoteTypeMeta, + typeInfo, + declaredKind, + registerByName, + compatible, + wireTypeId, + assignFieldIds, + context); + context.CacheReadTypeMeta(header, remoteTypeMeta); + context.StoreReadTypeMeta(remoteTypeMeta, index); + return remoteTypeMeta; + } + + private void ValidateCheckedTypeMeta( + TypeMeta remoteTypeMeta, + TypeInfo typeInfo, + UserTypeKind declaredKind, + bool registerByName, + bool compatible, + TypeId wireTypeId, + bool assignFieldIds, + ReadContext context) + { + if (HasValidatedTypeMeta(typeInfo, remoteTypeMeta)) + { + return; + } + + ValidateTypeMeta( + remoteTypeMeta, + typeInfo, + declaredKind, + registerByName, + compatible, + wireTypeId); + if (assignFieldIds) + { + remoteTypeMeta.EnsureAssignedFieldIds(TypeMetaFields(typeInfo, context.TrackRef)); + } + + SetValidatedTypeMeta(typeInfo, remoteTypeMeta); + } + internal static TypeId ResolveWireTypeId( UserTypeKind declaredKind, bool registerByName, @@ -873,15 +959,8 @@ TypeId.CompatibleStruct or private object? ReadRegisteredValue( TypeInfo typeInfo, - ReadContext context, - TypeMeta? typeMeta) + ReadContext context) { - if (typeMeta is not null) - { - typeMeta.EnsureAssignedFieldIds(TypeMetaFields(typeInfo, context.TrackRef)); - context.StoreTypeMeta(typeInfo.Type, typeMeta); - } - return ReadObject(typeInfo, context, RefMode.None, false); } @@ -898,7 +977,7 @@ internal TypeInfo ReadAnyTypeInfo(ReadContext context) { case TypeId.CompatibleStruct: case TypeId.NamedCompatibleStruct: - return ResolveAnyTypeInfoFromMeta(wireTypeId, context.ReadTypeMeta(), context.Compatible); + return ReadAnyTypeMetaInfo(wireTypeId, context.Compatible, context); case TypeId.NamedEnum: case TypeId.NamedStruct: case TypeId.NamedExt: @@ -918,7 +997,7 @@ private TypeInfo ReadNamedAnyTypeInfo(TypeId wireTypeId, ReadContext context) { if (context.Compatible) { - return ResolveAnyTypeInfoFromMeta(wireTypeId, context.ReadTypeMeta(), compatible: true); + return ReadAnyTypeMetaInfo(wireTypeId, compatible: true, context); } MetaString namespaceName = ReadMetaString( @@ -932,6 +1011,49 @@ private TypeInfo ReadNamedAnyTypeInfo(TypeId wireTypeId, ReadContext context) return ResolveAnyUserTypeInfo(wireTypeId, namespaceName.Value, typeName.Value, compatible: false); } + private TypeInfo ReadAnyTypeMetaInfo(TypeId wireTypeId, bool compatible, ReadContext context) + { + if (context.TryReadTypeMetaRef(out int index, out TypeMeta typeMeta)) + { + TypeInfo typeInfo = ResolveAnyTypeInfoFromMeta(wireTypeId, typeMeta, compatible); + typeMeta.EnsureAssignedFieldIds(TypeMetaFields(typeInfo, context.TrackRef)); + context.StoreTypeMeta(typeInfo.Type, typeMeta); + return typeInfo; + } + + int typeMetaStart = context.Reader.Cursor; + ulong header = context.Reader.ReadUInt64(); + if (context.TryGetCachedReadTypeMeta(header, out typeMeta)) + { + // Header-cache hits stay body-skip only. Type resolution and field binding + // below are existing Any resolution work, not schema-limit checks. + TypeMeta.SkipBody(context.Reader, header); + context.StoreReadTypeMeta(typeMeta, index); + TypeInfo typeInfo = ResolveAnyTypeInfoFromMeta(wireTypeId, typeMeta, compatible); + typeMeta.EnsureAssignedFieldIds(TypeMetaFields(typeInfo, context.TrackRef)); + context.StoreTypeMeta(typeInfo.Type, typeMeta); + return typeInfo; + } + + context.Reader.MoveBack(sizeof(ulong)); + typeMeta = context.DecodeReadTypeMeta(); + int typeMetaEnd = context.Reader.Cursor; + TypeInfo resolvedInfo = ResolveAnyTypeInfoFromMeta(wireTypeId, typeMeta, compatible); + if (context.TryUseExactLocalTypeMeta(typeMetaStart, typeMetaEnd, resolvedInfo, out TypeMeta localTypeMeta)) + { + localTypeMeta.EnsureAssignedFieldIds(TypeMetaFields(resolvedInfo, context.TrackRef)); + context.StoreReadTypeMeta(localTypeMeta, index); + context.StoreTypeMeta(resolvedInfo.Type, localTypeMeta); + return resolvedInfo.WithWireTypeInfo(wireTypeId, localTypeMeta); + } + + typeMeta.EnsureAssignedFieldIds(TypeMetaFields(resolvedInfo, context.TrackRef)); + context.CacheReadTypeMeta(header, typeMeta); + context.StoreReadTypeMeta(typeMeta, index); + context.StoreTypeMeta(resolvedInfo.Type, typeMeta); + return resolvedInfo; + } + internal object? ReadAnyValue(TypeInfo typeInfo, ReadContext context) { TypeId wireTypeId = typeInfo.WireTypeId @@ -969,10 +1091,10 @@ private TypeInfo ReadNamedAnyTypeInfo(TypeId wireTypeId, ReadContext context) case TypeId.NamedUnion: case TypeId.CompatibleStruct: case TypeId.NamedCompatibleStruct: - return ReadNestedRegisteredValue(typeInfo, context, typeInfo.GetTypeMeta()); + return ReadNestedRegisteredValue(typeInfo, context); case TypeId.Enum: case TypeId.NamedEnum: - return ReadRegisteredValue(typeInfo, context, typeInfo.GetTypeMeta()); + return ReadRegisteredValue(typeInfo, context); case TypeId.None: return null; default: @@ -998,11 +1120,10 @@ private object ReadNestedAnyMap(ReadContext context) private object? ReadNestedRegisteredValue( TypeInfo typeInfo, - ReadContext context, - TypeMeta? typeMeta) + ReadContext context) { context.IncreaseReadDepth(); - object? value = ReadRegisteredValue(typeInfo, context, typeMeta); + object? value = ReadRegisteredValue(typeInfo, context); context.DecreaseReadDepth(); return value; } diff --git a/csharp/tests/Fory.Tests/ForyRuntimeTests.cs b/csharp/tests/Fory.Tests/ForyRuntimeTests.cs index b8c06f695d..d18ffa708d 100644 --- a/csharp/tests/Fory.Tests/ForyRuntimeTests.cs +++ b/csharp/tests/Fory.Tests/ForyRuntimeTests.cs @@ -1118,7 +1118,7 @@ public void MacroFieldOrderFollowsForyRules() short first = reader.ReadInt16(); long second = reader.ReadVarInt64(); int third = reader.ReadVarInt32(); - ReadContext tailContext = new(reader, new TypeResolver(), false, false); + ReadContext tailContext = new(reader, new TypeResolver(), fory.Config); string fourth = tailContext.TypeResolver.GetSerializer().ReadData(tailContext); Assert.Equal(value.B, first); @@ -2157,6 +2157,7 @@ public void Union2UsesZeroBasedWireCaseIds() { TypeResolver resolver = new(); Serializer> serializer = resolver.GetSerializer>(); + Config config = ForyRuntime.Builder().TrackRef(true).Build().Config; ByteWriter firstWriter = new(); WriteContext firstWrite = new(firstWriter, resolver, trackRef: true); @@ -2165,7 +2166,7 @@ public void Union2UsesZeroBasedWireCaseIds() Assert.Equal(0u, firstReader.ReadVarUInt32()); Union2 firstDecoded = - serializer.ReadData(new ReadContext(new ByteReader(firstWriter.ToArray()), resolver, trackRef: true)); + serializer.ReadData(new ReadContext(new ByteReader(firstWriter.ToArray()), resolver, config)); Assert.Equal(0, firstDecoded.Index); Assert.Equal("hello", firstDecoded.GetT1()); @@ -2176,7 +2177,7 @@ public void Union2UsesZeroBasedWireCaseIds() Assert.Equal(1u, secondReader.ReadVarUInt32()); Union2 secondDecoded = - serializer.ReadData(new ReadContext(new ByteReader(secondWriter.ToArray()), resolver, trackRef: true)); + serializer.ReadData(new ReadContext(new ByteReader(secondWriter.ToArray()), resolver, config)); Assert.Equal(1, secondDecoded.Index); Assert.Equal(42L, secondDecoded.GetT2()); @@ -2857,7 +2858,8 @@ private static (ulong Encoding, string Decoded) WriteAndReadString(string value) int byteLength = checked((int)(header >> 2)); Assert.Equal(payload.Length - headerReader.Cursor, byteLength); - ReadContext readContext = new(new ByteReader(payload), resolver, trackRef: false, compatible: false); + Config config = ForyRuntime.Builder().Compatible(false).Build().Config; + ReadContext readContext = new(new ByteReader(payload), resolver, config); string decoded = StringSerializer.ReadString(readContext); Assert.Equal(0, readContext.Reader.Remaining); return (encoding, decoded); diff --git a/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs b/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs index 78b57e8ef8..2e97aa7aa7 100644 --- a/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs +++ b/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs @@ -206,7 +206,8 @@ public void FieldSkipperSkipsTimePayloads(TypeId typeId) writer.WriteUInt8(0xA5); ByteReader reader = new(writer.ToArray()); - ReadContext context = new(reader, new TypeResolver(), trackRef: false); + Config config = ForyRuntime.Builder().Compatible(false).Build().Config; + ReadContext context = new(reader, new TypeResolver(), config); FieldSkipper.SkipFieldValue(context, new TypeMetaFieldType((uint)typeId, nullable: false)); @@ -525,26 +526,108 @@ public void DeserializeRejectsUnsupportedRootHeaderBits() } [Fact] - public void TypeMetaHeaderCacheStopsPublishingAtCapacity() + public void TypeMetaSchemaLimitRejectsExtraVersions() { - ReadContext context = new(new ByteReader(Array.Empty()), new TypeResolver(), trackRef: false); - TypeMeta typeMeta = new( - (uint)TypeId.Struct, - 901, - MetaString.Empty('.', '_'), - MetaString.Empty('$', '_'), - registerByName: false, - []); + Config config = ForyRuntime.Builder() + .Compatible(false) + .MaxSchemaVersionsPerType(1) + .Build() + .Config; + ReadContext context = new(new ByteReader(Array.Empty()), new TypeResolver(), config); + TypeMeta first = RemoteStructTypeMeta(901, "first"); + TypeMeta second = RemoteStructTypeMeta(901, "second"); - for (ulong header = 1; header <= 8192; header++) - { - context.CacheReadTypeMeta(header, typeMeta); - } + ReadAndCacheTypeMeta(context, first); - Assert.True(context.TryGetCachedReadTypeMeta(8192, out _)); - context.CacheReadTypeMeta(8193, typeMeta); + Assert.Throws(() => ReadAndCacheTypeMeta(context, second)); + } - Assert.False(context.TryGetCachedReadTypeMeta(8193, out _)); + [Fact] + public void TypeMetaFieldLimitRejectsLargeStruct() + { + Config config = ForyRuntime.Builder() + .Compatible(false) + .MaxTypeFields(1) + .Build() + .Config; + ReadContext context = new(new ByteReader(Array.Empty()), new TypeResolver(), config); + TypeMeta typeMeta = RemoteStructTypeMeta(901, "first", "second"); + + InvalidDataException exception = + Assert.Throws(() => ReadAndCacheTypeMeta(context, typeMeta)); + Assert.Contains("MaxTypeFields", exception.Message, StringComparison.Ordinal); + } + + [Fact] + public void TypeMetaBodyLimitRejectsLargeMetadata() + { + Config config = ForyRuntime.Builder() + .Compatible(false) + .MaxTypeMetaBytes(1) + .Build() + .Config; + ReadContext context = new(new ByteReader(Array.Empty()), new TypeResolver(), config); + + InvalidDataException exception = + Assert.Throws(() => ReadAndCacheTypeMeta(context, RemoteStructTypeMeta(901, "value"))); + Assert.Contains("MaxTypeMetaBytes", exception.Message, StringComparison.Ordinal); + } + + [Fact] + public void TypeMetaSchemaLimitKeepsUnknownTypesSeparate() + { + Config config = ForyRuntime.Builder() + .Compatible(false) + .MaxSchemaVersionsPerType(1) + .Build() + .Config; + ReadContext context = new(new ByteReader(Array.Empty()), new TypeResolver(), config); + TypeMeta first = RemoteStructTypeMeta(901, "value"); + TypeMeta second = RemoteStructTypeMeta(902, "value"); + + TypeMeta firstRead = ReadAndCacheTypeMeta(context, first); + TypeMeta secondRead = ReadAndCacheTypeMeta(context, second); + + Assert.True(context.TryGetCachedReadTypeMeta(EncodedTypeMetaHeader(firstRead), out _)); + Assert.True(context.TryGetCachedReadTypeMeta(EncodedTypeMetaHeader(secondRead), out _)); + } + + [Fact] + public void FailedAnyTypeMetaDoesNotConsumeLimit() + { + TypeResolver resolver = new(); + resolver.Register(typeof(CustomPayload), 901); + Config config = ForyRuntime.Builder() + .Compatible(true) + .MaxSchemaVersionsPerType(1) + .Build() + .Config; + ReadContext context = new(new ByteReader(Array.Empty()), resolver, config); + + Assert.Throws(() => + ReadAnyTypeInfo(context, resolver, RemoteCompatibleStructTypeMeta(901, "Id", MapType()))); + TypeMeta accepted = ReadAndCacheTypeMeta(context, RemoteStructTypeMeta(901, "second")); + + Assert.True(context.TryGetCachedReadTypeMeta(EncodedTypeMetaHeader(accepted), out _)); + } + + [Fact] + public void ExactAnyTypeMetaIsFree() + { + TypeResolver resolver = new(); + resolver.Register(typeof(CustomPayload), 901); + Config config = ForyRuntime.Builder() + .Compatible(true) + .MaxSchemaVersionsPerType(1) + .Build() + .Config; + ReadContext context = new(new ByteReader(Array.Empty()), resolver, config); + TypeMeta remote = ReadAndCacheTypeMeta(context, RemoteStructTypeMeta(901, "remote")); + TypeMeta exact = resolver.GetTypeInfo(typeof(CustomPayload)).GetTypeMetaCacheEntry(trackRef: false).TypeMeta; + + TypeInfo typeInfo = ReadAnyTypeInfo(context, resolver, exact); + + Assert.True(context.TryGetCachedReadTypeMeta(EncodedTypeMetaHeader(remote), out _)); } [Fact] @@ -566,13 +649,94 @@ public void TypeMetaHeaderCacheHitSkipsCurrentBodySize() writer.WriteBytes(new byte[0xff]); writer.WriteUInt8(0x7b); - ReadContext context = new(new ByteReader(writer.ToArray()), new TypeResolver(), trackRef: false); + Config config = ForyRuntime.Builder().Compatible(false).Build().Config; + ReadContext context = new(new ByteReader(writer.ToArray()), new TypeResolver(), config); context.CacheReadTypeMeta(header, typeMeta); Assert.Same(typeMeta, context.ReadTypeMeta()); Assert.Equal(0x7b, context.Reader.ReadUInt8()); } + private static TypeMeta RemoteStructTypeMeta(uint userTypeId, string fieldName) + { + return RemoteStructTypeMeta(userTypeId, [fieldName]); + } + + private static TypeMeta RemoteStructTypeMeta(uint userTypeId, params string[] fieldNames) + { + TypeMetaFieldInfo[] fields = new TypeMetaFieldInfo[fieldNames.Length]; + for (int i = 0; i < fieldNames.Length; i++) + { + fields[i] = new TypeMetaFieldInfo(null, fieldNames[i], new TypeMetaFieldType((uint)TypeId.Int32, nullable: false)); + } + + return new TypeMeta( + (uint)TypeId.Struct, + userTypeId, + MetaString.Empty('.', '_'), + MetaString.Empty('$', '_'), + registerByName: false, + fields); + } + + private static TypeMeta RemoteCompatibleStructTypeMeta(uint userTypeId, string fieldName) + { + return RemoteCompatibleStructTypeMeta( + userTypeId, + fieldName, + new TypeMetaFieldType((uint)TypeId.Int32, nullable: false)); + } + + private static TypeMeta RemoteCompatibleStructTypeMeta( + uint userTypeId, + string fieldName, + TypeMetaFieldType fieldType) + { + return new TypeMeta( + (uint)TypeId.CompatibleStruct, + userTypeId, + MetaString.Empty('.', '_'), + MetaString.Empty('$', '_'), + registerByName: false, + [new TypeMetaFieldInfo(null, fieldName, fieldType)]); + } + + private static TypeMetaFieldType MapType() + { + return new TypeMetaFieldType( + (uint)TypeId.Map, + nullable: false, + trackRef: false, + [ + new TypeMetaFieldType((uint)TypeId.String, nullable: false), + new TypeMetaFieldType((uint)TypeId.Int32, nullable: false), + ]); + } + + private static TypeMeta ReadAndCacheTypeMeta(ReadContext context, TypeMeta typeMeta) + { + ByteWriter writer = new(); + writer.WriteVarUInt32(0); + writer.WriteBytes(typeMeta.Encode()); + context.ResetFor(new ByteReader(writer.ToArray())); + return context.ReadTypeMeta(); + } + + private static TypeInfo ReadAnyTypeInfo(ReadContext context, TypeResolver resolver, TypeMeta typeMeta) + { + ByteWriter writer = new(); + writer.WriteUInt8((byte)TypeId.CompatibleStruct); + writer.WriteVarUInt32(0); + writer.WriteBytes(typeMeta.Encode()); + context.ResetFor(new ByteReader(writer.ToArray())); + return resolver.ReadAnyTypeInfo(context); + } + + private static ulong EncodedTypeMetaHeader(TypeMeta typeMeta) + { + return BitConverter.ToUInt64(typeMeta.Encode(), 0); + } + [Fact] public void DynamicAnyRejectsUnknownUserTypeId() { diff --git a/csharp/tests/Fory.Tests/StringSerializerTests.cs b/csharp/tests/Fory.Tests/StringSerializerTests.cs index 1b9ca4bf88..f0e97e45d7 100644 --- a/csharp/tests/Fory.Tests/StringSerializerTests.cs +++ b/csharp/tests/Fory.Tests/StringSerializerTests.cs @@ -74,7 +74,8 @@ public void StringSerializerRejectsOddUtf16Payload() writer.WriteVarUInt36Small((3UL << 2) | Utf16); writer.WriteBytes([0x61, 0x00, 0x62]); - ReadContext context = new(new ByteReader(writer.ToArray()), new TypeResolver(), trackRef: false, compatible: false); + Config config = Fory.Builder().Compatible(false).Build().Config; + ReadContext context = new(new ByteReader(writer.ToArray()), new TypeResolver(), config); Assert.Throws(() => StringSerializer.ReadString(context)); } @@ -84,7 +85,8 @@ public void StringSerializerRejectsUnknownEncoding() ByteWriter writer = new(); writer.WriteVarUInt36Small(3UL); - ReadContext context = new(new ByteReader(writer.ToArray()), new TypeResolver(), trackRef: false, compatible: false); + Config config = Fory.Builder().Compatible(false).Build().Config; + ReadContext context = new(new ByteReader(writer.ToArray()), new TypeResolver(), config); Assert.Throws(() => StringSerializer.ReadString(context)); } @@ -119,7 +121,8 @@ private static (ulong Encoding, string Decoded, int HeaderBytes, int ByteLength) int byteLength = checked((int)(header >> 2)); Assert.Equal(payload.Length - headerReader.Cursor, byteLength); - ReadContext readContext = new(new ByteReader(payload), resolver, trackRef: false, compatible: false); + Config config = Fory.Builder().Compatible(false).Build().Config; + ReadContext readContext = new(new ByteReader(payload), resolver, config); string decoded = StringSerializer.ReadString(readContext); Assert.Equal(0, readContext.Reader.Remaining); return (encoding, decoded, headerReader.Cursor, byteLength); diff --git a/dart/packages/fory/lib/src/config.dart b/dart/packages/fory/lib/src/config.dart index 2ab848e57f..fd155500b1 100644 --- a/dart/packages/fory/lib/src/config.dart +++ b/dart/packages/fory/lib/src/config.dart @@ -24,6 +24,10 @@ final class Config { /// Default maximum nesting depth for a single serialization or /// deserialization operation. static const int defaultMaxDepth = 256; + static const int defaultMaxTypeFields = 512; + static const int defaultMaxTypeMetaBytes = 4096; + static const int defaultMaxSchemaVersionsPerType = 10; + static const int defaultMaxAverageSchemaVersionsPerType = 3; /// Enables compatible struct encoding and decoding. /// @@ -39,6 +43,19 @@ final class Config { /// Maximum allowed read or write nesting depth. final int maxDepth; + /// Maximum accepted field count in one received struct TypeDef. + final int maxTypeFields; + + /// Maximum accepted body size in one received TypeDef. + final int maxTypeMetaBytes; + + /// Maximum accepted remote struct schema versions for one logical type. + final int maxSchemaVersionsPerType; + + /// Maximum accepted average remote struct schema versions across logical + /// types. + final int maxAverageSchemaVersionsPerType; + /// Creates an immutable configuration object. /// /// Invalid numeric limits fail fast. When [compatible] is `true`, @@ -47,6 +64,21 @@ final class Config { this.compatible = true, bool checkStructVersion = true, this.maxDepth = defaultMaxDepth, + this.maxTypeFields = defaultMaxTypeFields, + this.maxTypeMetaBytes = defaultMaxTypeMetaBytes, + this.maxSchemaVersionsPerType = defaultMaxSchemaVersionsPerType, + this.maxAverageSchemaVersionsPerType = + defaultMaxAverageSchemaVersionsPerType, }) : checkStructVersion = compatible ? false : checkStructVersion, - assert(maxDepth > 0, 'maxDepth must be positive'); + assert(maxDepth > 0, 'maxDepth must be positive'), + assert(maxTypeFields > 0, 'maxTypeFields must be positive'), + assert(maxTypeMetaBytes > 0, 'maxTypeMetaBytes must be positive'), + assert( + maxSchemaVersionsPerType > 0, + 'maxSchemaVersionsPerType must be positive', + ), + assert( + maxAverageSchemaVersionsPerType > 0, + 'maxAverageSchemaVersionsPerType must be positive', + ); } diff --git a/dart/packages/fory/lib/src/fory.dart b/dart/packages/fory/lib/src/fory.dart index 6daaa072e6..56c22feff5 100644 --- a/dart/packages/fory/lib/src/fory.dart +++ b/dart/packages/fory/lib/src/fory.dart @@ -57,11 +57,20 @@ final class Fory { bool compatible = true, bool checkStructVersion = true, int maxDepth = Config.defaultMaxDepth, + int maxTypeFields = Config.defaultMaxTypeFields, + int maxTypeMetaBytes = Config.defaultMaxTypeMetaBytes, + int maxSchemaVersionsPerType = Config.defaultMaxSchemaVersionsPerType, + int maxAverageSchemaVersionsPerType = + Config.defaultMaxAverageSchemaVersionsPerType, }) { final config = Config( compatible: compatible, checkStructVersion: checkStructVersion, maxDepth: maxDepth, + maxTypeFields: maxTypeFields, + maxTypeMetaBytes: maxTypeMetaBytes, + maxSchemaVersionsPerType: maxSchemaVersionsPerType, + maxAverageSchemaVersionsPerType: maxAverageSchemaVersionsPerType, ); _readBuffer = Buffer(); _writeBuffer = Buffer(); diff --git a/dart/packages/fory/lib/src/meta/type_meta.dart b/dart/packages/fory/lib/src/meta/type_meta.dart index 7d77e17252..71e75b2a7f 100644 --- a/dart/packages/fory/lib/src/meta/type_meta.dart +++ b/dart/packages/fory/lib/src/meta/type_meta.dart @@ -102,8 +102,6 @@ final class TypeHeader { } final class ParsedTypeMetaCache { - static const int maxEntries = 8192; - final LinkedHashMap _entries = LinkedHashMap(); Int64? _lastHeader; @@ -131,10 +129,6 @@ final class ParsedTypeMetaCache { _lastResolved = resolved; return; } - if (_entries.length >= maxEntries) { - return; - } - _entries[header.value] = resolved; _lastHeader = header.value; _lastResolved = resolved; @@ -147,7 +141,8 @@ final class WireTypeMetaEncoder { WireTypeMeta typeMetaFor(Config config, TypeInfo resolvedType) { final wireTypeId = wireTypeIdFor(config, resolvedType); - final writesTypeDef = wireTypeId == TypeIds.compatibleStruct || + final writesTypeDef = + wireTypeId == TypeIds.compatibleStruct || wireTypeId == TypeIds.namedCompatibleStruct || (config.compatible && (wireTypeId == TypeIds.namedEnum || @@ -222,15 +217,14 @@ final class WireTypeMetaDecoder { int wireTypeId, EncodedMetaString namespace, EncodedMetaString typeName, - ) resolveUserByEncodedNameCached, + ) + resolveUserByEncodedNameCached, required TypeInfo? Function(int wireTypeId) expectedNamedType, required WireTypeMeta Function() readTypeDef, - required EncodedMetaString Function([ - EncodedMetaString? expected, - ]) readPackageMetaString, - required EncodedMetaString Function([ - EncodedMetaString? expected, - ]) readTypeNameMetaString, + required EncodedMetaString Function([EncodedMetaString? expected]) + readPackageMetaString, + required EncodedMetaString Function([EncodedMetaString? expected]) + readTypeNameMetaString, }) { final wireTypeId = buffer.readVarUint32Small7(); if (_isBuiltinWireType(wireTypeId)) { diff --git a/dart/packages/fory/lib/src/resolver/type_resolver.dart b/dart/packages/fory/lib/src/resolver/type_resolver.dart index 3ad6486316..3568f7123a 100644 --- a/dart/packages/fory/lib/src/resolver/type_resolver.dart +++ b/dart/packages/fory/lib/src/resolver/type_resolver.dart @@ -242,6 +242,8 @@ List _validateFieldInfos(Iterable fields) { } final class TypeResolver { + static const int _minRemoteStructSchemaLimit = 8192; + final Config config; final WireTypeMetaEncoder _wireTypeMetaEncoder = const WireTypeMetaEncoder(); final WireTypeMetaDecoder _wireTypeMetaDecoder = const WireTypeMetaDecoder(); @@ -273,6 +275,8 @@ final class TypeResolver { _internedEncodedMetaStrings = <_EncodedMetaStringKey, EncodedMetaString>{}; final Map _initialTypeMetaBytes = LinkedHashMap.identity(); + final Map _remoteSchemaVersionsByType = {}; + int _totalAcceptedSchemaVersions = 0; TypeResolver(this.config); @@ -791,22 +795,16 @@ final class TypeResolver { return sharedTypes[marker >>> 1]; } final header = TypeHeader(buffer.readInt64()); - final expectedTypeDef = expected.typeDef; - if (expectedTypeDef == null || expectedTypeDef.header != header.value) { - final cached = _parsedTypeMetaCache.lookup(header); - if (cached != null) { - header.skipRemaining(buffer); - sharedTypes.add(cached); - return cached; - } - final resolved = _readTypeDefWithHeader(buffer, header); - _parsedTypeMetaCache.remember(header, resolved); - sharedTypes.add(resolved); - return resolved; + final cached = _parsedTypeMetaCache.lookup(header); + if (cached != null) { + header.skipRemaining(buffer); + sharedTypes.add(cached); + return cached; } - header.skipRemaining(buffer); - sharedTypes.add(expected); - return expected; + final typeDefStart = bufferReaderIndex(buffer) - 8; + final resolved = _readTypeDefWithHeader(buffer, header, typeDefStart); + sharedTypes.add(resolved); + return resolved; } Uint8List _encodeInitialTypeDefMeta(int wireTypeId, TypeDef typeDef) { @@ -1119,14 +1117,6 @@ final class TypeResolver { return wireTypeMetaForResolved(sharedTypes[index]); } final header = TypeHeader(buffer.readInt64()); - final expectedTypeDef = expectedType?.typeDef; - if (expectedTypeDef != null && expectedTypeDef.header == header.value) { - // Header-cache hits intentionally skip without rehashing. Entries reach this cache only - // after a successful TypeDef parse and 52-bit metadata-hash validation. - header.skipRemaining(buffer); - sharedTypes.add(expectedType!); - return wireTypeMetaForResolved(expectedType); - } final cached = _parsedTypeMetaCache.lookup(header); if (cached != null) { // Header-cache hits intentionally skip without rehashing. Entries reach this cache only @@ -1135,17 +1125,30 @@ final class TypeResolver { sharedTypes.add(cached); return wireTypeMetaForResolved(cached); } - final resolved = _readTypeDefWithHeader(buffer, header); - _parsedTypeMetaCache.remember(header, resolved); + final typeDefStart = bufferReaderIndex(buffer) - 8; + final resolved = _readTypeDefWithHeader(buffer, header, typeDefStart); sharedTypes.add(resolved); return wireTypeMetaForResolved(resolved); } - TypeInfo _readTypeDefWithHeader(Buffer buffer, TypeHeader header) { + @pragma('vm:never-inline') + TypeInfo _readTypeDefWithHeader( + Buffer buffer, + TypeHeader header, + int typeDefStart, + ) { header.validateGlobal(); final metaSize = header.readMetaSize(buffer); + if (metaSize > config.maxTypeMetaBytes) { + throw StateError( + 'Type metadata body size $metaSize exceeds maxTypeMetaBytes ' + '${config.maxTypeMetaBytes}. The data may be malicious. If the data ' + 'is not malicious, please increase maxTypeMetaBytes.', + ); + } buffer.checkReadableBytes(metaSize); final metaBody = buffer.readBytes(metaSize); + final typeDefEnd = bufferReaderIndex(buffer); final metaBytes = Buffer.wrap(metaBody); final classHeader = metaBytes.readUint8(); final isStruct = (classHeader & typeDefStructFlag) != 0; @@ -1165,6 +1168,13 @@ final class TypeResolver { if (fieldCount == typeDefSmallFieldCountThreshold) { fieldCount += metaBytes.readVarUint32Small7(); } + if (fieldCount > config.maxTypeFields) { + throw StateError( + 'Type metadata field count $fieldCount exceeds maxTypeFields ' + '${config.maxTypeFields}. The data may be malicious. If the data ' + 'is not malicious, please increase maxTypeFields.', + ); + } } else { if ((classHeader & 0x70) != 0) { throw StateError('Invalid TypeDef kind header.'); @@ -1180,6 +1190,11 @@ final class TypeResolver { if (!byName) { userTypeId = metaBytes.readVarUint32(); } + if (fieldCount > metaBytes.readableBytes) { + throw StateError( + 'Type metadata field count exceeds available body bytes.', + ); + } final fields = []; for (var i = 0; i < fieldCount; i += 1) { fields.add(_readTypeDefField(metaBytes)); @@ -1205,19 +1220,32 @@ final class TypeResolver { throw StateError('TypeDef kind does not match registered type metadata.'); } if (resolved.kind != RegistrationKind.struct) { + _parsedTypeMetaCache.remember(header, resolved); + return resolved; + } + final localTypeDef = resolved.typeDef; + if (localTypeDef != null && + _matchesEncodedTypeDef( + buffer, + typeDefStart, + typeDefEnd, + localTypeDef.encoded, + )) { + _parsedTypeMetaCache.remember(header, resolved); return resolved; } + final remoteSchemaKey = _checkRemoteStructSchemaLimit( + typeId: typeId, + userTypeId: userTypeId, + resolved: resolved, + ); final remoteTypeDef = TypeDef( evolving: true, fields: List.unmodifiable(fields), header: header.value, encoded: Uint8List(0), ); - final localTypeDef = resolved.typeDef; - if (localTypeDef != null && _sameTypeDef(localTypeDef, remoteTypeDef)) { - return resolved; - } - return TypeInfo( + final remoteResolved = TypeInfo( type: resolved.type, kind: resolved.kind, typeId: resolved.typeId, @@ -1234,6 +1262,76 @@ final class TypeResolver { typeDef: resolved.typeDef, remoteTypeDef: remoteTypeDef, ); + remoteResolved.structSerializer?.validateCompatibleTypeInfo(remoteResolved); + _parsedTypeMetaCache.remember(header, remoteResolved); + _recordRemoteStructSchema(remoteSchemaKey); + return remoteResolved; + } + + bool _matchesEncodedTypeDef( + Buffer buffer, + int typeDefStart, + int typeDefEnd, + Uint8List encoded, + ) { + if (typeDefEnd - typeDefStart != encoded.length) { + return false; + } + final bytes = buffer.toBytes(); + for (var i = 0; i < encoded.length; i += 1) { + if (bytes[typeDefStart + i] != encoded[i]) { + return false; + } + } + return true; + } + + @pragma('vm:never-inline') + String? _checkRemoteStructSchemaLimit({ + required int typeId, + required int? userTypeId, + required TypeInfo resolved, + }) { + if (!_isStructTypeDefKind(typeId)) { + return null; + } + final key = + userTypeId != null + ? 'i$userTypeId' + : 'n${resolved.namespace ?? ''}\u0000${resolved.typeName ?? ''}'; + final versionsForType = _remoteSchemaVersionsByType[key] ?? 0; + if (versionsForType >= config.maxSchemaVersionsPerType) { + throw StateError( + 'Remote struct schema versions for one type exceeded ' + 'maxSchemaVersionsPerType=${config.maxSchemaVersionsPerType}.', + ); + } + final acceptedStructTypeCount = + versionsForType == 0 + ? _remoteSchemaVersionsByType.length + 1 + : _remoteSchemaVersionsByType.length; + final averageLimit = + acceptedStructTypeCount * config.maxAverageSchemaVersionsPerType; + final globalLimit = + averageLimit > _minRemoteStructSchemaLimit + ? averageLimit + : _minRemoteStructSchemaLimit; + if (_totalAcceptedSchemaVersions >= globalLimit) { + throw StateError( + 'Remote struct schema versions exceeded global limit from ' + 'maxAverageSchemaVersionsPerType=${config.maxAverageSchemaVersionsPerType}.', + ); + } + return key; + } + + void _recordRemoteStructSchema(String? key) { + if (key == null) { + return; + } + final versionsForType = _remoteSchemaVersionsByType[key] ?? 0; + _remoteSchemaVersionsByType[key] = versionsForType + 1; + _totalAcceptedSchemaVersions += 1; } EncodedMetaString _readTypeDefName( @@ -1708,39 +1806,6 @@ final class TypeResolver { return Object.hash(wireTypeId, namespace.hash, typeName.hash); } - bool _sameTypeDef(TypeDef left, TypeDef right) { - if (left.evolving != right.evolving || - left.fields.length != right.fields.length) { - return false; - } - for (var index = 0; index < left.fields.length; index += 1) { - final leftField = left.fields[index]; - final rightField = right.fields[index]; - if (leftField.identifier != rightField.identifier || - leftField.id != rightField.id || - !_sameFieldType(leftField.fieldType, rightField.fieldType)) { - return false; - } - } - return true; - } - - bool _sameFieldType(FieldType left, FieldType right) { - if (left.typeId != right.typeId || - left.nullable != right.nullable || - left.ref != right.ref || - left.dynamic != right.dynamic || - left.arguments.length != right.arguments.length) { - return false; - } - for (var index = 0; index < left.arguments.length; index += 1) { - if (!_sameFieldType(left.arguments[index], right.arguments[index])) { - return false; - } - } - return true; - } - Type _builtinTypeForFieldType(FieldType fieldType) { switch (fieldType.typeId) { case TypeIds.int64: diff --git a/dart/packages/fory/lib/src/serializer/struct_serializer.dart b/dart/packages/fory/lib/src/serializer/struct_serializer.dart index beb0d2c1f9..5d90fbe298 100644 --- a/dart/packages/fory/lib/src/serializer/struct_serializer.dart +++ b/dart/packages/fory/lib/src/serializer/struct_serializer.dart @@ -138,6 +138,13 @@ final class StructSerializer extends Serializer { return value; } + @pragma('vm:never-inline') + void validateCompatibleTypeInfo(TypeInfo resolved) { + if (resolved.remoteTypeDef != null) { + _compatibleReadLayoutForResolved(resolved); + } + } + @pragma('vm:prefer-inline') CompatibleStructReadLayout _compatibleReadLayoutForResolved( TypeInfo resolved, diff --git a/dart/packages/fory/test/xlang_protocol_test.dart b/dart/packages/fory/test/xlang_protocol_test.dart index 99a2d63a08..f8ff823c49 100644 --- a/dart/packages/fory/test/xlang_protocol_test.dart +++ b/dart/packages/fory/test/xlang_protocol_test.dart @@ -17,9 +17,15 @@ * under the License. */ +import 'dart:collection'; import 'dart:typed_data'; import 'package:fory/fory.dart'; +import 'package:fory/src/codegen/generated_registry.dart'; +import 'package:fory/src/codegen/generated_support.dart'; +import 'package:fory/src/context/meta_string_reader.dart'; +import 'package:fory/src/context/meta_string_writer.dart'; +import 'package:fory/src/meta/type_def.dart'; import 'package:fory/src/meta/type_meta.dart'; import 'package:fory/src/resolver/type_resolver.dart'; import 'package:fory/src/util/hash_util.dart'; @@ -38,6 +44,110 @@ final class _CacheTestSerializer extends Serializer { void write(WriteContext context, Object? value) {} } +final class _SchemaLocal {} + +final class _SchemaRemoteA {} + +final class _SchemaRemoteB {} + +final class _SchemaRemoteC {} + +const _intFieldType = GeneratedFieldType( + type: int, + typeId: TypeIds.int32, + nullable: false, + ref: false, + dynamic: false, + arguments: [], +); + +const _mapFieldType = GeneratedFieldType( + type: Map, + typeId: TypeIds.map, + nullable: false, + ref: false, + dynamic: false, + arguments: [ + GeneratedFieldType( + type: String, + typeId: TypeIds.string, + nullable: false, + ref: false, + dynamic: false, + arguments: [], + ), + GeneratedFieldType( + type: int, + typeId: TypeIds.int32, + nullable: false, + ref: false, + dynamic: false, + arguments: [], + ), + ], +); + +GeneratedFieldInfo _generatedField(String name) => GeneratedFieldInfo( + name: name, + identifier: name, + id: null, + fieldType: _intFieldType, +); + +GeneratedFieldInfo _generatedMapField(String name) => GeneratedFieldInfo( + name: name, + identifier: name, + id: null, + fieldType: _mapFieldType, +); + +void _rememberSchema(Type type, List fields) { + GeneratedTypeCatalog.remember( + type, + GeneratedTypeEntry( + kind: GeneratedTypeKind.struct, + serializerFactory: () => const _CacheTestSerializer(), + evolving: true, + needsRootRef: false, + usesNestedTypeDefinitions: false, + fields: fields.map((field) => field.toFieldInfo()).toList(), + ), + ); +} + +Uint8List _typeMetaBytes( + Type type, + String name, + List fields, +) { + final resolver = TypeResolver(const Config()); + _rememberSchema(type, fields); + final parts = name.split('.'); + resolver.registerGenerated( + type, + namespace: parts.first, + typeName: parts.last, + ); + final resolved = resolver.resolveUserByName(parts.first, parts.last); + final buffer = Buffer(); + resolver.writeTypeMeta( + buffer, + resolved, + typeDef: resolved.typeDef, + typeDefIds: LinkedHashMap.identity(), + metaStringWriter: MetaStringWriter(), + ); + return buffer.toBytes(); +} + +void _readTypeMeta(TypeResolver resolver, Uint8List bytes) { + resolver.readTypeMeta( + Buffer.wrap(bytes), + sharedTypes: [], + metaStringReader: MetaStringReader(resolver), + ); +} + void main() { group('xlang protocol regressions', () { test('deserializes NONE wire values as null', () { @@ -51,17 +161,18 @@ void main() { test('deserializes FLOAT16_ARRAY wire values', () { final fory = Fory(); final bytes = Uint8List.fromList( - fory.serialize( - Uint16List.fromList([0x3c00, 0xc000, 0x7e00]), - ), + fory.serialize(Uint16List.fromList([0x3c00, 0xc000, 0x7e00])), ); bytes[2] = TypeIds.float16Array; final values = fory.deserialize(bytes); expect( - Uint16List.view(values.buffer, values.offsetInBytes, values.length) - .toList(), + Uint16List.view( + values.buffer, + values.offsetInBytes, + values.length, + ).toList(), orderedEquals([0x3c00, 0xc000, 0x7e00]), ); }); @@ -121,7 +232,7 @@ void main() { ); }); - test('parsed TypeDef cache stops publishing at capacity', () { + test('parsed TypeDef cache publishes beyond old implementation floor', () { const resolved = TypeInfo( type: Object, kind: RegistrationKind.builtin, @@ -140,18 +251,116 @@ void main() { remoteTypeDef: null, ); final cache = ParsedTypeMetaCache(); - for (var i = 0; i < ParsedTypeMetaCache.maxEntries; i++) { + const oldImplementationFloor = 8192; + for (var i = 0; i < oldImplementationFloor; i++) { cache.remember(TypeHeader(Int64(i)), resolved); } expect( - cache.lookup(TypeHeader(Int64(ParsedTypeMetaCache.maxEntries - 1))), + cache.lookup(TypeHeader(Int64(oldImplementationFloor - 1))), same(resolved), ); - final uncached = TypeHeader(Int64(ParsedTypeMetaCache.maxEntries)); - cache.remember(uncached, resolved); + final aboveOldFloor = TypeHeader(Int64(oldImplementationFloor)); + cache.remember(aboveOldFloor, resolved); + + expect(cache.lookup(aboveOldFloor), same(resolved)); + }); + + test('remote schema limit rejects extra versions', () { + const name = 'example.Unknown'; + final reader = TypeResolver(const Config(maxSchemaVersionsPerType: 1)); + _rememberSchema(_SchemaLocal, []); + reader.registerGenerated( + _SchemaLocal, + namespace: 'example', + typeName: 'Unknown', + ); + final first = _typeMetaBytes(_SchemaRemoteA, name, [ + _generatedField('firstValue'), + ]); + final second = _typeMetaBytes(_SchemaRemoteB, name, [ + _generatedField('secondValue'), + ]); + + _readTypeMeta(reader, first); + + expect(() => _readTypeMeta(reader, second), throwsA(isA())); + }); + + test('type meta field limit rejects large struct', () { + final reader = TypeResolver(const Config(maxTypeFields: 1)); + final bytes = _typeMetaBytes( + _SchemaRemoteA, + 'example.TooManyFields', + [ + _generatedField('firstValue'), + _generatedField('secondValue'), + ], + ); + + expect(() => _readTypeMeta(reader, bytes), throwsA(isA())); + }); + + test('type meta body limit rejects large metadata', () { + final reader = TypeResolver(const Config(maxTypeMetaBytes: 1)); + final bytes = _typeMetaBytes( + _SchemaRemoteA, + 'example.LargeTypeMeta', + [_generatedField('value')], + ); + + expect(() => _readTypeMeta(reader, bytes), throwsA(isA())); + }); + + test('remote schema limit keeps unknown types separate', () { + final reader = TypeResolver(const Config(maxSchemaVersionsPerType: 1)); + _rememberSchema(_SchemaLocal, []); + reader.registerGenerated( + _SchemaLocal, + namespace: 'example', + typeName: 'UnknownA', + ); + _rememberSchema(_SchemaRemoteC, []); + reader.registerGenerated( + _SchemaRemoteC, + namespace: 'example', + typeName: 'UnknownB', + ); + final first = _typeMetaBytes( + _SchemaRemoteA, + 'example.UnknownA', + [_generatedField('firstValue')], + ); + final second = _typeMetaBytes( + _SchemaRemoteB, + 'example.UnknownB', + [_generatedField('secondValue')], + ); + + _readTypeMeta(reader, first); + _readTypeMeta(reader, second); + }); + + test('failed remote schema does not consume schema limit', () { + const name = 'example.Accepted'; + final reader = TypeResolver(const Config(maxSchemaVersionsPerType: 1)); + _rememberSchema(_SchemaLocal, [ + _generatedField('value'), + ]); + reader.registerGenerated( + _SchemaLocal, + namespace: 'example', + typeName: 'Accepted', + ); + final invalid = _typeMetaBytes(_SchemaRemoteA, name, [ + _generatedMapField('value'), + ]); + final valid = _typeMetaBytes(_SchemaRemoteB, name, [ + _generatedField('extraValue'), + ]); - expect(cache.lookup(uncached), isNull); + expect(() => _readTypeMeta(reader, invalid), throwsA(isA())); + _readTypeMeta(reader, valid); }); test('validates parsed TypeDef body hash before caching', () { diff --git a/docs/guide/cpp/configuration.md b/docs/guide/cpp/configuration.md index 478e82f8fd..766bd433e7 100644 --- a/docs/guide/cpp/configuration.md +++ b/docs/guide/cpp/configuration.md @@ -115,6 +115,57 @@ This limits the maximum depth for nested polymorphic object serialization (e.g., - **Increase**: For legitimate deeply nested data structures - **Decrease**: For stricter security requirements or shallow data structures +### max_schema_versions_per_type(uint32_t) + +Set the maximum accepted remote struct schema versions for one logical type on +metadata cache misses. + +```cpp +auto fory = Fory::builder() + .max_schema_versions_per_type(10) + .build(); +``` + +**Default:** `10` + +### max_type_fields(uint32_t) + +Set the maximum fields accepted in one received remote struct metadata body. + +```cpp +auto fory = Fory::builder() + .max_type_fields(512) + .build(); +``` + +**Default:** `512` + +### max_type_meta_bytes(uint32_t) + +Set the maximum encoded body bytes accepted for one received TypeDef body, +excluding the 8-byte header and any extended-size varint. + +```cpp +auto fory = Fory::builder() + .max_type_meta_bytes(4096) + .build(); +``` + +**Default:** `4096` + +### max_average_schema_versions_per_type(uint32_t) + +Set the average accepted remote struct schema versions across accepted remote +struct types. The effective global floor is `8192` schemas. + +```cpp +auto fory = Fory::builder() + .max_average_schema_versions_per_type(3) + .build(); +``` + +**Default:** `3` + ### check_struct_version(bool) Enable/disable struct version checking. @@ -150,13 +201,17 @@ auto fory = Fory::builder().build_thread_safe(); // Returns ThreadSafeFory ## Configuration Summary -| Option | Description | Default | -| ---------------------------- | --------------------------------------- | ------- | -| `xlang(bool)` | Use xlang mode | `true` | -| `compatible(bool)` | Enable schema evolution | `true` | -| `track_ref(bool)` | Enable reference tracking | `true` | -| `max_dyn_depth(uint32_t)` | Maximum nesting depth for dynamic types | `5` | -| `check_struct_version(bool)` | Enable struct version checking | `false` | +| Option | Description | Default | +| ------------------------------------------------ | -------------------------------------------------- | ------- | +| `xlang(bool)` | Use xlang mode | `true` | +| `compatible(bool)` | Enable schema evolution | `true` | +| `track_ref(bool)` | Enable reference tracking | `true` | +| `max_dyn_depth(uint32_t)` | Maximum nesting depth for dynamic types | `5` | +| `max_type_fields(uint32_t)` | Max fields in one received struct metadata body | `512` | +| `max_type_meta_bytes(uint32_t)` | Max encoded bytes in one received metadata body | `4096` | +| `max_schema_versions_per_type(uint32_t)` | Max remote schema versions for one struct type | `10` | +| `max_average_schema_versions_per_type(uint32_t)` | Average remote schema versions across struct types | `3` | +| `check_struct_version(bool)` | Enable struct version checking | `false` | ## Security @@ -166,6 +221,8 @@ Security-related configuration: - Use `check_struct_version(true)` with `compatible(false)` for intentional same-schema payloads. - Keep `max_dyn_depth(...)` as low as your model permits to reject unexpectedly deep polymorphic graphs. +- Keep the remote schema metadata limits at their defaults unless the data is not malicious and a + trusted peer sends larger metadata or many schema versions. - Prefer concrete fields over broad polymorphic fields for untrusted input. ## Related Topics diff --git a/docs/guide/csharp/configuration.md b/docs/guide/csharp/configuration.md index a158a6d5e3..4bf450cf23 100644 --- a/docs/guide/csharp/configuration.md +++ b/docs/guide/csharp/configuration.md @@ -35,12 +35,16 @@ ThreadSafeFory threadSafe = Fory.Builder().BuildThreadSafe(); `Fory.Builder().Build()` uses: -| Option | Default | Description | -| -------------------- | ------- | -------------------------------------------- | -| `TrackRef` | `false` | Reference tracking disabled | -| `Compatible` | `true` | Compatible schema-evolution metadata enabled | -| `CheckStructVersion` | `false` | Struct schema hash checks disabled | -| `MaxDepth` | `20` | Max dynamic nesting depth | +| Option | Default | Description | +| --------------------------------- | ------- | -------------------------------------------------- | +| `TrackRef` | `false` | Reference tracking disabled | +| `Compatible` | `true` | Compatible schema-evolution metadata enabled | +| `CheckStructVersion` | `false` | Struct schema hash checks disabled | +| `MaxDepth` | `20` | Max dynamic nesting depth | +| `MaxTypeFields` | `512` | Max fields in one received struct metadata body | +| `MaxTypeMetaBytes` | `4096` | Max encoded bytes in one received metadata body | +| `MaxSchemaVersionsPerType` | `10` | Max remote schema versions for one struct type | +| `MaxAverageSchemaVersionsPerType` | `3` | Average remote schema versions across struct types | ## Builder Options @@ -92,6 +96,49 @@ Fory fory = Fory.Builder() `value` must be greater than `0`. +### `MaxTypeFields(int value)` + +Sets the maximum fields accepted in one received remote struct metadata body. + +```csharp +Fory fory = Fory.Builder() + .MaxTypeFields(512) + .Build(); +``` + +### `MaxTypeMetaBytes(int value)` + +Sets the maximum encoded body bytes accepted for one received TypeMeta body, +excluding the 8-byte header and any extended-size varint. + +```csharp +Fory fory = Fory.Builder() + .MaxTypeMetaBytes(4096) + .Build(); +``` + +### `MaxSchemaVersionsPerType(int value)` + +Sets the maximum accepted remote struct schema versions for one logical type on +metadata cache misses. + +```csharp +Fory fory = Fory.Builder() + .MaxSchemaVersionsPerType(10) + .Build(); +``` + +### `MaxAverageSchemaVersionsPerType(int value)` + +Sets the average accepted remote struct schema versions across accepted remote +struct types. The effective global floor is `8192` schemas. + +```csharp +Fory fory = Fory.Builder() + .MaxAverageSchemaVersionsPerType(3) + .Build(); +``` + ## Common Configurations ### Compatible service @@ -127,6 +174,8 @@ Security-related configuration: - Register only the expected types before deserializing untrusted payloads. - Use `CheckStructVersion(true)` with `Compatible(false)` for intentional same-schema payloads. - Set `MaxDepth(...)` to reject unexpectedly deep dynamic object graphs. +- Keep the remote schema metadata limits at their defaults unless the data is not malicious and a + trusted peer sends larger metadata or many schema versions. - Prefer generated or registered concrete models over broad dynamic fields for untrusted input. ## Related Topics diff --git a/docs/guide/dart/configuration.md b/docs/guide/dart/configuration.md index f84ff49c2a..f2586556a0 100644 --- a/docs/guide/dart/configuration.md +++ b/docs/guide/dart/configuration.md @@ -34,6 +34,10 @@ final fory = Fory(); // customize limits while keeping default compatible mode final fory = Fory( maxDepth: 512, + maxTypeFields: 512, + maxTypeMetaBytes: 4096, + maxSchemaVersionsPerType: 10, + maxAverageSchemaVersionsPerType: 3, ); ``` @@ -82,13 +86,38 @@ Limits how deeply nested an object graph can be. Increase this if you have legit final fory = Fory(maxDepth: 128); ``` +### Remote schema metadata limits + +Compatible mode accepts remote struct metadata on metadata cache misses. These +limits bound retained schema-specific read state: + +```dart +final fory = Fory( + maxTypeFields: 512, + maxTypeMetaBytes: 4096, + maxSchemaVersionsPerType: 10, + maxAverageSchemaVersionsPerType: 3, +); +``` + +- `maxTypeFields` limits fields in one received struct metadata body. +- `maxTypeMetaBytes` limits encoded body bytes in one received TypeMeta body, excluding the 8-byte + header and any extended-size varint. +- `maxSchemaVersionsPerType` limits accepted schema versions for one logical struct type. +- `maxAverageSchemaVersionsPerType` limits the average across accepted remote struct types. The + effective global floor is `8192` schemas. + ## Defaults -| Option | Default | -| -------------------- | ------- | -| `compatible` | `true` | -| `checkStructVersion` | `false` | -| `maxDepth` | 256 | +| Option | Default | +| --------------------------------- | ------- | +| `compatible` | `true` | +| `checkStructVersion` | `false` | +| `maxDepth` | 256 | +| `maxTypeFields` | 512 | +| `maxTypeMetaBytes` | 4096 | +| `maxSchemaVersionsPerType` | 10 | +| `maxAverageSchemaVersionsPerType` | 3 | ## Xlang Notes @@ -105,6 +134,8 @@ Security-related configuration: - Register only the expected generated models before deserializing untrusted payloads. - Use `checkStructVersion: true` with `compatible: false` for intentional same-schema payloads. - Set `maxDepth` to reject unexpectedly deep payload shapes. +- Keep the remote schema metadata limits at their defaults unless the data is not malicious and a + trusted peer sends larger metadata or many schema versions. - Prefer generated schemas and explicit field metadata over broad dynamic fields for untrusted input. ## Related Topics diff --git a/docs/guide/go/configuration.md b/docs/guide/go/configuration.md index eb1aae5846..5f62ac1ed2 100644 --- a/docs/guide/go/configuration.md +++ b/docs/guide/go/configuration.md @@ -33,12 +33,16 @@ f := fory.New(fory.WithXlang(true)) Default settings: -| Option | Default | Description | -| ---------- | ------- | -------------------------------------------- | -| TrackRef | false | Reference tracking disabled | -| MaxDepth | 20 | Maximum nesting depth | -| IsXlang | true | Xlang mode enabled | -| Compatible | true | Compatible schema-evolution metadata enabled | +| Option | Default | Description | +| ------------------------------- | ------- | -------------------------------------------------- | +| TrackRef | false | Reference tracking disabled | +| MaxDepth | 20 | Maximum nesting depth | +| IsXlang | true | Xlang mode enabled | +| Compatible | true | Compatible schema-evolution metadata enabled | +| MaxTypeFields | 512 | Max fields in one received struct metadata body | +| MaxTypeMetaBytes | 4096 | Max encoded bytes in one received metadata body | +| MaxSchemaVersionsPerType | 10 | Max remote schema versions for one struct type | +| MaxAverageSchemaVersionsPerType | 3 | Average remote schema versions across struct types | ### With Options @@ -47,6 +51,10 @@ f := fory.New( fory.WithXlang(true), fory.WithTrackRef(true), fory.WithMaxDepth(10), + fory.WithMaxTypeFields(512), + fory.WithMaxTypeMetaBytes(4096), + fory.WithMaxSchemaVersionsPerType(10), + fory.WithMaxAverageSchemaVersionsPerType(3), ) ``` @@ -119,6 +127,41 @@ f := fory.New(fory.WithMaxDepth(30)) - Protects against deeply nested, recursive structures or malicious data - Serialization fails with error when exceeded +### WithMaxTypeFields + +Set the maximum fields accepted in one received remote struct metadata body: + +```go +f := fory.New(fory.WithMaxTypeFields(512)) +``` + +### WithMaxTypeMetaBytes + +Set the maximum encoded body bytes accepted for one received TypeDef body, +excluding the 8-byte header and any extended-size varint: + +```go +f := fory.New(fory.WithMaxTypeMetaBytes(4096)) +``` + +### WithMaxSchemaVersionsPerType + +Set the maximum accepted remote struct schema versions for one logical type on +metadata cache misses: + +```go +f := fory.New(fory.WithMaxSchemaVersionsPerType(10)) +``` + +### WithMaxAverageSchemaVersionsPerType + +Set the average accepted remote struct schema versions across accepted remote +struct types. The effective global floor is `8192` schemas: + +```go +f := fory.New(fory.WithMaxAverageSchemaVersionsPerType(3)) +``` + ### WithXlang Select the wire mode: @@ -346,6 +389,8 @@ Security-related configuration: - Register only the expected structs before deserializing untrusted data. - Use `WithMaxDepth(...)` to reject unexpectedly deep payloads. +- Keep the remote schema metadata limits at their defaults unless the data is not malicious and a + trusted peer sends larger metadata or many schema versions. - Prefer concrete struct fields over broad `any` or interface-typed fields for untrusted input. ## Related Topics diff --git a/docs/guide/java/configuration.md b/docs/guide/java/configuration.md index a42da9f949..02fa66fcef 100644 --- a/docs/guide/java/configuration.md +++ b/docs/guide/java/configuration.md @@ -38,6 +38,10 @@ This page documents all configuration options available through `ForyBuilder`. | `registerGuavaTypes` | Whether to pre-register Guava types such as `RegularImmutableMap`/`RegularImmutableList`. These types are not public API, but seem pretty stable. | `true` | | `requireClassRegistration` | Disabling may allow unknown classes to be deserialized, potentially causing security risks. | `true` | | `maxDepth` | Set max depth for deserialization, when depth exceeds, an exception will be thrown. This can be used to refuse deserialization DDOS attack. | `50` | +| `maxTypeFields` | Maximum fields accepted in one received remote struct metadata body. | `512` | +| `maxTypeMetaBytes` | Maximum encoded body bytes accepted for one received TypeDef or TypeMeta body, excluding the 8-byte header and any extended-size varint. | `4096` | +| `maxSchemaVersionsPerType` | Maximum accepted remote struct schema versions for one logical type before Fory rejects additional remote metadata on cache misses. | `10` | +| `maxAverageSchemaVersionsPerType` | Average accepted remote struct schema versions across all accepted remote struct types. The effective global floor is `8192` schemas. | `3` | | `suppressClassRegistrationWarnings` | Whether to suppress class registration warnings. The warnings can be used for security audit, but may be annoying, this suppression will be enabled by default. | `true` | | `metaShareEnabled` | Enables or disables meta share mode. | `true` if compatible mode is enabled, otherwise false. | | `scopedMetaShareEnabled` | Scoped meta share focuses on a single serialization process. Metadata created or identified during this process is exclusive to it and is not shared with by other serializations. | `true` if compatible mode is enabled, otherwise false. | @@ -93,6 +97,13 @@ Security-related options: - `requireClassRegistration(true)` restricts deserialization to registered classes. - `withMaxDepth(...)` rejects unexpectedly deep object graphs. +- `withMaxTypeFields(...)` and `withMaxTypeMetaBytes(...)` bound the field count + and encoded body size of one received remote metadata body before Fory builds + read state from it. +- `withMaxSchemaVersionsPerType(...)` and + `withMaxAverageSchemaVersionsPerType(...)` bound remote struct metadata cache + misses without changing registration, dynamic loading, or schema-evolution + semantics. - `withDeserializeUnknownClass(false)` avoids materializing unknown classes from metadata. - `checkJdkClassSerializable(true)` keeps the JDK serializability check for `java.*` classes. - Class registration warnings can be useful during security audits; use diff --git a/docs/guide/javascript/configuration.md b/docs/guide/javascript/configuration.md index e4a0fb3bd9..09f5770ec6 100644 --- a/docs/guide/javascript/configuration.md +++ b/docs/guide/javascript/configuration.md @@ -43,18 +43,26 @@ const fory = new Fory({ ref: true, compatible: true, maxDepth: 100, + maxTypeFields: 512, + maxTypeMetaBytes: 4096, + maxSchemaVersionsPerType: 10, + maxAverageSchemaVersionsPerType: 3, hps, }); ``` -| Option | Default | Description | -| -------------------------- | ------- | ------------------------------------------------------------------------------------- | -| `ref` | `false` | Enable reference tracking for shared or circular object graphs | -| `compatible` | `true` | Allow field additions/removals without breaking existing messages | -| `maxDepth` | `50` | Maximum nesting depth. Must be `>= 2`. Increase for deeply nested structures | -| `useSliceString` | `false` | Optional string-reading optimization for Node.js. Leave at default unless benchmarked | -| `hps` | unset | Optional fast string helper from `@apache-fory/hps` (Node.js 20+) | -| `hooks.afterCodeGenerated` | unset | Callback to inspect the generated serializer code, useful for debugging | +| Option | Default | Description | +| --------------------------------- | ------- | ------------------------------------------------------------------------------------- | +| `ref` | `false` | Enable reference tracking for shared or circular object graphs | +| `compatible` | `true` | Allow field additions/removals without breaking existing messages | +| `maxDepth` | `50` | Maximum nesting depth. Must be `>= 2`. Increase for deeply nested structures | +| `maxTypeFields` | `512` | Maximum fields accepted in one received remote struct metadata body | +| `maxTypeMetaBytes` | `4096` | Maximum encoded body bytes accepted for one received TypeMeta body | +| `maxSchemaVersionsPerType` | `10` | Maximum accepted remote struct schema versions for one logical type | +| `maxAverageSchemaVersionsPerType` | `3` | Average accepted remote struct schema versions across accepted remote struct types | +| `useSliceString` | `false` | Optional string-reading optimization for Node.js. Leave at default unless benchmarked | +| `hps` | unset | Optional fast string helper from `@apache-fory/hps` (Node.js 20+) | +| `hooks.afterCodeGenerated` | unset | Callback to inspect the generated serializer code, useful for debugging | ## Reference Tracking @@ -102,6 +110,11 @@ Security-related configuration: - Register only the expected schemas before deserializing untrusted payloads. - Set `maxDepth` for the maximum nesting depth your service accepts. +- Keep `maxTypeFields` and `maxTypeMetaBytes` at their defaults unless the data + is not malicious and a trusted peer sends larger remote metadata. +- Keep `maxSchemaVersionsPerType` and + `maxAverageSchemaVersionsPerType` at their defaults unless the data is not + malicious and a trusted peer sends many remote schema versions. - Prefer explicit `Type.struct(...)` schemas over `Type.any()` for untrusted input. - Pass `hps` only from the official package version you deploy with Fory. diff --git a/docs/guide/kotlin/configuration.md b/docs/guide/kotlin/configuration.md index 8c40d6ea6e..061f14a7f8 100644 --- a/docs/guide/kotlin/configuration.md +++ b/docs/guide/kotlin/configuration.md @@ -135,6 +135,8 @@ and any untrusted payload source: val fory = ForyKotlin.builder() .requireClassRegistration(true) .withMaxDepth(50) + .withMaxTypeFields(512) + .withMaxTypeMetaBytes(4096) .build() ``` @@ -142,5 +144,8 @@ Security-related configuration: - Keep `requireClassRegistration(true)` and register application classes or generated modules. - Use `withMaxDepth(...)` to reject unexpectedly deep object graphs. +- Keep `withMaxTypeFields(...)`, `withMaxTypeMetaBytes(...)`, and the remote schema-version limits + at their defaults unless the data is not malicious and a trusted peer sends larger metadata or + many schema versions. - Follow [Java Configuration](../java/configuration.md#security) for allow-listing and unknown-class controls. diff --git a/docs/guide/python/configuration.md b/docs/guide/python/configuration.md index 02c61eb113..9acc991cd8 100644 --- a/docs/guide/python/configuration.md +++ b/docs/guide/python/configuration.md @@ -36,6 +36,10 @@ class Fory: strict: bool = True, compatible: Optional[bool] = None, max_depth: int = 50, + max_type_fields: int = 512, + max_type_meta_bytes: int = 4096, + max_schema_versions_per_type: int = 10, + max_average_schema_versions_per_type: int = 3, policy: DeserializationPolicy = None, field_nullable: bool = False, meta_compressor=None, @@ -55,17 +59,21 @@ class ThreadSafeFory: ## Parameters -| Parameter | Type | Default | Description | -| ----------------- | ------------------------------- | ------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `xlang` | `bool` | `True` | Use xlang mode. Set `False` for Python native mode. | -| `ref` | `bool` | `False` | Enable reference tracking for shared/circular references. Disable for better performance if your data has no shared references. | -| `strict` | `bool` | `True` | Require type registration for security. Keep this enabled for production unless a policy owns trust decisions. | -| `compatible` | `bool \| None` | `None` | Schema evolution mode. `None` enables compatible mode in both xlang and native mode. Set `False` only when every reader and writer uses the same schema. | -| `max_depth` | `int` | `50` | Maximum deserialization depth for security, preventing stack overflow attacks. | -| `policy` | `DeserializationPolicy \| None` | `None` | Deserialization policy used for security checks. Strongly recommended when `strict=False`. | -| `field_nullable` | `bool` | `False` | Treat dataclass fields as nullable by default. | -| `meta_compressor` | `Any` | `None` | Optional metadata compressor used for compatible-mode metadata encoding. | -| `fory_factory` | `Callable \| None` | `None` | `ThreadSafeFory` factory hook. When set, `ThreadSafeFory` creates instances via this callback; otherwise it forwards `**kwargs` to `Fory` construction. | +| Parameter | Type | Default | Description | +| -------------------------------------- | ------------------------------- | ------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `xlang` | `bool` | `True` | Use xlang mode. Set `False` for Python native mode. | +| `ref` | `bool` | `False` | Enable reference tracking for shared/circular references. Disable for better performance if your data has no shared references. | +| `strict` | `bool` | `True` | Require type registration for security. Keep this enabled for production unless a policy owns trust decisions. | +| `compatible` | `bool \| None` | `None` | Schema evolution mode. `None` enables compatible mode in both xlang and native mode. Set `False` only when every reader and writer uses the same schema. | +| `max_depth` | `int` | `50` | Maximum deserialization depth for security, preventing stack overflow attacks. | +| `max_type_fields` | `int` | `512` | Maximum fields accepted in one received remote struct metadata body. | +| `max_type_meta_bytes` | `int` | `4096` | Maximum encoded body bytes accepted for one received TypeDef body, excluding the 8-byte header and any extended-size varint. | +| `max_schema_versions_per_type` | `int` | `10` | Maximum accepted remote struct schema versions for one logical type on metadata cache misses. | +| `max_average_schema_versions_per_type` | `int` | `3` | Average accepted remote struct schema versions across accepted remote struct types. The effective global floor is `8192` schemas. | +| `policy` | `DeserializationPolicy \| None` | `None` | Deserialization policy used for security checks. Strongly recommended when `strict=False`. | +| `field_nullable` | `bool` | `False` | Treat dataclass fields as nullable by default. | +| `meta_compressor` | `Any` | `None` | Optional metadata compressor used for compatible-mode metadata encoding. | +| `fory_factory` | `Callable \| None` | `None` | `ThreadSafeFory` factory hook. When set, `ThreadSafeFory` creates instances via this callback; otherwise it forwards `**kwargs` to `Fory` construction. | ## Key Methods @@ -185,6 +193,10 @@ fory = pyfory.Fory( ref=False, strict=True, max_depth=50, + max_type_fields=512, + max_type_meta_bytes=4096, + max_schema_versions_per_type=10, + max_average_schema_versions_per_type=3, ) fory.register(UserModel, name="example.User") @@ -204,6 +216,16 @@ fory = pyfory.Fory( ) ``` +Remote struct metadata is also limited on metadata cache misses: + +- `max_type_fields` limits the number of fields accepted in one received struct metadata body. +- `max_type_meta_bytes` limits the encoded body bytes accepted for one received TypeDef body. +- `max_schema_versions_per_type` limits accepted remote schema versions for one logical type. +- `max_average_schema_versions_per_type` limits the average across accepted remote struct types. + +These limits do not change `strict`, `policy`, dynamic loading, unknown-class handling, or +schema-evolution semantics. + ### DeserializationPolicy When `strict=False` is necessary, use `DeserializationPolicy` to restrict the dynamic types and diff --git a/docs/guide/rust/configuration.md b/docs/guide/rust/configuration.md index 4d93c53056..d52127e8e6 100644 --- a/docs/guide/rust/configuration.md +++ b/docs/guide/rust/configuration.md @@ -88,6 +88,28 @@ let fory = Fory::builder().max_dyn_depth(10).build(); // Allow up to 10 levels Note: Static data types (non-dynamic types) are secure by nature and not subject to depth limits, as their structure is known at compile time. +### Remote Schema Metadata Limits + +Compatible mode accepts remote struct metadata on metadata cache misses. These +limits bound retained schema-specific read state: + +```rust +let fory = Fory::builder() + .max_type_fields(512) + .max_type_meta_bytes(4096) + .max_schema_versions_per_type(10) + .max_average_schema_versions_per_type(3) + .build(); +``` + +- `max_type_fields` defaults to `512` and limits fields in one received struct metadata body. +- `max_type_meta_bytes` defaults to `4096` and limits encoded body bytes in one received TypeDef or + TypeMeta body, excluding the 8-byte header and any extended-size varint. +- `max_schema_versions_per_type` defaults to `10` and limits accepted schema versions for one + logical struct type. +- `max_average_schema_versions_per_type` defaults to `3` and limits the average across accepted + remote struct types. The effective global floor is `8192` schemas. + ### Explicit Xlang Examples Set `.xlang(true)` explicitly for xlang serialization examples: @@ -122,11 +144,15 @@ let fory = Fory::builder() ## Configuration Summary -| Option | Description | Default | -| -------------------- | --------------------------------------- | ------- | -| `compatible(bool)` | Enable schema evolution | `true` | -| `xlang(bool)` | Use xlang mode | `true` | -| `max_dyn_depth(u32)` | Maximum nesting depth for dynamic types | `5` | +| Option | Description | Default | +| --------------------------------------------- | -------------------------------------------------- | ------- | +| `compatible(bool)` | Enable schema evolution | `true` | +| `xlang(bool)` | Use xlang mode | `true` | +| `max_dyn_depth(u32)` | Maximum nesting depth for dynamic types | `5` | +| `max_type_fields(usize)` | Max fields in one received struct metadata body | `512` | +| `max_type_meta_bytes(usize)` | Max encoded bytes in one received metadata body | `4096` | +| `max_schema_versions_per_type(usize)` | Max remote schema versions for one struct type | `10` | +| `max_average_schema_versions_per_type(usize)` | Average remote schema versions across struct types | `3` | ## Compatible Mode @@ -143,6 +169,8 @@ Security-related configuration: - Register application structs and trait-object implementations before deserializing untrusted payloads. - Use `max_dyn_depth(...)` to reject unexpectedly deep dynamic object graphs. +- Keep the remote schema metadata limits at their defaults unless the data is not malicious and a + trusted peer sends larger metadata or many schema versions. - Prefer concrete typed fields over `dyn Any` or broad trait-object fields for untrusted input. ## Related Topics diff --git a/docs/guide/scala/configuration.md b/docs/guide/scala/configuration.md index d78e648d28..397657d873 100644 --- a/docs/guide/scala/configuration.md +++ b/docs/guide/scala/configuration.md @@ -179,6 +179,8 @@ and any untrusted payload source: val fory = ForyScala.builder() .requireClassRegistration(true) .withMaxDepth(50) + .withMaxTypeFields(512) + .withMaxTypeMetaBytes(4096) .build() ``` @@ -186,5 +188,8 @@ Security-related configuration: - Keep `requireClassRegistration(true)` and register application classes or generated modules. - Use `withMaxDepth(...)` to reject unexpectedly deep object graphs. +- Keep `withMaxTypeFields(...)`, `withMaxTypeMetaBytes(...)`, and the remote schema-version limits + at their defaults unless the data is not malicious and a trusted peer sends larger metadata or + many schema versions. - Follow [Java Configuration](../java/configuration.md#security) for allow-listing and unknown-class controls. diff --git a/docs/guide/swift/configuration.md b/docs/guide/swift/configuration.md index 85c1e33257..5e74039cb8 100644 --- a/docs/guide/swift/configuration.md +++ b/docs/guide/swift/configuration.md @@ -31,6 +31,10 @@ public struct Config { public let compatible: Bool public let checkClassVersion: Bool public let maxDepth: Int + public let maxTypeFields: Int + public let maxTypeMetaBytes: Int + public let maxSchemaVersionsPerType: Int + public let maxAverageSchemaVersionsPerType: Int } ``` @@ -86,10 +90,25 @@ let fory = Fory(compatible: false, checkClassVersion: true) ### Size and Depth Limits -`maxDepth` bounds decoded payload nesting depth. +`maxDepth` bounds decoded payload nesting depth. Remote struct metadata is also +limited on metadata cache misses: + +- `maxTypeFields` defaults to `512` and limits fields in one received struct metadata body. +- `maxTypeMetaBytes` defaults to `4096` and limits encoded body bytes in one received TypeMeta body, + excluding the 8-byte header and any extended-size varint. +- `maxSchemaVersionsPerType` defaults to `10` and limits accepted schema versions for one logical + struct type. +- `maxAverageSchemaVersionsPerType` defaults to `3` and limits the average across accepted remote + struct types. The effective global floor is `8192` schemas. ```swift -let fory = Fory(maxDepth: 5) +let fory = Fory( + maxDepth: 5, + maxTypeFields: 512, + maxTypeMetaBytes: 4096, + maxSchemaVersionsPerType: 10, + maxAverageSchemaVersionsPerType: 3 +) ``` ## Recommended Presets @@ -121,3 +140,5 @@ Security-related configuration: - Register only the expected generated models before deserializing untrusted payloads. - Use `checkClassVersion` with `compatible: false` for intentional same-schema payloads. - Set `maxDepth` for the largest nesting depth your service accepts. +- Keep the remote schema metadata limits at their defaults unless the data is not malicious and a + trusted peer sends larger metadata or many schema versions. diff --git a/docs/guide/xlang/serialization.md b/docs/guide/xlang/serialization.md index 338f58fcb6..b2a8c61b4d 100644 --- a/docs/guide/xlang/serialization.md +++ b/docs/guide/xlang/serialization.md @@ -23,6 +23,46 @@ This page demonstrates common cross-language serialization patterns. Data serial supported language can be deserialized in any other supported language when peers use matching type identity, field schema, and compatibility settings. +## Remote Schema Metadata Limits + +Compatible mode may receive remote struct metadata (`TypeDef` or `TypeMeta`) when a reader does not +already know the metadata header. Fory limits how many distinct remote schema versions can be +accepted before building schema-specific read state, and also limits the size of each received +metadata body: + +- `maxSchemaVersionsPerType`: maximum accepted remote schema versions for one logical struct type. + The default is `10`. +- `maxAverageSchemaVersionsPerType`: average accepted remote schema versions across all accepted + remote struct types. The default is `3`; the effective global floor is `8192` schemas. +- `maxTypeFields`: maximum fields declared by one received struct metadata body. The default is + `512`. +- `maxTypeMetaBytes`: maximum encoded metadata body bytes for one received TypeDef or TypeMeta body, + excluding the 8-byte header and any extended-size varint. The default is `4096`. + +These limits are checked only on the cold metadata parse and cache-miss paths. They do not change +wire format, class registration, dynamic loading, unknown-type, or schema-evolution semantics. Cache +hits and generated field readers remain hot paths and do not add extra validation for these limits. +Failed or incompatible metadata is not counted against schema-version limits: a schema version is +accepted only after the schema-specific read state is successfully built and its owning metadata +cache can publish it. A remote schema that exactly matches the local registered schema is accepted +without consuming the remote-schema limit, so normal local traffic can still be read even after +other remote schemas have filled the limit. + +Raise these values only when the data is not malicious and a trusted peer sends larger metadata or +many schema versions. + +| Language | Field-count option | Metadata-bytes option | Per-type option | Average option | +| --------------------- | ------------------- | ---------------------- | ------------------------------ | -------------------------------------- | +| Java | `withMaxTypeFields` | `withMaxTypeMetaBytes` | `withMaxSchemaVersionsPerType` | `withMaxAverageSchemaVersionsPerType` | +| Python | `max_type_fields` | `max_type_meta_bytes` | `max_schema_versions_per_type` | `max_average_schema_versions_per_type` | +| JavaScript/TypeScript | `maxTypeFields` | `maxTypeMetaBytes` | `maxSchemaVersionsPerType` | `maxAverageSchemaVersionsPerType` | +| C++ | `max_type_fields` | `max_type_meta_bytes` | `max_schema_versions_per_type` | `max_average_schema_versions_per_type` | +| Go | `WithMaxTypeFields` | `WithMaxTypeMetaBytes` | `WithMaxSchemaVersionsPerType` | `WithMaxAverageSchemaVersionsPerType` | +| Rust | `max_type_fields` | `max_type_meta_bytes` | `max_schema_versions_per_type` | `max_average_schema_versions_per_type` | +| C# | `MaxTypeFields` | `MaxTypeMetaBytes` | `MaxSchemaVersionsPerType` | `MaxAverageSchemaVersionsPerType` | +| Swift | `maxTypeFields` | `maxTypeMetaBytes` | `maxSchemaVersionsPerType` | `maxAverageSchemaVersionsPerType` | +| Dart | `maxTypeFields` | `maxTypeMetaBytes` | `maxSchemaVersionsPerType` | `maxAverageSchemaVersionsPerType` | + ## Serialize Built-in Types Common types can be serialized automatically without registration: primitive numeric types, string, binary, array, list, map, and more. diff --git a/docs/security/deserialization.md b/docs/security/deserialization.md index 390b1855b4..549c53b46d 100644 --- a/docs/security/deserialization.md +++ b/docs/security/deserialization.md @@ -197,6 +197,34 @@ Metadata readers should: policy decisions. - Reset or release metadata state at the correct root-operation boundary. +Remote struct metadata that can create schema-specific read state must be +limited on the cold metadata cache-miss path. The check is resource control +only: it must not change wire compatibility, type registration, dynamic class +loading, unknown-type handling, deserialization policy, or schema-evolution +semantics. Metadata cache hits and generated field readers remain hot paths and +must not add validation, hashing, allocation, or policy work for this limit. +Failed or incompatible metadata must not consume the limit. Count a remote +schema version only after schema-specific read state has been successfully built +and the owning metadata cache can publish it. +If a remote metadata body exactly matches a local registered schema after the +metadata body and hash have been validated, the reader may use the local schema +without consuming the remote-schema limit. + +Remote struct metadata bodies and field lists must also be bounded on the cold +metadata parse path. `maxTypeMetaBytes` limits the encoded metadata body bytes +for one received TypeDef or TypeMeta body, excluding the 8-byte header and any +extended-size varint. `maxTypeFields` limits the number of fields declared by +one received struct metadata body. For Java native TypeDef class layers, the +field limit applies to the total field count across the class layers in that +one TypeDef. These limits are checked before copying, decompressing, reserving, +or allocating from attacker-declared metadata sizes or field counts. + +The default limits are `maxTypeFields = 512` and `maxTypeMetaBytes = 4096`. +Runtimes should report limit failures as possible malicious data and tell users +to increase the exact option only when the data is not malicious. These limits +must not introduce validation on metadata cache-hit, generated serializer, or +already-resolved type-id hot paths. + Metadata byte-form strictness alone is not a security requirement. Rejecting a metadata shape is useful only when the owner wants that strictness or when the shape changes type identity, retained state, resource use, or policy behavior. diff --git a/docs/specification/java_serialization_spec.md b/docs/specification/java_serialization_spec.md index 609bae7001..607fffe772 100644 --- a/docs/specification/java_serialization_spec.md +++ b/docs/specification/java_serialization_spec.md @@ -410,6 +410,12 @@ Root kind codes: Class layers are encoded from parent to leaf. Field lists inside each layer use the field order defined above. +Readers may reject a received TypeDef that exceeds runtime resource limits such +as maximum metadata body bytes or maximum fields in one TypeDef. These limits +are receive-side resource controls and do not change the TypeDef wire encoding, +type identity, dynamic class loading, unknown-class handling, registration +policy, or schema-evolution semantics. + ### Field Info Each field is encoded as: diff --git a/docs/specification/xlang_implementation_guide.md b/docs/specification/xlang_implementation_guide.md index da68fd5859..5c9e501bf2 100644 --- a/docs/specification/xlang_implementation_guide.md +++ b/docs/specification/xlang_implementation_guide.md @@ -403,6 +403,17 @@ are readable through the byte owner. Field-list allocation should happen after that body readability check and should not use a separate small initial-capacity cap as a security rule. +Implementations should also bound received struct metadata on the cold metadata +parse path. `maxTypeMetaBytes` limits one encoded TypeDef or TypeMeta body, +excluding the 8-byte header and any extended-size varint, and is checked before +copying or decompressing that body. `maxTypeFields` limits the number of fields +declared by one received struct metadata body and is checked before reserving or +allocating the field list. These limits are runtime resource controls; they do +not change wire encoding, type identity, dynamic loading, unknown-type behavior, +deserialization policy, or schema-evolution semantics. Metadata cache hits and +generated field readers remain hot paths and must not add work for these +limits. + Skip paths do not need to materialize skipped values. Existing byte-skip operations should consume any available buffered prefix first, then skip or drop remaining stream bytes in bounded steps. diff --git a/docs/specification/xlang_serialization_spec.md b/docs/specification/xlang_serialization_spec.md index cf46904344..196e069731 100644 --- a/docs/specification/xlang_serialization_spec.md +++ b/docs/specification/xlang_serialization_spec.md @@ -743,6 +743,12 @@ Meta header byte for non-struct TypeDefs: - Bits 4-6: reserved (must be zero). - Bits 0-3: kind code. +Readers may reject a received TypeDef that exceeds runtime resource limits such +as maximum metadata body bytes or maximum fields in one struct TypeDef. These +limits are receive-side resource controls and do not change TypeDef wire +encoding, type identity, dynamic loading, unknown-type handling, registration +policy, or schema-evolution semantics. + Non-struct kind codes: - `0`: `ENUM` diff --git a/go/fory/fory.go b/go/fory/fory.go index 964d72c85d..1a45f42e3e 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -65,20 +65,26 @@ const ( // Config holds configuration options for Fory instances type Config struct { - TrackRef bool - MaxDepth int - IsXlang bool - Compatible bool // Schema evolution compatibility mode - MaxTypeFields int + TrackRef bool + MaxDepth int + IsXlang bool + Compatible bool // Schema evolution compatibility mode + MaxTypeFields int + MaxTypeMetaBytes int + MaxSchemaVersionsPerType int + MaxAverageSchemaVersionsPerType int } // defaultConfig returns the default configuration func defaultConfig() Config { return Config{ - TrackRef: false, // Match Java's default: reference tracking disabled - MaxDepth: 20, - IsXlang: true, - MaxTypeFields: 10000, + TrackRef: false, // Match Java's default: reference tracking disabled + MaxDepth: 20, + IsXlang: true, + MaxTypeFields: 512, + MaxTypeMetaBytes: 4096, + MaxSchemaVersionsPerType: 10, + MaxAverageSchemaVersionsPerType: 3, } } @@ -121,11 +127,44 @@ func WithCompatible(enabled bool) Option { // WithMaxTypeFields sets the maximum field count limit for schema definition deserialization func WithMaxTypeFields(size int) Option { + if size <= 0 { + panic("MaxTypeFields must be positive") + } return func(f *Fory) { f.config.MaxTypeFields = size } } +// WithMaxTypeMetaBytes sets the maximum body size for received type metadata. +func WithMaxTypeMetaBytes(size int) Option { + if size <= 0 { + panic("MaxTypeMetaBytes must be positive") + } + return func(f *Fory) { + f.config.MaxTypeMetaBytes = size + } +} + +// WithMaxSchemaVersionsPerType sets the maximum accepted remote schema versions for one struct type. +func WithMaxSchemaVersionsPerType(size int) Option { + if size <= 0 { + panic("MaxSchemaVersionsPerType must be positive") + } + return func(f *Fory) { + f.config.MaxSchemaVersionsPerType = size + } +} + +// WithMaxAverageSchemaVersionsPerType sets the average remote schema version limit across accepted struct types. +func WithMaxAverageSchemaVersionsPerType(size int) Option { + if size <= 0 { + panic("MaxAverageSchemaVersionsPerType must be positive") + } + return func(f *Fory) { + f.config.MaxAverageSchemaVersionsPerType = size + } +} + // ============================================================================ // Fory - Main serialization instance // ============================================================================ diff --git a/go/fory/type_def.go b/go/fory/type_def.go index f654e7b8f4..50e7b24f9b 100644 --- a/go/fory/type_def.go +++ b/go/fory/type_def.go @@ -396,6 +396,7 @@ func buildTypeDef(fory *Fory, value reflect.Value) (*TypeDef, error) { } registerByName := IsNamespacedType(TypeId(typeId)) typeDef := NewTypeDef(typeId, infoPtr.UserTypeID, infoPtr.PkgPathBytes, infoPtr.NameBytes, registerByName, false, fieldDefs) + typeDef.type_ = value.Type() // encoding the typeDef, and save the encoded bytes encoded, err := encodingTypeDef(fory.typeResolver, typeDef) @@ -1015,6 +1016,11 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro extraMetaSize = int(extra) metaSize += extraMetaSize } + if metaSize > fory.config.MaxTypeMetaBytes { + return nil, fmt.Errorf( + "type metadata body size %d exceeds MaxTypeMetaBytes %d. The data may be malicious. If the data is not malicious, please increase MaxTypeMetaBytes", + metaSize, fory.config.MaxTypeMetaBytes) + } // Store the encoded bytes for the TypeDef (including meta header and metadata) encodedMeta := buffer.ReadBinary(metaSize, &bufErr) if bufErr.HasError() { @@ -1043,8 +1049,13 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro if metaErr.HasError() { return nil, metaErr.TakeError() } - if fieldCount > fory.config.MaxTypeFields || fieldCount > metaBuffer.remaining() { - return nil, fmt.Errorf("field count exceeds maximum allowed limit or available buffer size") + if fieldCount > fory.config.MaxTypeFields { + return nil, fmt.Errorf( + "type metadata field count %d exceeds MaxTypeFields %d. The data may be malicious. If the data is not malicious, please increase MaxTypeFields", + fieldCount, fory.config.MaxTypeFields) + } + if fieldCount > metaBuffer.remaining() { + return nil, fmt.Errorf("type metadata field count exceeds available buffer size") } if registeredByName { if (metaHeaderByte & CompatibleTypeDefFlag) != 0 { @@ -1172,9 +1183,7 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro if fallbackInfo, fallbackExists := fory.typeResolver.namedTypeToTypeInfo[nameKey]; fallbackExists { info = fallbackInfo exists = true - if len(fory.typeResolver.nsTypeToTypeInfo) < maxCachedNamedTypeInfos { - fory.typeResolver.nsTypeToTypeInfo[nsTypeKey{nsBytes.Hashcode, nameBytes.Hashcode}] = info - } + fory.typeResolver.nsTypeToTypeInfo[nsTypeKey{nsBytes.Hashcode, nameBytes.Hashcode}] = info } } if exists { @@ -1216,12 +1225,12 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro } fieldInfos[i] = fieldInfo } - if !isStruct && len(fieldInfos) != 0 { - return nil, fmt.Errorf("non-struct TypeDef cannot carry field metadata") - } if metaErr.HasError() { return nil, metaErr.TakeError() } + if !isStruct && len(fieldInfos) != 0 { + return nil, fmt.Errorf("non-struct TypeDef cannot carry field metadata") + } if remaining := metaBuffer.remaining(); remaining != 0 { return nil, fmt.Errorf("TypeDef metadata body has %d trailing bytes", remaining) } diff --git a/go/fory/type_def_test.go b/go/fory/type_def_test.go index a3f1bb0cc3..9856f7a6b3 100644 --- a/go/fory/type_def_test.go +++ b/go/fory/type_def_test.go @@ -38,6 +38,14 @@ type SliceStruct struct { Items []string } +type SchemaLimitBad struct { + Value string +} + +type SchemaLimitExtra struct { + Extra int32 +} + type NestedSliceStruct struct { ID int32 Matrix [][]int @@ -345,6 +353,38 @@ func TestTypeDefFieldCountOOMPanic(t *testing.T) { } } +func TestTypeDefRejectsMaxTypeFields(t *testing.T) { + writer := NewFory(WithXlang(false), WithCompatible(true)) + require.NoError(t, writer.RegisterStructByName(SimpleStruct{}, "example.SimpleStruct")) + typeDef, err := buildTypeDef(writer, reflect.ValueOf(SimpleStruct{})) + require.NoError(t, err) + buffer := NewByteBuffer(typeDef.encoded) + readErr := &Error{} + header := buffer.ReadInt64(readErr) + require.NoError(t, readErr.CheckError()) + + reader := NewFory(WithXlang(false), WithCompatible(true), WithMaxTypeFields(1)) + _, err = decodeTypeDef(reader, buffer, header) + require.Error(t, err) + require.Contains(t, err.Error(), "MaxTypeFields") +} + +func TestTypeDefRejectsMaxTypeMetaBytes(t *testing.T) { + writer := NewFory(WithXlang(false), WithCompatible(true)) + require.NoError(t, writer.RegisterStructByName(SimpleStruct{}, "example.SimpleStruct")) + typeDef, err := buildTypeDef(writer, reflect.ValueOf(SimpleStruct{})) + require.NoError(t, err) + buffer := NewByteBuffer(typeDef.encoded) + readErr := &Error{} + header := buffer.ReadInt64(readErr) + require.NoError(t, readErr.CheckError()) + + reader := NewFory(WithXlang(false), WithCompatible(true), WithMaxTypeMetaBytes(1)) + _, err = decodeTypeDef(reader, buffer, header) + require.Error(t, err) + require.Contains(t, err.Error(), "MaxTypeMetaBytes") +} + func TestTypeDefRejectsReservedGlobalHeaderBits(t *testing.T) { fory := NewFory(WithXlang(false), WithCompatible(false)) buffer := NewByteBuffer(nil) @@ -448,16 +488,13 @@ func TestTypeDefRejectsCompressedMetadata(t *testing.T) { require.Contains(t, err.Error(), "compressed xlang TypeDef") } -func TestReadSharedTypeMetaCapsParsedTypeDefCache(t *testing.T) { +func TestReadSharedTypeMetaExactLocalSkipsRemoteCache(t *testing.T) { fory := NewFory(WithXlang(false), WithCompatible(true)) require.NoError(t, fory.RegisterStructByName(SimpleStruct{}, "example.SimpleStruct")) typeDef, err := buildTypeDef(fory, reflect.ValueOf(SimpleStruct{})) require.NoError(t, err) require.NotEmpty(t, typeDef.encoded) - for i := 0; i < maxCachedTypeDefs; i++ { - fory.typeResolver.defIdToTypeDef[int64(i)] = typeDef - } headerErr := &Error{} header := NewByteBuffer(typeDef.encoded).ReadInt64(headerErr) require.NoError(t, headerErr.CheckError()) @@ -470,11 +507,69 @@ func TestReadSharedTypeMetaCapsParsedTypeDefCache(t *testing.T) { typeInfo := fory.typeResolver.readSharedTypeMeta(buffer, readErr) require.NoError(t, readErr.CheckError()) require.NotNil(t, typeInfo) - require.Len(t, fory.typeResolver.defIdToTypeDef, maxCachedTypeDefs) require.NotContains(t, fory.typeResolver.defIdToTypeDef, header) } -func TestDecodeTypeDefFallbackNamedTypeCacheRespectsCap(t *testing.T) { +func TestRemoteSchemaLimitRejectsExtraVersions(t *testing.T) { + fory := NewFory(WithXlang(false), WithCompatible(true), WithMaxSchemaVersionsPerType(1)) + first := remoteSchemaLimitTypeDef(t, SimpleStruct{}, "example.Shared") + second := remoteSchemaLimitTypeDef(t, SliceStruct{}, "example.Shared") + + require.NoError(t, readRemoteTypeDef(t, fory, first)) + err := readRemoteTypeDef(t, fory, second) + + require.Error(t, err) + require.Contains(t, err.Error(), "MaxSchemaVersionsPerType") +} + +func TestRemoteSchemaLimitKeepsUnknownTypesSeparate(t *testing.T) { + fory := NewFory(WithXlang(false), WithCompatible(true), WithMaxSchemaVersionsPerType(1)) + first := remoteSchemaLimitTypeDef(t, SimpleStruct{}, "example.UnknownA") + second := remoteSchemaLimitTypeDef(t, SliceStruct{}, "example.UnknownB") + + require.NoError(t, readRemoteTypeDef(t, fory, first)) + require.NoError(t, readRemoteTypeDef(t, fory, second)) +} + +func TestRemoteSchemaCheckDoesNotConsumeLimit(t *testing.T) { + fory := NewFory(WithXlang(false), WithCompatible(true), WithMaxSchemaVersionsPerType(1)) + checked := remoteSchemaLimitTypeDef(t, SchemaLimitBad{}, "example.Accepted") + valid := remoteSchemaLimitTypeDef(t, SchemaLimitExtra{}, "example.Accepted") + + typeKey, isStruct, err := fory.typeResolver.checkRemoteStructSchemaLimit(checked) + require.NoError(t, err) + require.True(t, isStruct) + require.NotNil(t, typeKey) + require.NoError(t, readRemoteTypeDef(t, fory, valid)) +} + +func remoteSchemaLimitTypeDef(t *testing.T, value any, name string) *TypeDef { + t.Helper() + sender := NewFory(WithXlang(false), WithCompatible(true)) + require.NoError(t, sender.RegisterStructByName(value, name)) + typeDef, err := buildTypeDef(sender, reflect.ValueOf(value)) + require.NoError(t, err) + return typeDef +} + +func readRemoteTypeDef(t *testing.T, fory *Fory, typeDef *TypeDef) error { + t.Helper() + buffer := NewByteBuffer(nil) + buffer.WriteVarUint32(0) + writeErr := &Error{} + typeDef.writeTypeDef(buffer, writeErr) + require.NoError(t, writeErr.CheckError()) + + readErr := &Error{} + typeInfo := fory.typeResolver.readSharedTypeMeta(buffer, readErr) + if err := readErr.CheckError(); err != nil { + return err + } + require.NotNil(t, typeInfo) + return nil +} + +func TestDecodeTypeDefFallbackNamedTypeCachesLookup(t *testing.T) { fory := NewFory(WithXlang(false), WithCompatible(true)) require.NoError(t, fory.RegisterStructByName(SimpleStruct{}, "example.SimpleStruct")) typeDef, err := buildTypeDef(fory, reflect.ValueOf(SimpleStruct{})) @@ -486,9 +581,6 @@ func TestDecodeTypeDefFallbackNamedTypeCacheRespectsCap(t *testing.T) { delete(fory.typeResolver.nsTypeToTypeInfo, nameKey) info := fory.typeResolver.namedTypeToTypeInfo[[2]string{"example", "SimpleStruct"}] require.NotNil(t, info) - for i := 0; len(fory.typeResolver.nsTypeToTypeInfo) < maxCachedNamedTypeInfos; i++ { - fory.typeResolver.nsTypeToTypeInfo[nsTypeKey{int64(i + 1), int64(i + 2)}] = info - } require.NotContains(t, fory.typeResolver.nsTypeToTypeInfo, nameKey) buffer := NewByteBuffer(nil) @@ -500,8 +592,7 @@ func TestDecodeTypeDefFallbackNamedTypeCacheRespectsCap(t *testing.T) { decoded := readTypeDef(fory, buffer, header, readErr) require.NoError(t, readErr.CheckError()) require.NotNil(t, decoded) - require.Len(t, fory.typeResolver.nsTypeToTypeInfo, maxCachedNamedTypeInfos) - require.NotContains(t, fory.typeResolver.nsTypeToTypeInfo, nameKey) + require.Contains(t, fory.typeResolver.nsTypeToTypeInfo, nameKey) } func TestTypeDefRejectsNamespaceLengthBeyondMetadata(t *testing.T) { diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go index 6d908b22c7..945546032d 100644 --- a/go/fory/type_resolver.go +++ b/go/fory/type_resolver.go @@ -18,6 +18,7 @@ package fory import ( + "bytes" "errors" "fmt" "hash/fnv" @@ -52,11 +53,10 @@ const ( useStringId = 1 SMALL_STRING_THRESHOLD = 16 // 0xffffffff is reserved for "unset". - maxUserTypeID uint32 = 0xfffffffe - invalidUserTypeID uint32 = 0xffffffff - internalTypeIDLimit = 0xFF - maxCachedTypeDefs = 8192 - maxCachedNamedTypeInfos = 8192 + maxUserTypeID uint32 = 0xfffffffe + invalidUserTypeID uint32 = 0xffffffff + internalTypeIDLimit = 0xFF + minRemoteStructSchemaLimit = 8192 ) var ( @@ -208,8 +208,10 @@ type TypeResolver struct { typeNameDecoder *meta.Decoder // meta share related - typeToTypeDef map[reflect.Type]*TypeDef - defIdToTypeDef map[int64]*TypeDef + typeToTypeDef map[reflect.Type]*TypeDef + defIdToTypeDef map[int64]*TypeDef + remoteSchemaVersionsByType map[any]int + totalAcceptedSchemaVersions int // Fast type cache for O(1) lookup using type pointer typePointerCache map[uintptr]*TypeInfo @@ -252,10 +254,11 @@ func newTypeResolver(fory *Fory) *TypeResolver { typeNameEncoder: meta.NewEncoder('$', '_'), typeNameDecoder: meta.NewDecoder('$', '_'), - typeToTypeDef: make(map[reflect.Type]*TypeDef), - defIdToTypeDef: make(map[int64]*TypeDef), - typePointerCache: make(map[uintptr]*TypeInfo), - unionTypeCache: make(map[reflect.Type]bool), + typeToTypeDef: make(map[reflect.Type]*TypeDef), + defIdToTypeDef: make(map[int64]*TypeDef), + remoteSchemaVersionsByType: make(map[any]int), + typePointerCache: make(map[uintptr]*TypeInfo), + unionTypeCache: make(map[reflect.Type]bool), } // base type info for encode/decode types. // composite types info will be constructed dynamically. @@ -1519,6 +1522,61 @@ func (r *TypeResolver) getTypeDef(typ reflect.Type, create bool) (*TypeDef, erro return typeDef, nil } +//go:noinline +func (r *TypeResolver) checkRemoteStructSchemaLimit(td *TypeDef) (any, bool, error) { + switch TypeId(td.typeId) { + case STRUCT, COMPATIBLE_STRUCT, NAMED_STRUCT, NAMED_COMPATIBLE_STRUCT: + default: + return nil, false, nil + } + var typeKey any + if td.registerByName { + if td.nsName == nil || td.typeName == nil { + return nil, false, fmt.Errorf("named remote struct schema is missing namespace or type name") + } + namespace, err := r.namespaceDecoder.Decode(td.nsName.Data, td.nsName.Encoding) + if err != nil { + return nil, false, err + } + typeName, err := r.typeNameDecoder.Decode(td.typeName.Data, td.typeName.Encoding) + if err != nil { + return nil, false, err + } + typeKey = namespace + "\x00" + typeName + } else { + typeKey = td.userTypeId + } + versionsForType := r.remoteSchemaVersionsByType[typeKey] + if versionsForType >= r.fory.config.MaxSchemaVersionsPerType { + return nil, false, fmt.Errorf( + "remote schema version limit exceeded for type %v: %d >= %d. Increase MaxSchemaVersionsPerType if this peer legitimately sends many schema versions for one type", + typeKey, versionsForType, r.fory.config.MaxSchemaVersionsPerType) + } + acceptedStructTypeCount := len(r.remoteSchemaVersionsByType) + if versionsForType == 0 { + acceptedStructTypeCount++ + } + globalLimit := acceptedStructTypeCount * r.fory.config.MaxAverageSchemaVersionsPerType + if globalLimit < minRemoteStructSchemaLimit { + globalLimit = minRemoteStructSchemaLimit + } + if r.totalAcceptedSchemaVersions >= globalLimit { + return nil, false, fmt.Errorf( + "remote schema version limit exceeded: %d schemas for %d accepted struct types exceeds the average limit %d. Increase MaxAverageSchemaVersionsPerType if this peer legitimately sends many schema versions across many types", + r.totalAcceptedSchemaVersions, acceptedStructTypeCount, r.fory.config.MaxAverageSchemaVersionsPerType) + } + return typeKey, true, nil +} + +func (r *TypeResolver) recordRemoteStructSchema(typeKey any, isStruct bool) { + if !isStruct { + return + } + versionsForType := r.remoteSchemaVersionsByType[typeKey] + r.remoteSchemaVersionsByType[typeKey] = versionsForType + 1 + r.totalAcceptedSchemaVersions++ +} + func (r *TypeResolver) readSharedTypeMeta(buffer *ByteBuffer, err *Error) *TypeInfo { context := r.fory.MetaContext() if context == nil { @@ -1559,6 +1617,7 @@ func (r *TypeResolver) readSharedTypeMeta(buffer *ByteBuffer, err *Error) *TypeI } var td *TypeDef + newTypeDef := false if existingTd, exists := r.defIdToTypeDef[id]; exists { // Header-cache hits intentionally skip without rehashing. Entries reach this cache only // after a successful TypeDef parse and 52-bit metadata-hash validation. @@ -1570,6 +1629,39 @@ func (r *TypeResolver) readSharedTypeMeta(buffer *ByteBuffer, err *Error) *TypeI return nil } td = newTd + newTypeDef = true + if td.type_ != nil { + localType := td.type_ + localTd, localErr := r.getTypeDef(localType, true) + if localErr == nil && bytes.Equal(localTd.encoded, td.encoded) { + td = localTd + newTypeDef = false + if typeInfo := r.getTypeInfoByType(localType); typeInfo != nil { + context.readTypeInfos = append(context.readTypeInfos, typeInfo) + return typeInfo + } + if typeInfo, typeInfoErr := r.getTypeInfo(reflect.Zero(localType), true); typeInfoErr == nil { + context.readTypeInfos = append(context.readTypeInfos, typeInfo) + return typeInfo + } + } + } + if newTypeDef { + typeKey, isStruct, limitErr := r.checkRemoteStructSchemaLimit(td) + if limitErr != nil { + err.SetError(limitErr) + return nil + } + typeInfo, typeInfoErr := td.getOrBuildTypeInfo(r) + if typeInfoErr != nil { + err.SetError(typeInfoErr) + return nil + } + r.defIdToTypeDef[id] = td + r.recordRemoteStructSchema(typeKey, isStruct) + context.readTypeInfos = append(context.readTypeInfos, typeInfo) + return typeInfo + } } typeInfo, typeInfoErr := td.getOrBuildTypeInfo(r) @@ -1577,10 +1669,6 @@ func (r *TypeResolver) readSharedTypeMeta(buffer *ByteBuffer, err *Error) *TypeI err.SetError(typeInfoErr) return nil } - if _, exists := r.defIdToTypeDef[id]; !exists && len(r.defIdToTypeDef) < maxCachedTypeDefs { - r.defIdToTypeDef[id] = td - } - context.readTypeInfos = append(context.readTypeInfos, typeInfo) return typeInfo } @@ -2081,9 +2169,7 @@ func (r *TypeResolver) resolveTypeInfoByMetaBytes(nsBytes, typeBytes *MetaString nameKey := [2]string{ns, typeName} if typeInfo, exists := r.namedTypeToTypeInfo[nameKey]; exists { - if len(r.nsTypeToTypeInfo) < maxCachedNamedTypeInfos { - r.nsTypeToTypeInfo[compositeKey] = typeInfo - } + r.nsTypeToTypeInfo[compositeKey] = typeInfo return typeInfo } diff --git a/java/fory-core/src/main/java/org/apache/fory/Fory.java b/java/fory-core/src/main/java/org/apache/fory/Fory.java index b3629cee93..e2e736fb1e 100644 --- a/java/fory-core/src/main/java/org/apache/fory/Fory.java +++ b/java/fory-core/src/main/java/org/apache/fory/Fory.java @@ -123,8 +123,11 @@ public Fory(ForyBuilder builder, ClassLoader classLoader) { public Fory(ForyBuilder builder, ClassLoader classLoader, SharedRegistry sharedRegistry) { // Prefer the explicit constructor argument over retaining loader state on the builder used to // create thread-safe factories. + config = new Config(builder); if (sharedRegistry == null) { - sharedRegistry = new SharedRegistry(); + sharedRegistry = new SharedRegistry(config); + } else { + sharedRegistry.checkConfig(config); } if (classLoader == null) { classLoader = Thread.currentThread().getContextClassLoader(); @@ -134,7 +137,6 @@ public Fory(ForyBuilder builder, ClassLoader classLoader, SharedRegistry sharedR } this.sharedRegistry = sharedRegistry; this.classLoader = classLoader; - config = new Config(builder); headerBitmap = config.isXlang() ? isCrossLanguageFlag : 0; RefWriter refWriter; RefReader refReader; diff --git a/java/fory-core/src/main/java/org/apache/fory/ThreadLocalFory.java b/java/fory-core/src/main/java/org/apache/fory/ThreadLocalFory.java index 57f0d947dc..58b059e818 100644 --- a/java/fory-core/src/main/java/org/apache/fory/ThreadLocalFory.java +++ b/java/fory-core/src/main/java/org/apache/fory/ThreadLocalFory.java @@ -50,8 +50,19 @@ public class ThreadLocalFory extends AbstractThreadSafeFory { private final Object callbackLock = new Object(); public ThreadLocalFory(Function factory) { - SharedRegistry sharedRegistry = new SharedRegistry(); - foryFactory = () -> factory.apply(Fory.builder().withSharedRegistry(sharedRegistry)); + SharedRegistry[] sharedRegistry = new SharedRegistry[1]; + foryFactory = + () -> { + ForyBuilder builder = Fory.builder(); + if (sharedRegistry[0] != null) { + builder.withSharedRegistry(sharedRegistry[0]); + } + Fory fory = factory.apply(builder); + if (sharedRegistry[0] == null) { + sharedRegistry[0] = fory.getSharedRegistry(); + } + return fory; + }; factoryCallback = f -> {}; allFory = Collections.synchronizedMap(new WeakHashMap<>()); foryThreadLocal = ThreadLocal.withInitial(this::newFory); diff --git a/java/fory-core/src/main/java/org/apache/fory/config/Config.java b/java/fory-core/src/main/java/org/apache/fory/config/Config.java index 8c9f18d821..d244602e02 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/Config.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/Config.java @@ -64,6 +64,10 @@ public class Config implements Serializable { private final boolean serializeEnumByName; private final int bufferSizeLimitBytes; private final int maxDepth; + private final int maxTypeFields; + private final int maxTypeMetaBytes; + private final int maxSchemaVersionsPerType; + private final int maxAverageSchemaVersionsPerType; private final float mapRefLoadFactor; private final boolean forVirtualThread; @@ -106,6 +110,10 @@ public Config(ForyBuilder builder) { serializeEnumByName = builder.serializeEnumByName; bufferSizeLimitBytes = builder.bufferSizeLimitBytes; maxDepth = builder.maxDepth; + maxTypeFields = builder.maxTypeFields; + maxTypeMetaBytes = builder.maxTypeMetaBytes; + maxSchemaVersionsPerType = builder.maxSchemaVersionsPerType; + maxAverageSchemaVersionsPerType = builder.maxAverageSchemaVersionsPerType; mapRefLoadFactor = builder.mapRefLoadFactor; forVirtualThread = builder.forVirtualThread; } @@ -292,6 +300,26 @@ public int maxDepth() { return maxDepth; } + /** Returns the maximum number of fields accepted in one received struct TypeDef. */ + public int maxTypeFields() { + return maxTypeFields; + } + + /** Returns the maximum body size accepted for one received TypeDef. */ + public int maxTypeMetaBytes() { + return maxTypeMetaBytes; + } + + /** Returns the maximum accepted remote schema versions for one struct type. */ + public int maxSchemaVersionsPerType() { + return maxSchemaVersionsPerType; + } + + /** Returns the maximum average accepted remote schema versions across struct types. */ + public int maxAverageSchemaVersionsPerType() { + return maxAverageSchemaVersionsPerType; + } + /** Returns loadFactor of MacRef's writtenObjects. */ public float mapRefLoadFactor() { return mapRefLoadFactor; @@ -336,6 +364,10 @@ public boolean equals(Object o) { && deserializeUnknownClass == config.deserializeUnknownClass && xlang == config.xlang && compatible == config.compatible + && maxTypeFields == config.maxTypeFields + && maxTypeMetaBytes == config.maxTypeMetaBytes + && maxSchemaVersionsPerType == config.maxSchemaVersionsPerType + && maxAverageSchemaVersionsPerType == config.maxAverageSchemaVersionsPerType && Objects.equals(defaultJDKStreamSerializerType, config.defaultJDKStreamSerializerType) && longEncoding == config.longEncoding && forVirtualThread == config.forVirtualThread; @@ -367,6 +399,10 @@ public int hashCode() { requireClassRegistration, suppressClassRegistrationWarnings, registerGuavaTypes, + maxTypeFields, + maxTypeMetaBytes, + maxSchemaVersionsPerType, + maxAverageSchemaVersionsPerType, metaShareEnabled, scopedMetaShareEnabled, metaCompressor, diff --git a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java index 93bd88aabb..42581ad818 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java @@ -99,6 +99,10 @@ public final class ForyBuilder { Integer bufferSizeLimitBytes = -1; MetaCompressor metaCompressor = new DeflaterMetaCompressor(); int maxDepth = 50; + int maxTypeFields = 512; + int maxTypeMetaBytes = 4096; + int maxSchemaVersionsPerType = 10; + int maxAverageSchemaVersionsPerType = 3; float mapRefLoadFactor = 0.51f; boolean forVirtualThread = false; TypeChecker typeChecker; @@ -511,6 +515,62 @@ public ForyBuilder withMaxDepth(int maxDepth) { return this; } + /** + * Sets the maximum number of fields accepted in one received struct TypeDef. + * + *

This limit applies only to cold remote metadata parse paths. + */ + public ForyBuilder withMaxTypeFields(int maxTypeFields) { + Preconditions.checkArgument( + maxTypeFields > 0, "maxTypeFields must be positive but got %s", maxTypeFields); + this.maxTypeFields = maxTypeFields; + recordAction(b -> b.withMaxTypeFields(maxTypeFields)); + return this; + } + + /** + * Sets the maximum body size accepted for one received TypeDef. + * + *

This limit excludes the 8-byte metadata header and any extended-size varint bytes. + */ + public ForyBuilder withMaxTypeMetaBytes(int maxTypeMetaBytes) { + Preconditions.checkArgument( + maxTypeMetaBytes > 0, "maxTypeMetaBytes must be positive but got %s", maxTypeMetaBytes); + this.maxTypeMetaBytes = maxTypeMetaBytes; + recordAction(b -> b.withMaxTypeMetaBytes(maxTypeMetaBytes)); + return this; + } + + /** + * Sets the maximum number of accepted remote schema versions for one struct type. + * + *

This limit applies only to cold remote metadata miss paths. + */ + public ForyBuilder withMaxSchemaVersionsPerType(int maxSchemaVersionsPerType) { + Preconditions.checkArgument( + maxSchemaVersionsPerType > 0, + "maxSchemaVersionsPerType must be positive but got %s", + maxSchemaVersionsPerType); + this.maxSchemaVersionsPerType = maxSchemaVersionsPerType; + recordAction(b -> b.withMaxSchemaVersionsPerType(maxSchemaVersionsPerType)); + return this; + } + + /** + * Sets the maximum average number of accepted remote schema versions across struct types. + * + *

The global limit has an internal floor so small type universes can still evolve normally. + */ + public ForyBuilder withMaxAverageSchemaVersionsPerType(int maxAverageSchemaVersionsPerType) { + Preconditions.checkArgument( + maxAverageSchemaVersionsPerType > 0, + "maxAverageSchemaVersionsPerType must be positive but got %s", + maxAverageSchemaVersionsPerType); + this.maxAverageSchemaVersionsPerType = maxAverageSchemaVersionsPerType; + recordAction(b -> b.withMaxAverageSchemaVersionsPerType(maxAverageSchemaVersionsPerType)); + return this; + } + /** Set loadFactor of MapRefResolver writtenObjects. Default value is 0.51 */ public ForyBuilder withMapRefLoadFactor(float loadFactor) { Preconditions.checkArgument( diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java index 7115414abe..db53fd1615 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/NativeTypeDefDecoder.java @@ -46,8 +46,6 @@ * href="https://fory.apache.org/docs/specification/fory_java_serialization_spec">... */ class NativeTypeDefDecoder { - private static final int MAX_TYPE_DEF_SIZE_BYTES = 16 * 1024 * 1024; - static Tuple2 decodeTypeDefBuf( MemoryBuffer inputBuffer, TypeResolver resolver, long id) { if ((id & TypeDef.RESERVED_META_FLAGS) != 0) { @@ -58,12 +56,13 @@ static Tuple2 decodeTypeDefBuf( int size = (int) (id & META_SIZE_MASKS); if (size == META_SIZE_MASKS) { int moreSize = inputBuffer.readVarUInt32Small14(); + if (moreSize < 0 || moreSize > Integer.MAX_VALUE - size) { + throw new DeserializationException("Invalid TypeDef metadata size " + moreSize); + } encoded.writeVarUInt32(moreSize); size += moreSize; } - if (size > MAX_TYPE_DEF_SIZE_BYTES) { - throw new DeserializationException("TypeDef metadata size exceeds the maximum size"); - } + checkTypeMetaBodySize(size, resolver.getConfig().maxTypeMetaBytes()); byte[] encodedTypeDef = inputBuffer.readBytes(size); encoded.writeBytes(encodedTypeDef); if ((id & COMPRESS_META_FLAG) != 0) { @@ -71,11 +70,35 @@ static Tuple2 decodeTypeDefBuf( resolver .getConfig() .getMetaCompressor() - .decompress(encodedTypeDef, 0, size, MAX_TYPE_DEF_SIZE_BYTES); + .decompress(encodedTypeDef, 0, size, resolver.getConfig().maxTypeMetaBytes()); } return Tuple2.of(encodedTypeDef, encoded.getBytes(0, encoded.writerIndex())); } + static void checkTypeMetaBodySize(int size, int maxTypeMetaBytes) { + if (size > maxTypeMetaBytes) { + throw new DeserializationException( + "Type metadata body size " + + size + + " exceeds maxTypeMetaBytes " + + maxTypeMetaBytes + + ". The data may be malicious. If the data is not malicious, please increase " + + "maxTypeMetaBytes."); + } + } + + static void checkTypeMetaFieldCount(long fieldCount, int maxTypeFields) { + if (fieldCount > maxTypeFields) { + throw new DeserializationException( + "Type metadata field count " + + fieldCount + + " exceeds maxTypeFields " + + maxTypeFields + + ". The data may be malicious. If the data is not malicious, please increase " + + "maxTypeFields."); + } + } + public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer, long id) { Tuple2 decoded = decodeTypeDefBuf(buffer, resolver, id); MemoryBuffer typeDefBuf = MemoryBuffer.fromByteArray(decoded.f0); @@ -95,6 +118,8 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer, ClassSpec classSpec = null; Class rootClass = null; boolean rootClassLayerRegistered = false; + int maxTypeFields = resolver.getConfig().maxTypeFields(); + long totalFields = 0; for (int i = 0; i < numClasses; i++) { // | num fields + register flag | header + package name | header + class name // | header + type id + field name | next field info | ... | @@ -104,6 +129,8 @@ public static TypeDef decodeTypeDef(ClassResolver resolver, MemoryBuffer buffer, } boolean isRegistered = (currentClassHeader & 0b1) != 0; int numFields = currentClassHeader >>> 1; + totalFields += numFields; + checkTypeMetaFieldCount(totalFields, maxTypeFields); Class currentClass = null; if (isRegistered) { int typeId = typeDefBuf.readUInt8(); diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDef.java b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDef.java index 26865cb32e..949ded1e96 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDef.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDef.java @@ -321,6 +321,11 @@ public static TypeDef readTypeDef(TypeResolver resolver, MemoryBuffer buffer) { return NativeTypeDefDecoder.decodeTypeDef((ClassResolver) resolver, buffer, buffer.readInt64()); } + /** Decode an encoded class definition. */ + public static TypeDef readTypeDef(TypeResolver resolver, byte[] encoded) { + return readTypeDef(resolver, MemoryBuffer.fromByteArray(encoded)); + } + /** Read class definition from buffer. */ public static TypeDef readTypeDef(TypeResolver resolver, MemoryBuffer buffer, long header) { if (resolver.isCrossLanguage()) { @@ -329,6 +334,17 @@ public static TypeDef readTypeDef(TypeResolver resolver, MemoryBuffer buffer, lo return NativeTypeDefDecoder.decodeTypeDef((ClassResolver) resolver, buffer, header); } + /** Read encoded class definition bytes from buffer. */ + public static byte[] readTypeDefBytes(TypeResolver resolver, MemoryBuffer buffer, long header) { + if (resolver.isCrossLanguage()) { + if ((header & COMPRESS_META_FLAG) != 0) { + throw new DeserializationException("Compressed xlang TypeDef is not supported"); + } + return NativeTypeDefDecoder.decodeTypeDefBuf(buffer, resolver, header).f1; + } + return NativeTypeDefDecoder.decodeTypeDefBuf(buffer, resolver, header).f1; + } + /** * Consolidate fields of typeDef with cls. If some field exists in * cls but not in typeDef, it won't be returned in final collection. If diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java index 83ca2287bc..587e369193 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/TypeDefDecoder.java @@ -20,6 +20,7 @@ package org.apache.fory.meta; import static org.apache.fory.meta.Encoders.fieldNameEncodings; +import static org.apache.fory.meta.NativeTypeDefDecoder.checkTypeMetaFieldCount; import static org.apache.fory.meta.NativeTypeDefDecoder.decodeTypeDefBuf; import static org.apache.fory.meta.NativeTypeDefDecoder.readPkgName; import static org.apache.fory.meta.NativeTypeDefDecoder.readTypeName; @@ -82,6 +83,7 @@ public static TypeDef decodeTypeDef(XtypeResolver resolver, MemoryBuffer inputBu } numFields += extraFields; } + checkTypeMetaFieldCount(numFields, resolver.getConfig().maxTypeFields()); if (named) { String namespace = readPkgName(buffer); String typeName = readTypeName(buffer); @@ -90,7 +92,7 @@ public static TypeDef decodeTypeDef(XtypeResolver resolver, MemoryBuffer inputBu } TypeInfo userTypeInfo = resolver.getUserTypeInfo(namespace, typeName); if (userTypeInfo == null) { - classSpec = new ClassSpec(UnknownClass.UnknownStruct.class, typeId, -1); + classSpec = unknownNamedClassSpec(namespace, typeName, typeId); } else { validateRegisteredTypeDefKind(userTypeInfo, typeId); classSpec = new ClassSpec(userTypeInfo.getType(), typeId, userTypeInfo.getUserTypeId()); @@ -116,7 +118,7 @@ public static TypeDef decodeTypeDef(XtypeResolver resolver, MemoryBuffer inputBu String typeName = readTypeName(buffer); TypeInfo userTypeInfo = resolver.getUserTypeInfo(namespace, typeName); if (userTypeInfo == null) { - classSpec = new ClassSpec(UnknownClass.UnknownStruct.class, typeId, -1); + classSpec = unknownNamedClassSpec(namespace, typeName, typeId); } else { validateRegisteredTypeDefKind(userTypeInfo, typeId); classSpec = new ClassSpec(userTypeInfo.getType(), typeId, userTypeInfo.getUserTypeId()); @@ -169,6 +171,13 @@ private static void validateRegisteredTypeDefKind(TypeInfo userTypeInfo, int typ } } + private static ClassSpec unknownNamedClassSpec(String namespace, String typeName, int typeId) { + String className = namespace.isEmpty() ? typeName : namespace + "." + typeName; + ClassSpec classSpec = new ClassSpec(className, Types.isEnumType(typeId), false, 0, typeId, -1); + classSpec.type = UnknownClass.UnknownStruct.class; + return classSpec; + } + private static boolean isStructCompatibilityVariant(int registeredTypeId, int typeId) { boolean registeredIdStruct = registeredTypeId == Types.STRUCT || registeredTypeId == Types.COMPATIBLE_STRUCT; diff --git a/java/fory-core/src/main/java/org/apache/fory/pool/ThreadPoolFory.java b/java/fory-core/src/main/java/org/apache/fory/pool/ThreadPoolFory.java index e3544aa106..c3aede7bf3 100644 --- a/java/fory-core/src/main/java/org/apache/fory/pool/ThreadPoolFory.java +++ b/java/fory-core/src/main/java/org/apache/fory/pool/ThreadPoolFory.java @@ -59,9 +59,19 @@ public ThreadPoolFory(Function foryFactory, int poolSize) { throw new IllegalArgumentException( String.format("thread safe fory pool size error, please check it, size:[%s]", poolSize)); } - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry[] sharedRegistry = new SharedRegistry[1]; Supplier factory = - () -> foryFactory.apply(Fory.builder().withSharedRegistry(sharedRegistry)); + () -> { + ForyBuilder builder = Fory.builder(); + if (sharedRegistry[0] != null) { + builder.withSharedRegistry(sharedRegistry[0]); + } + Fory fory = foryFactory.apply(builder); + if (sharedRegistry[0] == null) { + sharedRegistry[0] = fory.getTypeResolver().getSharedRegistry(); + } + return fory; + }; this.poolSize = poolSize; slots = new AtomicReferenceArray<>(poolSize); pooledFory = new Fory[poolSize]; diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/SharedRegistry.java b/java/fory-core/src/main/java/org/apache/fory/resolver/SharedRegistry.java index a3f573afdd..f6e7170aab 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/SharedRegistry.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/SharedRegistry.java @@ -22,17 +22,19 @@ import java.lang.reflect.Member; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.IdentityHashMap; import java.util.List; import java.util.Objects; import java.util.SortedMap; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicInteger; import org.apache.fory.annotation.Internal; import org.apache.fory.codegen.CodeGenerator; import org.apache.fory.collection.BiMap; import org.apache.fory.collection.ConcurrentIdentityMap; import org.apache.fory.collection.Tuple2; +import org.apache.fory.config.Config; +import org.apache.fory.exception.ForyException; import org.apache.fory.meta.EncodedMetaString; import org.apache.fory.meta.Encoders; import org.apache.fory.meta.MetaString; @@ -57,7 +59,7 @@ public final class SharedRegistry { private static final int MAX_CACHED_ENCODED_META_STRINGS = 32768; private static final int MAX_CACHED_ENCODED_META_STRING_LENGTH = 2048; - private static final int MAX_CACHED_TYPE_DEFS = 8192; + private static final int MIN_REMOTE_STRUCT_SCHEMA_LIMIT = 8192; final ConcurrentIdentityMap, TypeDef> typeDefMap = new ConcurrentIdentityMap<>(); final ConcurrentIdentityMap, TypeDef> currentLayerTypeDef = @@ -90,10 +92,22 @@ public final class SharedRegistry { final StaticGeneratedSerializerRegistry staticGeneratedSerializerRegistry = new StaticGeneratedSerializerRegistry(); private final Object metaStringCacheLock = new Object(); - private final AtomicInteger cachedTypeDefCount = new AtomicInteger(); + private final Config config; + private final HashMap remoteSchemaVersionsByType = new HashMap<>(); + private int totalAcceptedSchemaVersions; volatile IdentityHashMap, Integer> registeredClassIdMap; volatile BiMap> registeredClasses; + public SharedRegistry(Config config) { + this.config = Objects.requireNonNull(config); + } + + public void checkConfig(Config config) { + if (!this.config.equals(config)) { + throw new IllegalArgumentException("SharedRegistry cannot be reused with different config"); + } + } + synchronized void setRegistrationIfAbsent( IdentityHashMap, Integer> candidateRegisteredClassIdMap, BiMap> candidateRegisteredClasses) { @@ -157,35 +171,67 @@ TypeDef getOrCreateTypeDef(TypeDef typeDef) { if (existing != null) { return existing; } - if (!reserveTypeDefCacheSlot()) { - return typeDef; + existing = typeDefById.putIfAbsent(id, typeDef); + return existing == null ? typeDef : existing; + } + + synchronized TypeDef getOrCreateRemoteTypeDef(TypeDef typeDef, Object structTypeKey) { + long id = typeDef.getId(); + TypeDef existing = typeDefById.get(id); + if (existing != null) { + return existing; } + int versionsForType = checkRemoteSchemaLimit(structTypeKey); existing = typeDefById.putIfAbsent(id, typeDef); if (existing != null) { - cachedTypeDefCount.decrementAndGet(); return existing; } + remoteSchemaVersionsByType.put(structTypeKey, versionsForType + 1); + totalAcceptedSchemaVersions++; return typeDef; } - private boolean reserveTypeDefCacheSlot() { - while (true) { - int count = cachedTypeDefCount.get(); - int mapSize = typeDefById.size(); - if (mapSize > count) { - if (cachedTypeDefCount.compareAndSet(count, mapSize)) { - count = mapSize; - } else { - continue; - } - } - if (count >= MAX_CACHED_TYPE_DEFS) { - return false; - } - if (cachedTypeDefCount.compareAndSet(count, count + 1)) { - return true; - } + synchronized void checkRemoteTypeDefLimit(TypeDef typeDef, Object structTypeKey) { + if (typeDefById.containsKey(typeDef.getId())) { + return; } + checkRemoteSchemaLimit(structTypeKey); + } + + private int checkRemoteSchemaLimit(Object structTypeKey) { + int versionsForType = remoteSchemaVersionsByType.getOrDefault(structTypeKey, 0); + int maxSchemaVersionsPerType = config.maxSchemaVersionsPerType(); + if (versionsForType >= maxSchemaVersionsPerType) { + throw new ForyException( + "Remote schema version limit exceeded for type " + + structTypeKey + + ": " + + versionsForType + + " >= " + + maxSchemaVersionsPerType + + ". Increase maxSchemaVersionsPerType if this peer legitimately sends many " + + "schema versions for one type."); + } + int acceptedStructTypeCount = + versionsForType == 0 + ? remoteSchemaVersionsByType.size() + 1 + : remoteSchemaVersionsByType.size(); + long globalLimit = + Math.max( + (long) MIN_REMOTE_STRUCT_SCHEMA_LIMIT, + (long) acceptedStructTypeCount * config.maxAverageSchemaVersionsPerType()); + if (totalAcceptedSchemaVersions >= globalLimit) { + throw new ForyException( + "Remote schema version limit exceeded: " + + totalAcceptedSchemaVersions + + " schemas for " + + acceptedStructTypeCount + + " accepted struct types exceeds the average limit " + + config.maxAverageSchemaVersionsPerType() + + ". Increase maxAverageSchemaVersionsPerType if this peer legitimately sends many " + + "schema versions across many types."); + } + return versionsForType; } EncodedMetaString getPackageEncodedMetaString(String string) { diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java index 80c9f59982..e4b1ba8972 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java @@ -123,7 +123,6 @@ public abstract class TypeResolver { static final int INTERNAL_NATIVE_ID_LIMIT = 250; private static final GenericType OBJECT_GENERIC_TYPE = GenericType.build(Object.class); private static final float TYPE_ID_MAP_LOAD_FACTOR = 0.5f; - private static final int MAX_CACHED_TYPE_DEFS = 8192; static final long MAX_USER_TYPE_ID = 0xffff_fffEL; private static final class TransformedTypeInfo { @@ -162,6 +161,7 @@ protected TypeResolver( JITContext jitContext) { this.config = config; this.sharedRegistry = sharedRegistry; + sharedRegistry.checkConfig(config); this.jitContext = jitContext; metaContextShareEnabled = config.isMetaShareEnabled(); extRegistry = new ExtRegistry(classLoader, sharedRegistry); @@ -854,6 +854,11 @@ protected final TypeInfo readTypeInfoFromBytes( * ClassResolver and XtypeResolver. */ protected final TypeInfo readSharedClassMeta(ReadContext readContext) { + return readSharedClassMeta(readContext, null, false); + } + + private TypeInfo readSharedClassMeta( + ReadContext readContext, Class targetClass, boolean hasTargetClass) { MemoryBuffer buffer = readContext.getBuffer(); MetaReadContext metaReadContext = readContext.getMetaReadContext(); assert metaReadContext != null : SET_META_READ_CONTEXT_MSG; @@ -876,10 +881,22 @@ protected final TypeInfo readSharedClassMeta(ReadContext readContext) { TypeDef typeDef = sharedRegistry.typeDefById.get(id); if (typeDef != null) { TypeDef.skipTypeDef(buffer, id); + typeInfo = buildMetaSharedTypeInfo(typeDef); + } else if (hasTargetClass) { + byte[] encoded = TypeDef.readTypeDefBytes(this, buffer, id); + TypeDef localTypeDef = getTypeDef(targetClass, true); + if (Arrays.equals(encoded, localTypeDef.getEncoded())) { + // Exact local bytes only avoid remote schema counting/parsing. They still select the + // target class serializer, so the class must pass the active deserialization policy. + checkClassForDeserialization(targetClass); + typeInfo = getTypeInfo(targetClass); + } else { + typeInfo = buildCheckedMetaSharedTypeInfo(TypeDef.readTypeDef(this, encoded)); + } } else { - typeDef = readTypeDef(buffer, id); + typeDef = TypeDef.readTypeDef(this, buffer, id); + typeInfo = buildCheckedMetaSharedTypeInfo(typeDef); } - typeInfo = buildMetaSharedTypeInfo(typeDef); } // index == readTypeInfos.size() since types are written sequentially metaReadContext.readTypeInfos.add(typeInfo); @@ -888,7 +905,7 @@ protected final TypeInfo readSharedClassMeta(ReadContext readContext) { } public final TypeInfo readSharedClassMeta(ReadContext readContext, Class targetClass) { - TypeInfo typeInfo = readSharedClassMeta(readContext); + TypeInfo typeInfo = readSharedClassMeta(readContext, targetClass, true); Class readClass = typeInfo.getType(); // replace target class if needed if (targetClass != readClass) { @@ -941,9 +958,6 @@ private TypeInfo transformTypeInfo(TypeInfo typeInfo, Class targetClass, long } TransformedTypeInfo[] infos = extRegistry.transformedTypeInfo.get(targetClass); int size = infos == null ? 0 : infos.length; - if (size >= MAX_CACHED_TYPE_DEFS) { - return newTypeInfo; - } TransformedTypeInfo[] newInfos = new TransformedTypeInfo[size + 1]; if (size > 0) { System.arraycopy(infos, 0, newInfos, 0, size); @@ -1030,8 +1044,18 @@ final TypeInfo buildMetaSharedTypeInfo(TypeDef typeDef) { return typeInfo; } Class cls = loadClass(typeDef.getClassSpec()); - // A wire TypeDef may create a compatible serializer; admit the class before caching it by id. checkClassForDeserialization(cls); + return buildMetaSharedTypeInfo(typeDef, cls); + } + + private TypeInfo buildMetaSharedTypeInfo(TypeDef typeDef, Class cls) { + TypeInfo typeInfo = createMetaSharedTypeInfo(typeDef, cls); + extRegistry.typeInfoByTypeDefId.put(typeDef.getId(), typeInfo); + return typeInfo; + } + + private TypeInfo createMetaSharedTypeInfo(TypeDef typeDef, Class cls) { + TypeInfo typeInfo; if (!typeDef.isStructSchemaKind() && !UnknownClass.class.isAssignableFrom(TypeUtils.getComponentIfArray(cls))) { typeInfo = getTypeInfo(cls); @@ -1042,12 +1066,39 @@ final TypeInfo buildMetaSharedTypeInfo(TypeDef typeDef) { } else { typeInfo = getMetaSharedTypeInfo(typeDef, cls); } - if (extRegistry.typeInfoByTypeDefId.size < MAX_CACHED_TYPE_DEFS) { - extRegistry.typeInfoByTypeDefId.put(typeDef.getId(), typeInfo); - } return typeInfo; } + private TypeInfo buildCheckedMetaSharedTypeInfo(TypeDef typeDef) { + Class cls = loadClass(typeDef.getClassSpec()); + // A wire TypeDef may create a compatible serializer; check the class before counting it. + checkClassForDeserialization(cls); + TypeDef localTypeDef = exactLocalTypeDef(typeDef, cls); + if (localTypeDef != null) { + return buildMetaSharedTypeInfo(localTypeDef, cls); + } + if (!typeDef.isStructSchemaKind()) { + TypeDef cachedTypeDef = cacheTypeDef(typeDef); + return buildMetaSharedTypeInfo(cachedTypeDef, cls); + } + Object structTypeKey = remoteStructKey(typeDef); + sharedRegistry.checkRemoteTypeDefLimit(typeDef, structTypeKey); + TypeInfo typeInfo = createMetaSharedTypeInfo(typeDef, cls); + TypeDef cachedTypeDef = sharedRegistry.getOrCreateRemoteTypeDef(typeDef, structTypeKey); + if (cachedTypeDef != typeDef) { + return buildMetaSharedTypeInfo(cachedTypeDef, cls); + } + extRegistry.typeInfoByTypeDefId.put(typeDef.getId(), typeInfo); + return typeInfo; + } + + private TypeDef exactLocalTypeDef(TypeDef remoteTypeDef, Class cls) { + TypeDef localTypeDef = getTypeDef(cls, true); + return Arrays.equals(remoteTypeDef.getEncoded(), localTypeDef.getEncoded()) + ? localTypeDef + : null; + } + // TODO(chaokunyang) if TypeDef is consistent with class in this process, // use existing serializer instead. private TypeInfo getMetaSharedTypeInfo(TypeDef typeDef, Class clz) { @@ -1470,6 +1521,21 @@ public final TypeDef cacheTypeDef(TypeDef typeDef) { return sharedRegistry.getOrCreateTypeDef(typeDef); } + @Internal + public final TypeDef cacheRemoteTypeDef(TypeDef typeDef) { + if (!typeDef.isStructSchemaKind()) { + return cacheTypeDef(typeDef); + } + return sharedRegistry.getOrCreateRemoteTypeDef(typeDef, remoteStructKey(typeDef)); + } + + private static Object remoteStructKey(TypeDef typeDef) { + if (typeDef.isNamed()) { + return typeDef.getClassSpec().entireClassName; + } + return typeDef.getClassSpec().userTypeId; + } + public final boolean isSerializable(Class cls) { // Enums are always serializable, even if abstract (enums with abstract methods) if (cls.isEnum()) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java index c0444c3b6f..9550a1413f 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ObjectStreamSerializer.java @@ -93,7 +93,6 @@ @SuppressWarnings({"unchecked", "rawtypes"}) public class ObjectStreamSerializer extends AbstractObjectSerializer { private static final Logger LOG = LoggerFactory.getLogger(ObjectStreamSerializer.class); - private static final int MAX_CACHED_TYPE_DEFS = 8192; private final SlotInfo[] slotsInfos; // Instance-level cache: TypeDef ID -> TypeInfo (shared across all slots). @@ -279,7 +278,7 @@ public Object read(ReadContext readContext) { ClassResolver classResolver = (ClassResolver) typeResolver; TreeMap callbacks = new TreeMap<>(Collections.reverseOrder()); for (int i = 0; i < numClasses; i++) { - // Matching layers are admitted by the registered root object type; requiring each + // Matching layers are accepted by the registered root object type; requiring each // serializable superclass to be registered would make normal ObjectStream hierarchy reads // unusable. Sender-only layers are checked below before their data is skipped. Class currentClass = classResolver.readClassInternalUnchecked(readContext); @@ -553,12 +552,20 @@ private TypeInfo readLayerTypeInfo( TypeDef.skipTypeDef(buffer, typeDefId); return typeInfo; } - TypeDef typeDef = - typeResolver.cacheTypeDef(TypeDef.readTypeDef(typeResolver, buffer, typeDefId)); - typeInfo = new TypeInfo(cls, typeDef); - if (typeDefIdToTypeInfo.size < MAX_CACHED_TYPE_DEFS) { - typeDefIdToTypeInfo.put(typeDefId, typeInfo); + byte[] encoded = TypeDef.readTypeDefBytes(typeResolver, buffer, typeDefId); + TypeDef localTypeDef = typeResolver.getTypeDef(cls, false); + TypeDef typeDef; + if (Arrays.equals(encoded, localTypeDef.getEncoded())) { + // Exact local bytes only avoid remote schema counting/parsing. They still select the + // target class serializer, so the class must pass the active deserialization policy. + typeResolver.checkClassForDeserialization(cls); + typeDef = localTypeDef; + } else { + typeResolver.checkClassForDeserialization(cls); + typeDef = typeResolver.cacheRemoteTypeDef(TypeDef.readTypeDef(typeResolver, encoded)); } + typeInfo = new TypeInfo(cls, typeDef); + typeDefIdToTypeInfo.put(typeDefId, typeInfo); return typeInfo; } diff --git a/java/fory-core/src/test/java/org/apache/fory/ForyInitPerf.java b/java/fory-core/src/test/java/org/apache/fory/ForyInitPerf.java index 9442d4122c..e4a2168816 100644 --- a/java/fory-core/src/test/java/org/apache/fory/ForyInitPerf.java +++ b/java/fory-core/src/test/java/org/apache/fory/ForyInitPerf.java @@ -23,6 +23,7 @@ import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; +import org.apache.fory.config.Config; import org.apache.fory.config.ForyBuilder; import org.apache.fory.context.MetaReadContext; import org.apache.fory.context.MetaWriteContext; @@ -144,12 +145,15 @@ private static BenchmarkContext newBenchmarkContext() throws Exception { .withAsyncCompilation(false) .withCodegen(true); finishBuilder(builder); - return new BenchmarkContext(builder, benchmarkClassLoader(), new SharedRegistry()); + return new BenchmarkContext( + builder, benchmarkClassLoader(), new SharedRegistry(new Config(builder))); } private static Fory newFory(BenchmarkContext context, boolean useSharedRegistry) { SharedRegistry sharedRegistry = - useSharedRegistry ? context.sharedRegistry : new SharedRegistry(); + useSharedRegistry + ? context.sharedRegistry + : new SharedRegistry(new Config(context.builder)); return new Fory(context.builder, context.classLoader, sharedRegistry); } diff --git a/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java b/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java index 2d6bbce13e..fc6875c9c5 100644 --- a/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java @@ -227,6 +227,49 @@ public void testPrependHeader() { Assert.assertEquals(header & TypeDef.COMPRESS_META_FLAG, TypeDef.COMPRESS_META_FLAG); } + @Test + public void testDecodeRejectsTooManyFields() { + Fory writer = Fory.builder().withXlang(false).withMetaShare(true).withCompatible(false).build(); + TypeDef typeDef = + TypeDef.buildTypeDef(writer.getTypeResolver(), TypeDefEncoderTest.ManyFields.class); + Fory reader = + Fory.builder() + .withXlang(false) + .withMetaShare(true) + .withCompatible(false) + .withMaxTypeFields(31) + .build(); + + DeserializationException exception = + Assert.expectThrows( + DeserializationException.class, + () -> + TypeDef.readTypeDef( + reader.getTypeResolver(), MemoryBuffer.fromByteArray(typeDef.getEncoded()))); + Assert.assertTrue(exception.getMessage().contains("maxTypeFields")); + } + + @Test + public void testDecodeRejectsTypeMetaBodySize() { + Fory writer = Fory.builder().withXlang(false).withMetaShare(true).withCompatible(false).build(); + TypeDef typeDef = TypeDef.buildTypeDef(writer.getTypeResolver(), Foo1.class); + Fory reader = + Fory.builder() + .withXlang(false) + .withMetaShare(true) + .withCompatible(false) + .withMaxTypeMetaBytes(1) + .build(); + + DeserializationException exception = + Assert.expectThrows( + DeserializationException.class, + () -> + TypeDef.readTypeDef( + reader.getTypeResolver(), MemoryBuffer.fromByteArray(typeDef.getEncoded()))); + Assert.assertTrue(exception.getMessage().contains("maxTypeMetaBytes")); + } + @Test public void testDecodeRejectsReservedGlobalBits() { Fory fory = Fory.builder().withXlang(false).withMetaShare(true).withCompatible(false).build(); diff --git a/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java b/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java index 613d50342f..3221d176a1 100644 --- a/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/meta/TypeDefEncoderTest.java @@ -581,6 +581,50 @@ public void testExtendedFieldCountHeaderDoesNotSetRegisterByName() { Assert.assertEquals(decoded.getFieldsInfo().size(), 32); } + @Test + public void testDecodeRejectsTooManyFields() { + Fory writer = Fory.builder().withXlang(true).withCompatible(false).withMetaShare(true).build(); + writer.register(ManyFields.class, 6002); + TypeDef typeDef = TypeDef.buildTypeDef(writer.getTypeResolver(), ManyFields.class); + Fory reader = + Fory.builder() + .withXlang(true) + .withCompatible(false) + .withMetaShare(true) + .withMaxTypeFields(31) + .build(); + + DeserializationException exception = + Assert.expectThrows( + DeserializationException.class, + () -> + TypeDef.readTypeDef( + reader.getTypeResolver(), MemoryBuffer.fromByteArray(typeDef.getEncoded()))); + Assert.assertTrue(exception.getMessage().contains("maxTypeFields")); + } + + @Test + public void testDecodeRejectsTypeMetaBodySize() { + Fory writer = Fory.builder().withXlang(true).withCompatible(false).withMetaShare(true).build(); + writer.register(ClassWithNoAnnotations.class, 6003); + TypeDef typeDef = TypeDef.buildTypeDef(writer.getTypeResolver(), ClassWithNoAnnotations.class); + Fory reader = + Fory.builder() + .withXlang(true) + .withCompatible(false) + .withMetaShare(true) + .withMaxTypeMetaBytes(1) + .build(); + + DeserializationException exception = + Assert.expectThrows( + DeserializationException.class, + () -> + TypeDef.readTypeDef( + reader.getTypeResolver(), MemoryBuffer.fromByteArray(typeDef.getEncoded()))); + Assert.assertTrue(exception.getMessage().contains("maxTypeMetaBytes")); + } + @Test public void testDecodeRejectsCompressedXlangTypeDef() { Fory fory = Fory.builder().withXlang(true).withCompatible(false).withMetaShare(true).build(); diff --git a/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java b/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java index 1dd0cbdde6..861fb373b3 100644 --- a/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/resolver/ClassResolverTest.java @@ -50,9 +50,12 @@ import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; import org.apache.fory.builder.Generated; +import org.apache.fory.config.Config; import org.apache.fory.config.ForyBuilder; +import org.apache.fory.context.MetaReadContext; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; +import org.apache.fory.exception.ForyException; import org.apache.fory.exception.InsecureException; import org.apache.fory.logging.Logger; import org.apache.fory.logging.LoggerFactory; @@ -70,6 +73,7 @@ import org.apache.fory.serializer.collection.CollectionSerializer; import org.apache.fory.serializer.collection.CollectionSerializers; import org.apache.fory.serializer.collection.MapSerializers; +import org.apache.fory.test.bean.BeanA; import org.apache.fory.test.bean.BeanB; import org.apache.fory.type.Descriptor; import org.apache.fory.type.DescriptorGrouper; @@ -289,7 +293,7 @@ public void testSharedRegistrySharesTypeDefCachesAcrossForyInstances() { ForyBuilder builder = Fory.builder().withXlang(false).requireClassRegistration(false).withCompatible(false); finishBuilder(builder); - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = new SharedRegistry(new Config(builder)); Fory fory1 = new Fory(builder, ClassResolverTest.class.getClassLoader(), sharedRegistry); Fory fory2 = new Fory(builder, ClassResolverTest.class.getClassLoader(), sharedRegistry); @@ -317,7 +321,7 @@ public void testReadTypeDefPublishesValidatedTypeDefById() { .withCompatible(false) .withMetaShare(true); finishBuilder(builder); - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = new SharedRegistry(new Config(builder)); Fory fory1 = new Fory(builder, ClassResolverTest.class.getClassLoader(), sharedRegistry); Fory fory2 = new Fory(builder, ClassResolverTest.class.getClassLoader(), sharedRegistry); @@ -341,31 +345,93 @@ public void testReadTypeDefPublishesValidatedTypeDefById() { } @Test - public void testTypeDefHeaderCacheStopsAtMaxEntries() { + public void testExactLocalTypeDefChecksTypeChecker() { + Fory reader = + Fory.builder() + .withXlang(false) + .requireClassRegistration(false) + .withCompatible(false) + .withMetaShare(true) + .build(); + reader + .getTypeResolver() + .setTypeChecker((resolver, className) -> !className.equals(BeanB.class.getName())); + ClassResolver resolver = (ClassResolver) reader.getTypeResolver(); + TypeDef typeDef = resolver.getTypeDef(BeanB.class, true); + ReadContext readContext = reader.getReadContext(); + readContext.setMetaReadContext(new MetaReadContext()); + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(256); + readContext.prepare(buffer, null, false); + buffer.writeVarUInt32(0); + typeDef.writeTypeDef(buffer); + buffer.readerIndex(0); + + Assert.assertThrows( + InsecureException.class, () -> resolver.readSharedClassMeta(readContext, BeanB.class)); + } + + @Test + public void testRemoteSchemaVersionLimitByType() { ForyBuilder builder = Fory.builder() .withXlang(false) .requireClassRegistration(false) .withCompatible(false) - .withMetaShare(true); + .withMetaShare(true) + .withMaxSchemaVersionsPerType(1); finishBuilder(builder); - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = new SharedRegistry(new Config(builder)); Fory fory = new Fory(builder, ClassResolverTest.class.getClassLoader(), sharedRegistry); ClassResolver resolver = (ClassResolver) fory.getTypeResolver(); - TypeDef typeDef = TypeDef.buildTypeDef(resolver, BeanB.class); - int maxCachedTypeDefs = 8192; - for (long i = 0; i < maxCachedTypeDefs; i++) { - sharedRegistry.typeDefById.put(i, typeDef); - } + TypeDef first = TypeDef.buildTypeDef(resolver, BeanB.class); + TypeDef second = TypeDef.buildTypeDef(resolver, BeanA.class); - MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(256); - typeDef.writeTypeDef(buffer); - buffer.readerIndex(0); - TypeDef readTypeDef = resolver.readTypeDef(buffer, buffer.readInt64()); + assertSame(first, sharedRegistry.getOrCreateRemoteTypeDef(first, "remote.Type")); + assertSame(first, sharedRegistry.getOrCreateRemoteTypeDef(first, "remote.Type")); + Assert.assertThrows( + ForyException.class, () -> sharedRegistry.getOrCreateRemoteTypeDef(second, "remote.Type")); + } + + @Test + public void testRemoteSchemaVersionsUseRemoteTypeKey() { + ForyBuilder builder = + Fory.builder() + .withXlang(false) + .requireClassRegistration(false) + .withCompatible(false) + .withMetaShare(true) + .withMaxSchemaVersionsPerType(1); + finishBuilder(builder); + SharedRegistry sharedRegistry = new SharedRegistry(new Config(builder)); + Fory fory = new Fory(builder, ClassResolverTest.class.getClassLoader(), sharedRegistry); + ClassResolver resolver = (ClassResolver) fory.getTypeResolver(); + + TypeDef first = TypeDef.buildTypeDef(resolver, BeanB.class); + TypeDef second = TypeDef.buildTypeDef(resolver, BeanA.class); + + assertSame(first, sharedRegistry.getOrCreateRemoteTypeDef(first, "remote.UnknownA")); + assertSame(second, sharedRegistry.getOrCreateRemoteTypeDef(second, "remote.UnknownB")); + } + + @Test + public void testRemoteTypeDefCheckOnly() { + ForyBuilder builder = + Fory.builder() + .withXlang(false) + .requireClassRegistration(false) + .withCompatible(false) + .withMetaShare(true) + .withMaxSchemaVersionsPerType(1); + finishBuilder(builder); + SharedRegistry sharedRegistry = new SharedRegistry(new Config(builder)); + Fory fory = new Fory(builder, ClassResolverTest.class.getClassLoader(), sharedRegistry); + ClassResolver resolver = (ClassResolver) fory.getTypeResolver(); + + TypeDef checked = TypeDef.buildTypeDef(resolver, BeanB.class); + TypeDef accepted = TypeDef.buildTypeDef(resolver, BeanA.class); - assertNotNull(readTypeDef); - assertNull(sharedRegistry.typeDefById.get(typeDef.getId())); - assertEquals(sharedRegistry.typeDefById.size(), maxCachedTypeDefs); + sharedRegistry.checkRemoteTypeDefLimit(checked, "remote.Type"); + assertSame(accepted, sharedRegistry.getOrCreateRemoteTypeDef(accepted, "remote.Type")); } @Test @@ -373,7 +439,7 @@ public void testSharedRegistryCachesFieldDescriptorsAndDescriptorGrouper() { ForyBuilder builder = Fory.builder().withXlang(false).requireClassRegistration(false).withCompatible(false); finishBuilder(builder); - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = new SharedRegistry(new Config(builder)); Fory fory1 = new Fory(builder, ClassResolverTest.class.getClassLoader(), sharedRegistry); Fory fory2 = new Fory(builder, ClassResolverTest.class.getClassLoader(), sharedRegistry); @@ -409,7 +475,7 @@ public void testSharedRegistryCachesTypeDefDescriptorsAndDescriptorGrouperBySema .requireClassRegistration(false) .withCompatible(false); finishBuilder(builder); - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = new SharedRegistry(new Config(builder)); Fory fory1 = new Fory(builder, ClassResolverTest.class.getClassLoader(), sharedRegistry); Fory fory2 = new Fory(builder, ClassResolverTest.class.getClassLoader(), sharedRegistry); @@ -444,7 +510,7 @@ public void testRegisterNamedClassCachesOnlyNamespaceAndTypeName() { ForyBuilder builder = Fory.builder().withXlang(false).requireClassRegistration(true).withCompatible(false); finishBuilder(builder); - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = new SharedRegistry(new Config(builder)); Fory fory = new Fory(builder, ClassResolverTest.class.getClassLoader(), sharedRegistry); ClassResolver classResolver = (ClassResolver) fory.getTypeResolver(); @@ -459,7 +525,7 @@ public void testFinishRegisterPublishesAndAdoptsSharedRegistration() { ForyBuilder builder = Fory.builder().withXlang(false).requireClassRegistration(true).withCompatible(false); finishBuilder(builder); - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = new SharedRegistry(new Config(builder)); Fory fory1 = new Fory(builder, ClassResolverTest.class.getClassLoader(), sharedRegistry); Fory fory2 = new Fory(builder, ClassResolverTest.class.getClassLoader(), sharedRegistry); @@ -804,7 +870,7 @@ public void testShareableSerializerSharedAcrossRuntimes() { ForyBuilder builder = Fory.builder().withXlang(false).requireClassRegistration(true).withCompatible(false); finishBuilder(builder); - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = new SharedRegistry(new Config(builder)); Fory fory1 = new Fory(builder, ClassResolverTest.class.getClassLoader(), sharedRegistry); Fory fory2 = new Fory(builder, ClassResolverTest.class.getClassLoader(), sharedRegistry); diff --git a/java/fory-core/src/test/java/org/apache/fory/resolver/MetaStringIOTest.java b/java/fory-core/src/test/java/org/apache/fory/resolver/MetaStringIOTest.java index ea1bb16faa..54299d0add 100644 --- a/java/fory-core/src/test/java/org/apache/fory/resolver/MetaStringIOTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/resolver/MetaStringIOTest.java @@ -27,6 +27,7 @@ import static org.testng.Assert.expectThrows; import java.nio.ByteBuffer; +import org.apache.fory.Fory; import org.apache.fory.TestUtils; import org.apache.fory.collection.LongLongByteMap; import org.apache.fory.context.MetaStringReader; @@ -40,9 +41,13 @@ import org.testng.annotations.Test; public class MetaStringIOTest { + private static SharedRegistry newSharedRegistry() { + return new SharedRegistry(Fory.builder().build().getConfig()); + } + @Test public void testWriteMetaString() { - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = newSharedRegistry(); MetaStringWriter writer = new MetaStringWriter(); MetaStringReader reader = new MetaStringReader(sharedRegistry); MemoryBuffer buffer = MemoryUtils.buffer(32); @@ -67,7 +72,7 @@ public void testWriteSmallMetaString() { }) { for (int i = 0; i < 32; i++) { String str = StringUtils.random(i, 0); - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = newSharedRegistry(); MetaStringWriter writer = new MetaStringWriter(); MetaStringReader reader = new MetaStringReader(sharedRegistry); writer.writeMetaString(buffer, newGenericMetaString(str)); @@ -82,7 +87,7 @@ public void testWriteSmallMetaString() { @Test public void testMetaStringWriterResetClearsDynamicWriteState() { - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = newSharedRegistry(); MetaStringWriter writer = new MetaStringWriter(); MetaStringReader reader = new MetaStringReader(sharedRegistry); EncodedMetaString metaString = newGenericMetaString("thread_safe_fory"); @@ -101,7 +106,7 @@ public void testMetaStringWriterResetClearsDynamicWriteState() { @Test public void testMetaStringReaderUsesSharedRegistryInstances() { - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = newSharedRegistry(); MetaStringWriter writer = new MetaStringWriter(); MetaStringReader reader = new MetaStringReader(sharedRegistry); EncodedMetaString encodedMetaString = newGenericMetaString("shared_meta_string"); @@ -119,7 +124,7 @@ public void testMetaStringReaderUsesSharedRegistryInstances() { @Test public void testSharedRegistrySkipsLongEncodedMetaStrings() { - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = newSharedRegistry(); String str = StringUtils.random(2050, 0); EncodedMetaString first = newGenericMetaString(sharedRegistry, str); @@ -130,7 +135,7 @@ public void testSharedRegistrySkipsLongEncodedMetaStrings() { @Test public void testSharedRegistryCapsEncodedMetaStringCount() { - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = newSharedRegistry(); EncodedMetaString first = null; for (int i = 0; i < 32768; i++) { EncodedMetaString encodedMetaString = newGenericMetaString(sharedRegistry, "meta_" + i); @@ -148,7 +153,7 @@ public void testSharedRegistryCapsEncodedMetaStringCount() { @Test public void testReadBigMetaStringRejectsNonCanonicalHash() { - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = newSharedRegistry(); MetaStringReader reader = new MetaStringReader(sharedRegistry); EncodedMetaString encodedMetaString = newGenericMetaString(StringUtils.random(32, 0)); MemoryBuffer buffer = MemoryUtils.buffer(64); @@ -162,7 +167,7 @@ public void testReadBigMetaStringRejectsNonCanonicalHash() { @Test public void testCachedBigMetaStringReusesHeaderCache() { - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = newSharedRegistry(); MetaStringReader reader = new MetaStringReader(sharedRegistry); EncodedMetaString encodedMetaString = newGenericMetaString(StringUtils.random(32, 0)); MemoryBuffer buffer = MemoryUtils.buffer(128); @@ -179,7 +184,7 @@ public void testCachedBigMetaStringReusesHeaderCache() { @Test public void testReadSmallMetaStringKeyIncludesLengthAndEncoding() { - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = newSharedRegistry(); MetaStringReader reader = new MetaStringReader(sharedRegistry); MemoryBuffer buffer = MemoryUtils.buffer(32); @@ -201,7 +206,7 @@ public void testReadSmallMetaStringKeyIncludesLengthAndEncoding() { @Test public void testMetaStringReaderResetClearsDynamicIdsOnly() { - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = newSharedRegistry(); MetaStringReader reader = new MetaStringReader(sharedRegistry); MemoryBuffer buffer = MemoryUtils.buffer(32); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ObjectStreamSerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ObjectStreamSerializerTest.java index 933bd07c51..48f7f53145 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ObjectStreamSerializerTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ObjectStreamSerializerTest.java @@ -45,6 +45,7 @@ import lombok.EqualsAndHashCode; import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; +import org.apache.fory.config.Config; import org.apache.fory.config.ForyBuilder; import org.apache.fory.context.MetaReadContext; import org.apache.fory.context.MetaWriteContext; @@ -1347,7 +1348,7 @@ public void testObjectStreamReadersReuseValidatedTypeDefCache(boolean compatible .withCompatible(compatible) .withMetaShare(true); finishBuilder(builder); - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = new SharedRegistry(new Config(builder)); Fory writerFory = new Fory(builder, ObjectStreamSerializerTest.class.getClassLoader(), sharedRegistry); Fory readerFory1 = @@ -1378,6 +1379,39 @@ public void testObjectStreamReadersReuseValidatedTypeDefCache(boolean compatible assertEquals(result2.value, 7); } + @Test(dataProvider = "compatibleModeProvider") + public void testObjectStreamExactLocalTypeDefChecksTypeChecker(boolean compatible) { + ForyBuilder builder = + Fory.builder() + .withXlang(false) + .requireClassRegistration(false) + .withRefTracking(true) + .withCompatible(compatible) + .withMetaShare(true); + finishBuilder(builder); + SharedRegistry sharedRegistry = new SharedRegistry(new Config(builder)); + Fory writerFory = + new Fory(builder, ObjectStreamSerializerTest.class.getClassLoader(), sharedRegistry); + Fory readerFory = + new Fory(builder, ObjectStreamSerializerTest.class.getClassLoader(), sharedRegistry); + writerFory.registerSerializer( + MixedSerializationClass.class, + new ObjectStreamSerializer(writerFory.getTypeResolver(), MixedSerializationClass.class)); + readerFory.registerSerializer( + MixedSerializationClass.class, + new ObjectStreamSerializer(readerFory.getTypeResolver(), MixedSerializationClass.class)); + + writerFory.setMetaWriteContext(new MetaWriteContext()); + byte[] bytes = writerFory.serialize(new MixedSerializationClass("blocked", 11)); + readerFory + .getTypeResolver() + .setTypeChecker( + (resolver, className) -> !className.equals(MixedSerializationClass.class.getName())); + readerFory.setMetaReadContext(new MetaReadContext()); + + Assert.assertThrows(InsecureException.class, () -> readerFory.deserialize(bytes)); + } + // ==================== Default Value Tests ==================== /** Class to test default values when fields are missing. */ diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/RegisterTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/RegisterTest.java index c27e2d20d6..8f3b057b65 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/RegisterTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/RegisterTest.java @@ -24,6 +24,7 @@ import lombok.Data; import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; +import org.apache.fory.config.Config; import org.apache.fory.config.ForyBuilder; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; @@ -155,7 +156,7 @@ public void testShareableSerializerOverrideStaysLocal() { ForyBuilder builder = Fory.builder().withXlang(true).withCompatible(false).requireClassRegistration(true); finishBuilder(builder); - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = new SharedRegistry(new Config(builder)); Fory fory1 = new Fory(builder, RegisterTest.class.getClassLoader(), sharedRegistry); Fory fory2 = new Fory(builder, RegisterTest.class.getClassLoader(), sharedRegistry); @@ -176,7 +177,7 @@ public void testShareableInternalSerializerSharedAcrossRuntimes() { ForyBuilder builder = Fory.builder().withXlang(true).withCompatible(false).requireClassRegistration(true); finishBuilder(builder); - SharedRegistry sharedRegistry = new SharedRegistry(); + SharedRegistry sharedRegistry = new SharedRegistry(new Config(builder)); Fory fory1 = new Fory(builder, RegisterTest.class.getClassLoader(), sharedRegistry); Fory fory2 = new Fory(builder, RegisterTest.class.getClassLoader(), sharedRegistry); diff --git a/javascript/packages/core/lib/context.ts b/javascript/packages/core/lib/context.ts index 466d50ece1..8252711a92 100644 --- a/javascript/packages/core/lib/context.ts +++ b/javascript/packages/core/lib/context.ts @@ -41,6 +41,7 @@ type TypeResolverLike = { computeTypeId(typeInfo: TypeInfo): number; getSerializerById(id: number, userTypeId?: number): Serializer | undefined; getSerializerByName(name: string): Serializer | undefined; + getSerializerByHash?(hash: number): Serializer | undefined; getSerializerByData(value: any): Serializer | null | undefined; isCompatible(): boolean; generateReadSerializer(typeInfo: TypeInfo): Serializer; @@ -530,8 +531,7 @@ export class WriteContext { } export class ReadContext { - private static readonly MAX_CACHED_TYPE_META = 8192; - private static readonly MAX_CACHED_COMPATIBLE_READ_SERIALIZER = 8192; + private static readonly MIN_REMOTE_STRUCT_SCHEMA_LIMIT = 8192; readonly reader: BinaryReader; readonly refReader: RefReader; @@ -540,13 +540,14 @@ export class ReadContext { private typeMeta: TypeMeta[] = []; /** Persistent cross-message cache keyed by 8-byte type meta header. */ private typeMetaCache: Map> = new Map(); - private typeMetaCacheSize = 0; private lastTypeMetaHeaderLow = -1; private lastTypeMetaHeaderHigh = -1; private lastTypeMeta: TypeMeta | null = null; private recentTypeMetaHeaderLows = [-1, -1, -1, -1]; private recentTypeMetaHeaderHighs = [-1, -1, -1, -1]; private recentTypeMetas: Array = [null, null, null, null]; + private remoteSchemaVersionsByType = new Map(); + private totalAcceptedSchemaVersions = 0; private compatibleReadSerializers = new Map< number, CompatibleReadSerializerCacheEntry @@ -554,6 +555,7 @@ export class ReadContext { private _depth = 0; private _maxDepth: number; + private readonly config: Config; private static typeMetaHeaderHash(headerLow: number, headerHigh: number) { return headerHigh * 0x100000 + (headerLow >>> 12); @@ -609,6 +611,7 @@ export class ReadContext { this.refReader = new RefReader(this.reader); this.metaStringReader = new MetaStringReader(); this._maxDepth = config.maxDepth ?? 50; + this.config = config; } reset(bytes: Uint8Array) { @@ -649,11 +652,11 @@ export class ReadContext { this.refReader.reference(object); } - private readTypeMetaFromHeader( + private readCachedTypeMetaFromHeader( dynamicTypeId: number, headerLow: number, headerHigh: number, - ): TypeMeta { + ): TypeMeta | null { if ( this.lastTypeMeta !== null && this.lastTypeMetaHeaderLow === headerLow @@ -675,39 +678,160 @@ export class ReadContext { } const cached = this.typeMetaCache.get(headerHigh)?.get(headerLow); - let typeMeta: TypeMeta; if (cached) { // Header-cache hits intentionally skip without rehashing. Entries reach this cache only // after a successful TypeMeta parse and 52-bit metadata-hash validation. The current body // size still comes from the current header bytes, not from the cached TypeMeta. TypeMeta.skipBodyByHeaderLow(this.reader, headerLow); - typeMeta = cached; this.lastTypeMetaHeaderLow = headerLow; this.lastTypeMetaHeaderHigh = headerHigh; - this.lastTypeMeta = typeMeta; - this.rememberRecentTypeMeta(headerLow, headerHigh, typeMeta); - } else { - const header = (BigInt(headerHigh) << 32n) | BigInt(headerLow); - typeMeta = TypeMeta.fromBytesAfterHeader(this.reader, header); - if (this.typeMetaCacheSize < ReadContext.MAX_CACHED_TYPE_META) { - let highCache = this.typeMetaCache.get(headerHigh); - if (highCache === undefined) { - highCache = new Map(); - this.typeMetaCache.set(headerHigh, highCache); - } - highCache.set(headerLow, typeMeta); - this.typeMetaCacheSize++; - this.lastTypeMetaHeaderLow = headerLow; - this.lastTypeMetaHeaderHigh = headerHigh; - this.lastTypeMeta = typeMeta; - this.rememberRecentTypeMeta(headerLow, headerHigh, typeMeta); + this.lastTypeMeta = cached; + this.rememberRecentTypeMeta(headerLow, headerHigh, cached); + this.typeMeta[dynamicTypeId] = cached; + return cached; + } + return null; + } + + private checkRemoteStructSchemaLimit(typeMeta: TypeMeta) { + if (!TypeId.structType(typeMeta.getTypeId())) { + return undefined; + } + const typeKey = TypeId.isNamedType(typeMeta.getTypeId()) + ? `${typeMeta.getNs()}\u0000${typeMeta.getTypeName()}` + : typeMeta.getUserTypeId(); + const versionsForType = this.remoteSchemaVersionsByType.get(typeKey) ?? 0; + const maxSchemaVersionsPerType = this.config.maxSchemaVersionsPerType; + if (versionsForType >= maxSchemaVersionsPerType) { + throw new Error( + `Remote schema version limit exceeded for type ${String(typeKey)}: ` + + `${versionsForType} >= ${maxSchemaVersionsPerType}. Increase ` + + "maxSchemaVersionsPerType if this peer legitimately sends many " + + "schema versions for one type.", + ); + } + const acceptedStructTypeCount + = versionsForType === 0 + ? this.remoteSchemaVersionsByType.size + 1 + : this.remoteSchemaVersionsByType.size; + const maxAverageSchemaVersionsPerType + = this.config.maxAverageSchemaVersionsPerType; + const globalLimit = Math.max( + ReadContext.MIN_REMOTE_STRUCT_SCHEMA_LIMIT, + acceptedStructTypeCount * maxAverageSchemaVersionsPerType, + ); + if (this.totalAcceptedSchemaVersions >= globalLimit) { + throw new Error( + `Remote schema version limit exceeded: ${this.totalAcceptedSchemaVersions} ` + + `schemas for ${acceptedStructTypeCount} accepted struct types exceeds ` + + `the average limit ${maxAverageSchemaVersionsPerType}. Increase ` + + "maxAverageSchemaVersionsPerType if this peer legitimately sends many " + + "schema versions across many types.", + ); + } + return typeKey; + } + + private cacheTypeMeta( + headerLow: number, + headerHigh: number, + typeMeta: TypeMeta, + typeKey: string | number | undefined, + ) { + let highCache = this.typeMetaCache.get(headerHigh); + if (highCache === undefined) { + highCache = new Map(); + this.typeMetaCache.set(headerHigh, highCache); + } + highCache.set(headerLow, typeMeta); + this.lastTypeMetaHeaderLow = headerLow; + this.lastTypeMetaHeaderHigh = headerHigh; + this.lastTypeMeta = typeMeta; + this.rememberRecentTypeMeta(headerLow, headerHigh, typeMeta); + if (typeKey !== undefined) { + const versionsForType = this.remoteSchemaVersionsByType.get(typeKey) ?? 0; + this.remoteSchemaVersionsByType.set(typeKey, versionsForType + 1); + this.totalAcceptedSchemaVersions++; + } + } + + private exactTypeMetaForSerializer( + start: number, + end: number, + serializer: Serializer | undefined | null, + ): TypeMeta | undefined { + if (!serializer || typeof serializer.getTypeInfo !== "function") { + return undefined; + } + const localTypeMeta = TypeMeta.fromTypeInfo( + serializer.getTypeInfo(), + this.typeResolver, + ); + const localBytes = localTypeMeta.toBytes(); + if (end - start !== localBytes.length) { + return undefined; + } + const remoteBytes = this.reader.bufferRefAt(start, localBytes.length); + for (let i = 0; i < localBytes.length; i++) { + if (remoteBytes[i] !== localBytes[i]) { + return undefined; } } + return localTypeMeta; + } + + private readTypeMetaFromHeader( + dynamicTypeId: number, + headerLow: number, + headerHigh: number, + exactLocal?: Serializer, + ): TypeMeta { + const cached = this.readCachedTypeMetaFromHeader( + dynamicTypeId, + headerLow, + headerHigh, + ); + if (cached !== null) { + return cached; + } + + const header = (BigInt(headerHigh) << 32n) | BigInt(headerLow); + if ( + exactLocal !== undefined + && typeof exactLocal.getTypeInfo === "function" + ) { + const localTypeMeta = TypeMeta.fromTypeInfo( + exactLocal.getTypeInfo(), + this.typeResolver, + ); + if ( + TypeMeta.matchesEncodedAfterHeader( + this.reader, + headerLow, + headerHigh, + localTypeMeta.toBytes(), + this.config.maxTypeMetaBytes, + ) + ) { + this.cacheTypeMeta(headerLow, headerHigh, localTypeMeta, undefined); + this.typeMeta[dynamicTypeId] = localTypeMeta; + return localTypeMeta; + } + } + + const typeMeta = TypeMeta.fromBytesAfterHeader( + this.reader, + header, + this.config.maxTypeFields, + this.config.maxTypeMetaBytes, + ); + const typeKey = this.checkRemoteStructSchemaLimit(typeMeta); + this.cacheTypeMeta(headerLow, headerHigh, typeMeta, typeKey); this.typeMeta[dynamicTypeId] = typeMeta; return typeMeta; } - readTypeMeta(): TypeMeta { + readTypeMeta(exactLocal?: Serializer): TypeMeta { const idOrLen = this.reader.readVarUInt32(); if (idOrLen & 1) { const typeMeta = this.typeMeta[idOrLen >> 1]; @@ -718,7 +842,12 @@ export class ReadContext { } const headerLow = this.reader.readUint32(); const headerHigh = this.reader.readUint32(); - return this.readTypeMetaFromHeader(idOrLen >> 1, headerLow, headerHigh); + return this.readTypeMetaFromHeader( + idOrLen >> 1, + headerLow, + headerHigh, + exactLocal, + ); } readTypeMetaIfSchemaChanged( @@ -735,14 +864,26 @@ export class ReadContext { } remoteHash = typeMeta.getHash(); } else { + const dynamicTypeId = idOrLen >> 1; const headerLow = this.reader.readUint32(); const headerHigh = this.reader.readUint32(); - typeMeta = this.readTypeMetaFromHeader( - idOrLen >> 1, + remoteHash = ReadContext.typeMetaHeaderHash(headerLow, headerHigh); + const cached = this.readCachedTypeMetaFromHeader( + dynamicTypeId, headerLow, headerHigh, ); - remoteHash = ReadContext.typeMetaHeaderHash(headerLow, headerHigh); + if (cached === null) { + return this.readChangedTypeMetaMiss( + expectedHash, + original, + dynamicTypeId, + headerLow, + headerHigh, + remoteHash, + ); + } + typeMeta = cached; } if (expectedHash !== remoteHash) { const cached = this.compatibleReadSerializers.get(remoteHash); @@ -753,20 +894,245 @@ export class ReadContext { typeMeta, original, ); + this.compatibleReadSerializers.set(remoteHash, { + localHash: expectedHash, + serializer, + }); + return serializer; + } + return undefined; + } + + private readChangedTypeMetaMiss( + expectedHash: number, + original: Serializer | undefined, + dynamicTypeId: number, + headerLow: number, + headerHigh: number, + remoteHash: number, + ): Serializer | undefined { + let typeMeta: TypeMeta; + let typeKey: string | number | undefined; + let shouldCache = false; + const exactLocal = original + ?? this.typeResolver.getSerializerByHash?.(expectedHash); + if ( + exactLocal !== undefined + && typeof exactLocal.getTypeInfo === "function" + ) { + const localTypeMeta = TypeMeta.fromTypeInfo( + exactLocal.getTypeInfo(), + this.typeResolver, + ); if ( - this.compatibleReadSerializers.size - < ReadContext.MAX_CACHED_COMPATIBLE_READ_SERIALIZER + TypeMeta.matchesEncodedAfterHeader( + this.reader, + headerLow, + headerHigh, + localTypeMeta.toBytes(), + this.config.maxTypeMetaBytes, + ) ) { - this.compatibleReadSerializers.set(remoteHash, { - localHash: expectedHash, - serializer, - }); + this.cacheTypeMeta(headerLow, headerHigh, localTypeMeta, undefined); + this.typeMeta[dynamicTypeId] = localTypeMeta; + typeMeta = localTypeMeta; + } else { + const header = (BigInt(headerHigh) << 32n) | BigInt(headerLow); + typeMeta = TypeMeta.fromBytesAfterHeader( + this.reader, + header, + this.config.maxTypeFields, + this.config.maxTypeMetaBytes, + ); + typeKey = this.checkRemoteStructSchemaLimit(typeMeta); + shouldCache = true; + } + } else { + const header = (BigInt(headerHigh) << 32n) | BigInt(headerLow); + typeMeta = TypeMeta.fromBytesAfterHeader( + this.reader, + header, + this.config.maxTypeFields, + this.config.maxTypeMetaBytes, + ); + typeKey = this.checkRemoteStructSchemaLimit(typeMeta); + shouldCache = true; + } + if (expectedHash !== remoteHash) { + const cached = this.compatibleReadSerializers.get(remoteHash); + if (cached !== undefined && cached.localHash === expectedHash) { + if (shouldCache) { + this.cacheTypeMeta(headerLow, headerHigh, typeMeta, typeKey); + this.typeMeta[dynamicTypeId] = typeMeta; + } + return cached.serializer; + } + const serializer = this.genSerializerByTypeMetaRuntime( + typeMeta, + original, + ); + this.compatibleReadSerializers.set(remoteHash, { + localHash: expectedHash, + serializer, + }); + if (shouldCache) { + this.cacheTypeMeta(headerLow, headerHigh, typeMeta, typeKey); + this.typeMeta[dynamicTypeId] = typeMeta; } return serializer; } + if (shouldCache) { + this.cacheTypeMeta(headerLow, headerHigh, typeMeta, typeKey); + this.typeMeta[dynamicTypeId] = typeMeta; + } return undefined; } + private buildNamedTypeKey(ns: string, typeName: string) { + return `${ns}$${typeName}`; + } + + private serializerByTypeMeta(typeId: number, typeMeta: TypeMeta) { + if (typeId === TypeId.COMPATIBLE_STRUCT) { + return this.typeResolver.getSerializerById(typeId, typeMeta.getUserTypeId()); + } + return this.typeResolver.getSerializerByName( + this.buildNamedTypeKey(typeMeta.getNs(), typeMeta.getTypeName()), + ); + } + + private updateAnySerializer( + serializer: Serializer | undefined | null, + typeMeta: TypeMeta, + ) { + if (!serializer) { + return this.genSerializerByTypeMetaRuntime(typeMeta); + } + const hash = serializer.getHash(); + if (hash !== typeMeta.getHash()) { + return this.genSerializerByTypeMetaRuntime(typeMeta, serializer); + } + return serializer; + } + + private readAnyTypeMetaSerializer(typeId: number, updateStruct: boolean) { + const idOrLen = this.reader.readVarUInt32(); + let typeMeta: TypeMeta; + if (idOrLen & 1) { + typeMeta = this.typeMeta[idOrLen >> 1]; + if (!typeMeta) { + throw new Error(`missing TypeMeta reference ${idOrLen >> 1}`); + } + const serializer = this.serializerByTypeMeta(typeId, typeMeta); + return updateStruct + ? this.updateAnySerializer(serializer, typeMeta) + : serializer; + } + + const dynamicTypeId = idOrLen >> 1; + const typeMetaStart = this.reader.readGetCursor(); + const headerLow = this.reader.readUint32(); + const headerHigh = this.reader.readUint32(); + const cached = this.readCachedTypeMetaFromHeader( + dynamicTypeId, + headerLow, + headerHigh, + ); + if (cached !== null) { + const serializer = this.serializerByTypeMeta(typeId, cached); + return updateStruct + ? this.updateAnySerializer(serializer, cached) + : serializer; + } + + const header = (BigInt(headerHigh) << 32n) | BigInt(headerLow); + typeMeta = TypeMeta.fromBytesAfterHeader( + this.reader, + header, + this.config.maxTypeFields, + this.config.maxTypeMetaBytes, + ); + const typeMetaEnd = this.reader.readGetCursor(); + const serializer = this.serializerByTypeMeta(typeId, typeMeta); + const localTypeMeta = this.exactTypeMetaForSerializer( + typeMetaStart, + typeMetaEnd, + serializer, + ); + if (localTypeMeta !== undefined) { + this.cacheTypeMeta(headerLow, headerHigh, localTypeMeta, undefined); + this.typeMeta[dynamicTypeId] = localTypeMeta; + return serializer!; + } + + const typeKey = this.checkRemoteStructSchemaLimit(typeMeta); + const resolved = updateStruct + ? this.updateAnySerializer(serializer, typeMeta) + : serializer; + if (!resolved) { + return undefined; + } + this.cacheTypeMeta(headerLow, headerHigh, typeMeta, typeKey); + this.typeMeta[dynamicTypeId] = typeMeta; + return resolved; + } + + detectAnySerializer() { + const typeId = this.reader.readUint8(); + let userTypeId = -1; + if (TypeId.needsUserTypeId(typeId) && typeId !== TypeId.COMPATIBLE_STRUCT) { + userTypeId = this.reader.readVarUint32Small7(); + } + let serializer: Serializer | undefined | null; + switch (typeId) { + case TypeId.COMPATIBLE_STRUCT: + serializer = this.readAnyTypeMetaSerializer(typeId, true); + break; + case TypeId.NAMED_ENUM: + case TypeId.NAMED_UNION: + if (this.isCompatible()) { + serializer = this.readAnyTypeMetaSerializer(typeId, false); + } else { + const ns = this.readNamespace(); + const typeName = this.readTypeName(); + serializer = this.typeResolver.getSerializerByName( + this.buildNamedTypeKey(ns, typeName), + ); + } + break; + case TypeId.NAMED_EXT: + if (this.isCompatible()) { + serializer = this.readAnyTypeMetaSerializer(typeId, false); + } else { + const ns = this.readNamespace(); + const typeName = this.readTypeName(); + serializer = this.typeResolver.getSerializerByName( + this.buildNamedTypeKey(ns, typeName), + ); + } + break; + case TypeId.NAMED_STRUCT: + case TypeId.NAMED_COMPATIBLE_STRUCT: + if (this.isCompatible() || typeId === TypeId.NAMED_COMPATIBLE_STRUCT) { + serializer = this.readAnyTypeMetaSerializer(typeId, true); + } else { + const ns = this.readNamespace(); + const typeName = this.readTypeName(); + serializer = this.typeResolver.getSerializerByName( + this.buildNamedTypeKey(ns, typeName), + ); + } + break; + default: + serializer = this.typeResolver.getSerializerById(typeId, userTypeId); + break; + } + if (!serializer) { + throw new Error(`can't find implements of typeId: ${typeId}`); + } + return serializer; + } + private canonicalTypeId(typeId: number): number { switch (typeId) { case TypeId.NAMED_ENUM: diff --git a/javascript/packages/core/lib/fory.ts b/javascript/packages/core/lib/fory.ts index 3c122d790b..05b424b1cc 100644 --- a/javascript/packages/core/lib/fory.ts +++ b/javascript/packages/core/lib/fory.ts @@ -34,6 +34,10 @@ import { ReadContext, WriteContext } from "./context"; const DEFAULT_DEPTH_LIMIT = 50 as const; const MIN_DEPTH_LIMIT = 2 as const; +const DEFAULT_MAX_TYPE_FIELDS = 512 as const; +const DEFAULT_MAX_TYPE_META_BYTES = 4096 as const; +const DEFAULT_MAX_SCHEMA_VERSIONS_PER_TYPE = 10 as const; +const DEFAULT_MAX_AVERAGE_SCHEMA_VERSIONS_PER_TYPE = 3 as const; export default class Fory { readonly typeResolver: TypeResolver; readonly anySerializer: Serializer; @@ -67,10 +71,48 @@ export default class Fory { } private initConfig(config: Partial | undefined) { + const maxTypeFields = config?.maxTypeFields ?? DEFAULT_MAX_TYPE_FIELDS; + if (!Number.isInteger(maxTypeFields) || maxTypeFields <= 0) { + throw new Error( + `maxTypeFields must be a positive integer but got ${maxTypeFields}`, + ); + } + const maxTypeMetaBytes + = config?.maxTypeMetaBytes ?? DEFAULT_MAX_TYPE_META_BYTES; + if (!Number.isInteger(maxTypeMetaBytes) || maxTypeMetaBytes <= 0) { + throw new Error( + `maxTypeMetaBytes must be a positive integer but got ${maxTypeMetaBytes}`, + ); + } + const maxSchemaVersionsPerType + = config?.maxSchemaVersionsPerType ?? DEFAULT_MAX_SCHEMA_VERSIONS_PER_TYPE; + if ( + !Number.isInteger(maxSchemaVersionsPerType) + || maxSchemaVersionsPerType <= 0 + ) { + throw new Error( + `maxSchemaVersionsPerType must be a positive integer but got ${maxSchemaVersionsPerType}`, + ); + } + const maxAverageSchemaVersionsPerType + = config?.maxAverageSchemaVersionsPerType + ?? DEFAULT_MAX_AVERAGE_SCHEMA_VERSIONS_PER_TYPE; + if ( + !Number.isInteger(maxAverageSchemaVersionsPerType) + || maxAverageSchemaVersionsPerType <= 0 + ) { + throw new Error( + `maxAverageSchemaVersionsPerType must be a positive integer but got ${maxAverageSchemaVersionsPerType}`, + ); + } return { ref: Boolean(config?.ref), useSliceString: Boolean(config?.useSliceString), maxDepth: config?.maxDepth, + maxTypeFields, + maxTypeMetaBytes, + maxSchemaVersionsPerType, + maxAverageSchemaVersionsPerType, hooks: config?.hooks || {}, compatible: config?.compatible ?? true, hps: config?.hps, diff --git a/javascript/packages/core/lib/gen/any.ts b/javascript/packages/core/lib/gen/any.ts index 2486e93f3f..806a717e3e 100644 --- a/javascript/packages/core/lib/gen/any.ts +++ b/javascript/packages/core/lib/gen/any.ts @@ -21,93 +21,13 @@ import { TypeInfo } from "../typeInfo"; import { CodecBuilder } from "./builder"; import { BaseSerializerGenerator } from "./serializer"; import { CodegenRegistry } from "./router"; -import { Serializer, TypeId } from "../type"; +import { TypeId } from "../type"; import { Scope } from "./scope"; -import { TypeMeta } from "../meta/TypeMeta"; import { ReadContext, WriteContext } from "../context"; export class AnyHelper { static detectSerializer(readContext: ReadContext) { - const reader = readContext.reader; - const typeResolver = readContext.typeResolver; - const typeId = reader.readUint8(); - let userTypeId = -1; - if (TypeId.needsUserTypeId(typeId) && typeId !== TypeId.COMPATIBLE_STRUCT) { - userTypeId = reader.readVarUint32Small7(); - } - let serializer: Serializer | undefined; - - function buildNamedTypeKey(ns: string, typeName: string) { - return `${ns}$${typeName}`; - } - - function tryUpdateSerializer(serializer: Serializer | undefined | null, typeMeta: TypeMeta) { - if (!serializer) { - return readContext.genSerializerByTypeMetaRuntime(typeMeta); - } - const hash = serializer.getHash(); - if (hash !== typeMeta.getHash()) { - return readContext.genSerializerByTypeMetaRuntime(typeMeta, serializer); - } - return serializer; - } - - switch (typeId) { - case TypeId.COMPATIBLE_STRUCT: - { - const typeMeta = readContext.readTypeMeta(); - serializer = typeResolver.getSerializerById(typeId, typeMeta.getUserTypeId()); - serializer = tryUpdateSerializer(serializer, typeMeta); - } - break; - case TypeId.NAMED_ENUM: - case TypeId.NAMED_UNION: - if (readContext.isCompatible()) { - const typeMeta = readContext.readTypeMeta(); - const ns = typeMeta.getNs(); - const typeName = typeMeta.getTypeName(); - serializer = typeResolver.getSerializerByName(buildNamedTypeKey(ns, typeName)); - } else { - const ns = readContext.readNamespace(); - const typeName = readContext.readTypeName(); - serializer = typeResolver.getSerializerByName(buildNamedTypeKey(ns, typeName)); - } - break; - case TypeId.NAMED_EXT: - if (readContext.isCompatible()) { - const typeMeta = readContext.readTypeMeta(); - const ns = typeMeta.getNs(); - const typeName = typeMeta.getTypeName(); - serializer = typeResolver.getSerializerByName(buildNamedTypeKey(ns, typeName)); - } else { - const ns = readContext.readNamespace(); - const typeName = readContext.readTypeName(); - serializer = typeResolver.getSerializerByName(buildNamedTypeKey(ns, typeName)); - } - break; - case TypeId.NAMED_STRUCT: - case TypeId.NAMED_COMPATIBLE_STRUCT: - if (readContext.isCompatible() || typeId === TypeId.NAMED_COMPATIBLE_STRUCT) { - const typeMeta = readContext.readTypeMeta(); - const ns = typeMeta.getNs(); - const typeName = typeMeta.getTypeName(); - const named = buildNamedTypeKey(ns, typeName); - const namedSerializer = typeResolver.getSerializerByName(named); - serializer = tryUpdateSerializer(namedSerializer, typeMeta); - } else { - const ns = readContext.readNamespace(); - const typeName = readContext.readTypeName(); - serializer = typeResolver.getSerializerByName(buildNamedTypeKey(ns, typeName)); - } - break; - default: - serializer = typeResolver.getSerializerById(typeId, userTypeId); - break; - } - if (!serializer) { - throw new Error(`can't find implements of typeId: ${typeId}`); - } - return serializer; + return readContext.detectAnySerializer(); } static getSerializer(writeContext: WriteContext, v: any) { diff --git a/javascript/packages/core/lib/meta/TypeMeta.ts b/javascript/packages/core/lib/meta/TypeMeta.ts index 90c4696c45..5df0ba7950 100644 --- a/javascript/packages/core/lib/meta/TypeMeta.ts +++ b/javascript/packages/core/lib/meta/TypeMeta.ts @@ -45,6 +45,8 @@ const UINT64_MASK = 0xffffffffffffffffn; const HEADER_HASH_MASK = UINT64_MASK ^ ((1n << HASH_SHIFT_BITS) - 1n); const BIG_NAME_THRESHOLD = 0b111111; const MAX_TYPE_META_NESTING = 128; +const DEFAULT_MAX_TYPE_FIELDS = 512; +const DEFAULT_MAX_TYPE_META_BYTES = 4096; const PRIMITIVE_TYPE_IDS = [ TypeId.BOOL, @@ -454,6 +456,44 @@ export class TypeMeta { reader.readSkip(metaSize); } + static matchesEncodedAfterHeader( + reader: BinaryReader, + headerLow: number, + headerHigh: number, + encoded: Uint8Array, + maxTypeMetaBytes = DEFAULT_MAX_TYPE_META_BYTES, + ): boolean { + const start = reader.readGetCursor(); + const metaSize = TypeMeta.readMetaSizeFromLow(reader, headerLow); + TypeMeta.checkTypeMetaBytes(metaSize, maxTypeMetaBytes); + const bodyStart = reader.readGetCursor(); + const afterHeaderSize = bodyStart - start + metaSize; + reader.checkReadableBytes(metaSize); + if ( + encoded.length !== 8 + afterHeaderSize + || encoded[0] !== (headerLow & 0xff) + || encoded[1] !== ((headerLow >>> 8) & 0xff) + || encoded[2] !== ((headerLow >>> 16) & 0xff) + || encoded[3] !== ((headerLow >>> 24) & 0xff) + || encoded[4] !== (headerHigh & 0xff) + || encoded[5] !== ((headerHigh >>> 8) & 0xff) + || encoded[6] !== ((headerHigh >>> 16) & 0xff) + || encoded[7] !== ((headerHigh >>> 24) & 0xff) + ) { + reader.readSetCursor(start); + return false; + } + const remote = reader.bufferRefAt(start, afterHeaderSize); + for (let i = 0; i < afterHeaderSize; i++) { + if (remote[i] !== encoded[8 + i]) { + reader.readSetCursor(start); + return false; + } + } + reader.readSetCursor(start + afterHeaderSize); + return true; + } + static fromBytes(reader: BinaryReader): TypeMeta { return TypeMeta.fromBytesAfterHeader(reader, TypeMeta.readHeader(reader)); } @@ -462,9 +502,15 @@ export class TypeMeta { * Parse the type meta body after the header has already been consumed * by readHeader(). Used by ReadContext to avoid re-reading the header. */ - static fromBytesAfterHeader(reader: BinaryReader, header: bigint): TypeMeta { + static fromBytesAfterHeader( + reader: BinaryReader, + header: bigint, + maxTypeFields = DEFAULT_MAX_TYPE_FIELDS, + maxTypeMetaBytes = DEFAULT_MAX_TYPE_META_BYTES, + ): TypeMeta { TypeMeta.validateGlobalHeader(header); const metaSize = TypeMeta.readMetaSize(reader, header); + TypeMeta.checkTypeMetaBytes(metaSize, maxTypeMetaBytes); const compressed = false; const headerHash = Number(header >> HASH_SHIFT_BITS); @@ -495,6 +541,7 @@ export class TypeMeta { if (numFields === SMALL_NUM_FIELDS_THRESHOLD) { numFields += reader.readVarUInt32(); } + TypeMeta.checkTypeFields(numFields, maxTypeFields); } else { if ((classHeader & 0b01110000) !== 0) { throw new Error("invalid TypeMeta kind header"); @@ -572,6 +619,27 @@ export class TypeMeta { return metaSize; } + private static checkTypeMetaBytes( + metaSize: number, + maxTypeMetaBytes: number, + ) { + if (metaSize > maxTypeMetaBytes) { + throw new Error( + `Type metadata body size ${metaSize} exceeds maxTypeMetaBytes ${maxTypeMetaBytes}. ` + + "The data may be malicious. If the data is not malicious, please increase maxTypeMetaBytes.", + ); + } + } + + private static checkTypeFields(numFields: number, maxTypeFields: number) { + if (numFields > maxTypeFields) { + throw new Error( + `Type metadata field count ${numFields} exceeds maxTypeFields ${maxTypeFields}. ` + + "The data may be malicious. If the data is not malicious, please increase maxTypeFields.", + ); + } + } + private static validateParsedBodyHash(header: bigint, body: Uint8Array) { const expectedHeaderHash = TypeMeta.headerHashBits( body, diff --git a/javascript/packages/core/lib/type.ts b/javascript/packages/core/lib/type.ts index 5999c99b70..5b7342e75e 100644 --- a/javascript/packages/core/lib/type.ts +++ b/javascript/packages/core/lib/type.ts @@ -291,6 +291,10 @@ export interface Config { ref: boolean; useSliceString: boolean; maxDepth?: number; + maxTypeFields: number; + maxTypeMetaBytes: number; + maxSchemaVersionsPerType: number; + maxAverageSchemaVersionsPerType: number; hooks: { afterCodeGenerated?: (code: string) => string; }; diff --git a/javascript/packages/core/lib/typeResolver.ts b/javascript/packages/core/lib/typeResolver.ts index c9936be86d..749c3dcb54 100644 --- a/javascript/packages/core/lib/typeResolver.ts +++ b/javascript/packages/core/lib/typeResolver.ts @@ -87,6 +87,7 @@ export default class TypeResolver { readonly trackingRef: boolean; private internalSerializer: Serializer[] = new Array(300); private customSerializer: Map = new Map(); + private serializerByHash: Map = new Map(); private writeContext!: WriteContext; private readContext!: ReadContext; @@ -248,6 +249,12 @@ export default class TypeResolver { this.initInternalSerializer(); } + private rememberSerializerHash(serializer: Serializer | undefined) { + if (serializer?._initialized) { + this.serializerByHash.set(serializer.getHash(), serializer); + } + } + registerSerializer(typeInfo: TypeInfo, serializer: Serializer = uninitSerialize) { const typeId = this.computeTypeId(typeInfo); if (!TypeId.isNamedType(typeId)) { @@ -258,7 +265,9 @@ export default class TypeResolver { } else { this.customSerializer.set(key, { ...serializer }); } - return this.customSerializer.get(key); + const registered = this.customSerializer.get(key); + this.rememberSerializerHash(registered); + return registered; } if (typeId <= 0xFF) { if (this.internalSerializer[typeId]) { @@ -266,14 +275,18 @@ export default class TypeResolver { } else { this.internalSerializer[typeId] = { ...serializer }; } - return this.internalSerializer[typeId]; + const registered = this.internalSerializer[typeId]; + this.rememberSerializerHash(registered); + return registered; } if (this.customSerializer.has(typeId)) { Object.assign(this.customSerializer.get(typeId)!, serializer); } else { this.customSerializer.set(typeId, { ...serializer }); } - return this.customSerializer.get(typeId); + const registered = this.customSerializer.get(typeId); + this.rememberSerializerHash(registered); + return registered; } const name = typeInfo.named!; @@ -282,7 +295,9 @@ export default class TypeResolver { } else { this.customSerializer.set(name, { ...serializer }); } - return this.customSerializer.get(name); + const registered = this.customSerializer.get(name); + this.rememberSerializerHash(registered); + return registered; } generateReadSerializer(typeInfo: TypeInfo) { @@ -324,6 +339,10 @@ export default class TypeResolver { return this.customSerializer.get(typeIdOrName); } + getSerializerByHash(hash: number) { + return this.serializerByHash.get(hash); + } + getSerializerByData(v: any) { if (v === null || v === undefined) { return null; diff --git a/javascript/packages/core/package.json b/javascript/packages/core/package.json index 71112f067f..49b28b3d62 100644 --- a/javascript/packages/core/package.json +++ b/javascript/packages/core/package.json @@ -6,6 +6,7 @@ "main": "dist/index.js", "scripts": { "build": "tsc", + "test": "npm run build && node --test test/*.test.js", "prepublishOnly": "npm run build" }, "files": [ diff --git a/javascript/packages/core/test/schema-limit.test.js b/javascript/packages/core/test/schema-limit.test.js new file mode 100644 index 0000000000..37174eabbe --- /dev/null +++ b/javascript/packages/core/test/schema-limit.test.js @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +"use strict"; + +const assert = require("node:assert/strict"); +const runTest = + typeof globalThis.expect === "function" && typeof globalThis.test === "function" + ? globalThis.test + : require("node:test").test; +const { ReadContext } = require("../dist/lib/context"); +const { AnyHelper } = require("../dist/lib/gen/any"); +const { FieldInfo, TypeMeta } = require("../dist/lib/meta/TypeMeta"); +const { TypeId } = require("../dist/lib/type"); +const { Type } = require("../dist/lib/typeInfo"); + +function context(typeResolver = {}, config = {}) { + return new ReadContext( + typeResolver, + { + compatible: true, + maxTypeFields: 512, + maxTypeMetaBytes: 4096, + maxAverageSchemaVersionsPerType: 3, + maxSchemaVersionsPerType: 1, + useSliceString: false, + ...config, + }, + ); +} + +function remoteStruct( + name, + fieldName, + fieldType = Type.int32({ encoding: "fixed" }), + typeId = TypeId.NAMED_STRUCT, + userTypeId = -1, +) { + return new TypeMeta([new FieldInfo( + fieldName, + fieldType.typeId, + fieldType.userTypeId, + fieldType.trackingRef === true, + fieldType.nullable === true, + fieldType.options, + )], { + namespace: "example", + typeId, + typeName: name, + userTypeId, + }); +} + +function anyStruct(fieldName, fieldType = Type.int32({ encoding: "fixed" })) { + return remoteStruct("", fieldName, fieldType, TypeId.COMPATIBLE_STRUCT, 901); +} + +function readTypeMeta(readContext, typeMeta) { + const encoded = typeMeta.toBytes(); + const bytes = new Uint8Array(encoded.length + 1); + bytes[0] = 0; + bytes.set(encoded, 1); + readContext.reset(bytes); + return readContext.readTypeMeta(); +} + +function readChangedTypeMeta(readContext, expectedHash, original, typeMeta) { + const encoded = typeMeta.toBytes(); + const bytes = new Uint8Array(encoded.length + 1); + bytes[0] = 0; + bytes.set(encoded, 1); + readContext.reset(bytes); + return readContext.readTypeMetaIfSchemaChanged(expectedHash, original); +} + +function detectAnySerializer(readContext, typeMeta) { + const encoded = typeMeta.toBytes(); + const bytes = new Uint8Array(encoded.length + 2); + bytes[0] = TypeId.COMPATIBLE_STRUCT; + bytes[1] = 0; + bytes.set(encoded, 2); + readContext.reset(bytes); + return AnyHelper.detectSerializer(readContext); +} + +runTest("remote schema limit rejects extra versions", () => { + const readContext = context(); + readTypeMeta(readContext, remoteStruct("Shared", "first")); + assert.throws( + () => readTypeMeta(readContext, remoteStruct("Shared", "second")), + /maxSchemaVersionsPerType/, + ); +}); + +runTest("TypeMeta field limit rejects large struct metadata", () => { + const readContext = context({}, { maxTypeFields: 1 }); + const fieldType = Type.int32({ encoding: "fixed" }); + const typeMeta = new TypeMeta([ + new FieldInfo("first", fieldType.typeId, fieldType.userTypeId, false, false, fieldType.options), + new FieldInfo("second", fieldType.typeId, fieldType.userTypeId, false, false, fieldType.options), + ], { + namespace: "example", + typeId: TypeId.NAMED_STRUCT, + typeName: "TooManyFields", + userTypeId: -1, + }); + + assert.throws(() => readTypeMeta(readContext, typeMeta), /maxTypeFields/); +}); + +runTest("TypeMeta body limit rejects large metadata", () => { + const readContext = context({}, { maxTypeMetaBytes: 1 }); + + assert.throws( + () => readTypeMeta(readContext, remoteStruct("LargeMeta", "value")), + /maxTypeMetaBytes/, + ); +}); + +runTest("failed compatible TypeMeta does not consume schema limit", () => { + const localTypeInfo = Type.struct( + { namespace: "example", typeName: "Shared" }, + { value: Type.int32({ encoding: "fixed" }) }, + ); + const original = { + getTypeInfo() { + return localTypeInfo; + }, + }; + const readContext = context({ + computeTypeId(typeInfo) { + return typeInfo.typeId; + }, + getSerializerById() { + return undefined; + }, + generateReadSerializer(typeInfo) { + return { + getTypeInfo() { + return typeInfo; + }, + }; + }, + }); + const localHash = TypeMeta.fromTypeInfo(localTypeInfo).getHash(); + + assert.throws( + () => readChangedTypeMeta( + readContext, + localHash, + original, + remoteStruct("Shared", "value", Type.map(Type.string(), Type.int32({ encoding: "fixed" }))), + ), + /field schema mismatch/, + ); + assert.doesNotThrow(() => readChangedTypeMeta( + readContext, + localHash, + original, + remoteStruct("Shared", "extra"), + )); +}); + +runTest("exact local TypeMeta bypasses schema limit", () => { + const localTypeInfo = Type.struct( + { namespace: "example", typeName: "Shared" }, + { value: Type.int32({ encoding: "fixed" }) }, + ); + const original = { + getTypeInfo() { + return localTypeInfo; + }, + }; + const localHash = TypeMeta.fromTypeInfo(localTypeInfo).getHash(); + const readContext = context({ + computeTypeId(typeInfo) { + return typeInfo.typeId; + }, + getSerializerById() { + return undefined; + }, + getSerializerByHash(hash) { + return hash === localHash ? original : undefined; + }, + generateReadSerializer(typeInfo) { + return { + getTypeInfo() { + return typeInfo; + }, + }; + }, + }); + + readChangedTypeMeta(readContext, localHash, original, remoteStruct("Shared", "extra")); + assert.doesNotThrow(() => readChangedTypeMeta( + readContext, + localHash, + undefined, + TypeMeta.fromTypeInfo(localTypeInfo), + )); + assert.doesNotThrow(() => readTypeMeta( + readContext, + TypeMeta.fromTypeInfo(localTypeInfo), + )); +}); + +runTest("failed Any TypeMeta does not consume schema limit", () => { + const localTypeInfo = Type.struct( + 901, + { value: Type.int32({ encoding: "fixed" }) }, + ); + const localHash = TypeMeta.fromTypeInfo(localTypeInfo).getHash(); + const original = { + getHash() { + return localHash; + }, + getTypeInfo() { + return localTypeInfo; + }, + }; + const readContext = context({ + computeTypeId(typeInfo) { + return typeInfo.typeId; + }, + getSerializerById(typeId, userTypeId) { + return userTypeId === 901 ? original : undefined; + }, + getSerializerByName() { + return undefined; + }, + generateReadSerializer(typeInfo) { + return { + getHash() { + return TypeMeta.fromTypeInfo(typeInfo).getHash(); + }, + getTypeInfo() { + return typeInfo; + }, + }; + }, + }); + + assert.throws( + () => detectAnySerializer( + readContext, + anyStruct("value", Type.map(Type.string(), Type.int32({ encoding: "fixed" }))), + ), + /field schema mismatch/, + ); + assert.doesNotThrow(() => detectAnySerializer(readContext, anyStruct("extra"))); +}); + +runTest("exact Any TypeMeta bypasses schema limit", () => { + const localTypeInfo = Type.struct( + 901, + { value: Type.int32({ encoding: "fixed" }) }, + ); + const localHash = TypeMeta.fromTypeInfo(localTypeInfo).getHash(); + const original = { + getHash() { + return localHash; + }, + getTypeInfo() { + return localTypeInfo; + }, + }; + const readContext = context({ + computeTypeId(typeInfo) { + return typeInfo.typeId; + }, + getSerializerById(typeId, userTypeId) { + return userTypeId === 901 ? original : undefined; + }, + getSerializerByName() { + return undefined; + }, + generateReadSerializer(typeInfo) { + return { + getHash() { + return TypeMeta.fromTypeInfo(typeInfo).getHash(); + }, + getTypeInfo() { + return typeInfo; + }, + }; + }, + }); + + detectAnySerializer(readContext, anyStruct("extra")); + assert.doesNotThrow(() => detectAnySerializer( + readContext, + TypeMeta.fromTypeInfo(localTypeInfo), + )); + assert.doesNotThrow(() => readTypeMeta( + readContext, + TypeMeta.fromTypeInfo(localTypeInfo), + )); +}); + +runTest("remote schema limit keeps unknown structs separate", () => { + const readContext = context(); + assert.equal( + readTypeMeta(readContext, remoteStruct("UnknownA", "value")).getTypeName(), + "UnknownA", + ); + assert.equal( + readTypeMeta(readContext, remoteStruct("UnknownB", "value")).getTypeName(), + "UnknownB", + ); +}); diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index 9414c3c7ee..f74e2ab3c0 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -135,6 +135,10 @@ def __init__( strict: bool = True, compatible: Optional[bool] = None, max_depth: int = 50, + max_type_fields: int = 512, + max_type_meta_bytes: int = 4096, + max_schema_versions_per_type: int = 10, + max_average_schema_versions_per_type: int = 3, policy: DeserializationPolicy = None, field_nullable: bool = False, meta_compressor=None, @@ -169,6 +173,16 @@ def __init__( max_depth: Maximum nesting depth for deserialization (default: 50). Raises an exception if exceeded to prevent malicious deeply-nested data attacks. + max_type_fields: Maximum accepted field count in one received struct TypeDef. + + max_type_meta_bytes: Maximum accepted body size in one received TypeDef. + + max_schema_versions_per_type: Maximum accepted remote schema versions for one + struct type. + + max_average_schema_versions_per_type: Average remote schema versions allowed + across accepted struct types. + policy: Custom deserialization policy for security checks. When provided, it controls which types can be deserialized, overriding the default policy. **Strongly recommended** when strict=False to maintain security controls. @@ -191,6 +205,14 @@ def __init__( self.compatible = compatible self.field_nullable = field_nullable self.max_depth = max_depth + if not isinstance(max_type_fields, int) or max_type_fields <= 0: + raise ValueError("max_type_fields must be a positive integer") + if not isinstance(max_type_meta_bytes, int) or max_type_meta_bytes <= 0: + raise ValueError("max_type_meta_bytes must be a positive integer") + if not isinstance(max_schema_versions_per_type, int) or max_schema_versions_per_type <= 0: + raise ValueError("max_schema_versions_per_type must be a positive integer") + if not isinstance(max_average_schema_versions_per_type, int) or max_average_schema_versions_per_type <= 0: + raise ValueError("max_average_schema_versions_per_type must be a positive integer") self.config = Config( xlang=xlang, track_ref=ref, @@ -199,6 +221,10 @@ def __init__( meta_share=compatible, scoped_meta_share_enabled=compatible, max_depth=max_depth, + max_type_fields=max_type_fields, + max_type_meta_bytes=max_type_meta_bytes, + max_schema_versions_per_type=max_schema_versions_per_type, + max_average_schema_versions_per_type=max_average_schema_versions_per_type, field_nullable=field_nullable, policy=self.policy, meta_compressor=meta_compressor, diff --git a/python/pyfory/meta/typedef_decoder.py b/python/pyfory/meta/typedef_decoder.py index 2c6a89d607..008211e987 100644 --- a/python/pyfory/meta/typedef_decoder.py +++ b/python/pyfory/meta/typedef_decoder.py @@ -53,7 +53,6 @@ MAX_GENERATED_CLASSES = 1000 -MAX_FIELDS_PER_CLASS = 256 _generated_class_count = 0 @@ -102,11 +101,22 @@ def decode_typedef(buffer: Buffer, resolver, header=None) -> TypeDef: raise ValueError("Compressed xlang TypeDef is not supported") # If meta size is at maximum, read additional size + encoded = Buffer.allocate(meta_size + 16) + encoded.write_int64(header) if meta_size == META_SIZE_MASKS: - meta_size += buffer.read_var_uint32() + extended_size = buffer.read_var_uint32() + encoded.write_var_uint32(extended_size) + meta_size += extended_size + max_type_meta_bytes = resolver.config.max_type_meta_bytes + if meta_size > max_type_meta_bytes: + raise ValueError( + f"Type metadata body size {meta_size} exceeds max_type_meta_bytes {max_type_meta_bytes}. " + "The data may be malicious. If the data is not malicious, please increase max_type_meta_bytes." + ) # Read meta data encoded_meta_data = buffer.read_bytes(meta_size) + encoded.write_bytes(encoded_meta_data) meta_data = encoded_meta_data # Create a new buffer for meta data @@ -131,8 +141,12 @@ def decode_typedef(buffer: Buffer, resolver, header=None) -> TypeDef: num_fields = meta_header & SMALL_NUM_FIELDS_THRESHOLD if num_fields == SMALL_NUM_FIELDS_THRESHOLD: num_fields += meta_buffer.read_var_uint32() - if num_fields > MAX_FIELDS_PER_CLASS: - raise ValueError(f"Class has {num_fields} fields, exceeding the maximum allowed {MAX_FIELDS_PER_CLASS} fields.") + max_type_fields = resolver.config.max_type_fields + if num_fields > max_type_fields: + raise ValueError( + f"Type metadata field count {num_fields} exceeds max_type_fields {max_type_fields}. " + "The data may be malicious. If the data is not malicious, please increase max_type_fields." + ) else: if meta_header & 0b01110000: raise ValueError("Invalid TypeDef kind header") @@ -176,6 +190,7 @@ def decode_typedef(buffer: Buffer, resolver, header=None) -> TypeDef: f"Exceeded maximum number of dynamically generated classes ({MAX_GENERATED_CLASSES}). " "This may indicate malicious data causing memory issues." ) + resolver._check_remote_struct_schema_meta(type_id, namespace, typename, user_type_id) _generated_class_count += 1 # Generate dynamic dataclass from field definitions field_definitions = [(field_info.name, Any) for field_info in field_infos] @@ -195,7 +210,7 @@ def decode_typedef(buffer: Buffer, resolver, header=None) -> TypeDef: type_cls, type_id, field_infos, - meta_data, + encoded.to_bytes(0, encoded.get_writer_index()), is_compressed, user_type_id=user_type_id, ) diff --git a/python/pyfory/registry.py b/python/pyfory/registry.py index a467002411..ddeec4fc86 100644 --- a/python/pyfory/registry.py +++ b/python/pyfory/registry.py @@ -144,7 +144,7 @@ NO_TYPE_ID, NO_USER_TYPE_ID, ) -from pyfory.meta.typedef import TypeDef +from pyfory.meta.typedef import TypeDef, is_named_typedef_kind, is_struct_typedef_kind from pyfory.meta.typedef_decoder import decode_typedef, skip_typedef from pyfory.meta.typedef_encoder import encode_typedef @@ -156,7 +156,7 @@ logger = logging.getLogger(__name__) namespace_decoder = MetaStringDecoder(".", "_") typename_decoder = MetaStringDecoder("$", "_") -MAX_CACHED_TYPE_DEFS = 8192 +MIN_REMOTE_STRUCT_SCHEMA_LIMIT = 8192 MAX_CACHED_ENCODED_META_STRINGS = 8192 _NO_REF_NUMERIC_TYPE_IDS = frozenset( @@ -342,6 +342,7 @@ class TypeResolver: "compatible", "field_nullable", "policy", + "config", "shared_registry", "_type_id_counter", "_types_info", @@ -359,6 +360,8 @@ class TypeResolver: "_user_type_id_to_type_info", "_used_user_type_ids", "_meta_shared_type_info", + "_remote_schema_versions_by_type", + "_total_accepted_schema_versions", "meta_share", "_internal_py_serializer_map", "_actual_type_resolver", @@ -366,6 +369,7 @@ class TypeResolver: ) def __init__(self, config, *, shared_registry): + self.config = config self.xlang = config.xlang self.track_ref = config.track_ref self.strict = config.strict @@ -388,6 +392,8 @@ def __init__(self, config, *, shared_registry): self.namespace_encoder = MetaStringEncoder(".", "_") self.namespace_decoder = MetaStringDecoder(".", "_") self._meta_shared_type_info = {} + self._remote_schema_versions_by_type = {} + self._total_accepted_schema_versions = 0 self.typename_encoder = MetaStringEncoder("$", "_") self.typename_decoder = MetaStringDecoder("$", "_") self.meta_compressor = config.meta_compressor if config.meta_compressor is not None else DeflaterMetaCompressor() @@ -1173,6 +1179,71 @@ def _build_type_info_from_typedef(self, type_def): ) return typeinfo + def _local_type_info_for_typedef(self, type_def): + if is_named_typedef_kind(type_def.type_id): + return self.get_type_info_by_name(type_def.namespace, type_def.typename) + if self.is_registered_by_id(type_id=type_def.type_id, user_type_id=type_def.user_type_id): + return self.get_type_info_by_id(type_def.type_id, user_type_id=type_def.user_type_id) + return None + + def _remote_struct_schema_key(self, type_id, namespace, typename, user_type_id): + if not is_struct_typedef_kind(type_id): + return None + if is_named_typedef_kind(type_id): + return (namespace or "", typename) + return user_type_id + + def _check_remote_struct_schema_key(self, type_key): + if type_key is None: + return None + versions_for_type = self._remote_schema_versions_by_type.get(type_key, 0) + max_schema_versions_per_type = self.config.max_schema_versions_per_type + if versions_for_type >= max_schema_versions_per_type: + raise ValueError( + f"Remote schema version limit exceeded for type {type_key}: " + f"{versions_for_type} >= {max_schema_versions_per_type}. Increase " + "max_schema_versions_per_type if this peer legitimately sends many " + "schema versions for one type." + ) + accepted_struct_type_count = ( + len(self._remote_schema_versions_by_type) + 1 if versions_for_type == 0 else len(self._remote_schema_versions_by_type) + ) + max_average_schema_versions_per_type = self.config.max_average_schema_versions_per_type + global_limit = max( + MIN_REMOTE_STRUCT_SCHEMA_LIMIT, + accepted_struct_type_count * max_average_schema_versions_per_type, + ) + if self._total_accepted_schema_versions >= global_limit: + raise ValueError( + "Remote schema version limit exceeded: " + f"{self._total_accepted_schema_versions} schemas for " + f"{accepted_struct_type_count} accepted struct types exceeds the average " + f"limit {max_average_schema_versions_per_type}. Increase " + "max_average_schema_versions_per_type if this peer legitimately sends many " + "schema versions across many types." + ) + return type_key + + def _check_remote_struct_schema_limit(self, type_def): + return self._check_remote_struct_schema_key( + self._remote_struct_schema_key( + type_def.type_id, + type_def.namespace, + type_def.typename, + type_def.user_type_id, + ) + ) + + def _check_remote_struct_schema_meta(self, type_id, namespace, typename, user_type_id): + return self._check_remote_struct_schema_key(self._remote_struct_schema_key(type_id, namespace, typename, user_type_id)) + + def _record_remote_struct_schema(self, type_key): + if type_key is None: + return + versions_for_type = self._remote_schema_versions_by_type.get(type_key, 0) + self._remote_schema_versions_by_type[type_key] = versions_for_type + 1 + self._total_accepted_schema_versions += 1 + def _read_and_build_type_info(self, buffer): """Read TypeDef inline from buffer and build TypeInfo. @@ -1187,7 +1258,14 @@ def _read_and_build_type_info(self, buffer): skip_typedef(buffer, header) return type_info type_def = decode_typedef(buffer, self, header=header) + local_type_info = self._local_type_info_for_typedef(type_def) + if local_type_info is not None: + if local_type_info.type_def is None: + self._set_type_info(local_type_info) + if local_type_info.type_def is not None and local_type_info.type_def.encoded == type_def.encoded: + return local_type_info + type_key = self._check_remote_struct_schema_limit(type_def) type_info = self._build_type_info_from_typedef(type_def) - if len(self._meta_shared_type_info) < MAX_CACHED_TYPE_DEFS: - self._meta_shared_type_info[header] = type_info + self._meta_shared_type_info[header] = type_info + self._record_remote_struct_schema(type_key) return type_info diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx index 715218a301..4778f13587 100644 --- a/python/pyfory/serialization.pyx +++ b/python/pyfory/serialization.pyx @@ -81,8 +81,6 @@ ENABLE_FORY_CYTHON_SERIALIZATION = os.environ.get( cdef int32_t NOT_NULL_BOOL_FLAG = (NOT_NULL_VALUE_FLAG & 0xFF) | (TypeId.BOOL << 8) cdef int32_t NOT_NULL_STRING_FLAG = (NOT_NULL_VALUE_FLAG & 0xFF) | (TypeId.STRING << 8) cdef int32_t NOT_NULL_FLOAT64_FLAG = (NOT_NULL_VALUE_FLAG & 0xFF) | (TypeId.FLOAT64 << 8) -cdef int32_t MAX_CACHED_TYPE_DEFS = 8192 - _PRIMITIVE_TYPE_IDS = frozenset(range(1, 21)) - {16} @@ -111,6 +109,10 @@ cdef class Config: meta_share: Enables shared type metadata on the resolver/type-info path. scoped_meta_share_enabled: Enables per-operation meta-share state. max_depth: Maximum allowed nesting depth during deserialization. + max_type_fields: Maximum accepted field count in one received struct TypeDef. + max_type_meta_bytes: Maximum accepted body size in one received TypeDef. + max_schema_versions_per_type: Maximum accepted remote schema versions for one struct type. + max_average_schema_versions_per_type: Average remote schema versions allowed across accepted struct types. field_nullable: Treats struct/dataclass fields as nullable by default. policy: Deserialization policy used for security-sensitive checks. meta_compressor: Optional typedef/meta compressor implementation. @@ -123,6 +125,10 @@ cdef class Config: cdef public bint meta_share cdef public bint scoped_meta_share_enabled cdef public int32_t max_depth + cdef public int32_t max_type_fields + cdef public int32_t max_type_meta_bytes + cdef public int32_t max_schema_versions_per_type + cdef public int32_t max_average_schema_versions_per_type cdef public bint field_nullable cdef public object policy cdef public object meta_compressor @@ -137,6 +143,10 @@ cdef class Config: meta_share, scoped_meta_share_enabled, max_depth, + max_type_fields, + max_type_meta_bytes, + max_schema_versions_per_type, + max_average_schema_versions_per_type, field_nullable, policy, meta_compressor, @@ -152,6 +162,10 @@ cdef class Config: meta_share: Enable shared type metadata on resolver/type-info paths. scoped_meta_share_enabled: Enable per-operation meta-share state. max_depth: Maximum allowed read depth before failing deserialization. + max_type_fields: Maximum accepted field count in one received struct TypeDef. + max_type_meta_bytes: Maximum accepted body size in one received TypeDef. + max_schema_versions_per_type: Maximum accepted remote schema versions for one struct type. + max_average_schema_versions_per_type: Average remote schema versions allowed across accepted struct types. field_nullable: Treat all struct fields as nullable by default. policy: Deserialization policy implementation. meta_compressor: Optional typedef/meta compressor. @@ -163,6 +177,18 @@ cdef class Config: self.meta_share = meta_share self.scoped_meta_share_enabled = scoped_meta_share_enabled self.max_depth = max_depth + if max_type_fields <= 0: + raise ValueError("max_type_fields must be a positive integer") + if max_type_meta_bytes <= 0: + raise ValueError("max_type_meta_bytes must be a positive integer") + if max_schema_versions_per_type <= 0: + raise ValueError("max_schema_versions_per_type must be a positive integer") + if max_average_schema_versions_per_type <= 0: + raise ValueError("max_average_schema_versions_per_type must be a positive integer") + self.max_type_fields = max_type_fields + self.max_type_meta_bytes = max_type_meta_bytes + self.max_schema_versions_per_type = max_schema_versions_per_type + self.max_average_schema_versions_per_type = max_average_schema_versions_per_type self.field_nullable = field_nullable self.policy = policy self.meta_compressor = meta_compressor @@ -193,6 +219,7 @@ cdef class TypeResolver: """ cdef object resolver + cdef readonly Config config cdef readonly object shared_registry cdef readonly bint xlang cdef readonly bint track_ref @@ -226,6 +253,7 @@ cdef class TypeResolver: shared_registry=shared_registry, ) self.resolver = resolver + self.config = config self.shared_registry = resolver.shared_registry self.xlang = resolver.xlang self.track_ref = resolver.track_ref @@ -515,19 +543,27 @@ cdef class TypeResolver: meta_context.read_type_infos.append(typeinfo) return typeinfo - cdef inline TypeInfo _read_and_build_type_info(self, Buffer buffer): + cdef TypeInfo _read_and_build_type_info(self, Buffer buffer): cdef int64_t header = buffer.read_int64() cdef TypeInfo typeinfo = self._meta_shared_type_info.get(header) cdef object type_def + cdef object type_key if typeinfo is not None: # Header-cache hits intentionally skip without rehashing. Entries reach this cache only # after a successful TypeDef parse and 52-bit metadata-hash validation. _skip_typedef_fast(buffer, header) return typeinfo type_def = decode_typedef(buffer, self.resolver, header=header) + typeinfo = self.resolver._local_type_info_for_typedef(type_def) + if typeinfo is not None: + if typeinfo.type_def is None: + self.resolver._set_type_info(typeinfo) + if typeinfo.type_def is not None and typeinfo.type_def.encoded == type_def.encoded: + return typeinfo + type_key = self.resolver._check_remote_struct_schema_limit(type_def) typeinfo = self.resolver._build_type_info_from_typedef(type_def) - if len(self._meta_shared_type_info) < MAX_CACHED_TYPE_DEFS: - self._meta_shared_type_info[header] = typeinfo + self._meta_shared_type_info[header] = typeinfo + self.resolver._record_remote_struct_schema(type_key) return typeinfo cdef inline TypeInfo _load_bytes_to_type_info( @@ -544,8 +580,7 @@ cdef class TypeResolver: if entry != NULL and deref(entry).second != NULL: return deref(entry).second typeinfo = self.resolver._load_metabytes_to_type_info(ns_metabytes, type_metabytes) - if self._c_meta_hash_to_type_info.size() < MAX_CACHED_TYPE_DEFS: - self._c_meta_hash_to_type_info[hash_key] = typeinfo + self._c_meta_hash_to_type_info[hash_key] = typeinfo return typeinfo @@ -798,6 +833,10 @@ cdef class Fory: strict=True, compatible=None, max_depth=50, + max_type_fields=512, + max_type_meta_bytes=4096, + max_schema_versions_per_type=10, + max_average_schema_versions_per_type=3, policy=None, field_nullable=False, meta_compressor=None, @@ -810,8 +849,12 @@ cdef class Fory: ref: Enable reference tracking for shared and circular references. strict: Require registered types on dynamic resolution paths. compatible: Enable compatible mode and meta-share type exchange. Defaults to - compatible mode. + compatible mode. max_depth: Maximum allowed read depth before rejecting payloads. + max_type_fields: Maximum accepted field count in one received struct TypeDef. + max_type_meta_bytes: Maximum accepted body size in one received TypeDef. + max_schema_versions_per_type: Maximum accepted remote schema versions for one struct type. + max_average_schema_versions_per_type: Average remote schema versions allowed across accepted struct types. policy: Optional deserialization policy implementation. field_nullable: Treat struct fields as nullable by default. meta_compressor: Optional typedef/meta compressor implementation. @@ -837,6 +880,10 @@ cdef class Fory: meta_share=compatible, scoped_meta_share_enabled=compatible, max_depth=max_depth, + max_type_fields=max_type_fields, + max_type_meta_bytes=max_type_meta_bytes, + max_schema_versions_per_type=max_schema_versions_per_type, + max_average_schema_versions_per_type=max_average_schema_versions_per_type, field_nullable=field_nullable, policy=self.policy, meta_compressor=meta_compressor, diff --git a/python/pyfory/tests/test_typedef_encoding.py b/python/pyfory/tests/test_typedef_encoding.py index 1890e06ad6..94649214c9 100644 --- a/python/pyfory/tests/test_typedef_encoding.py +++ b/python/pyfory/tests/test_typedef_encoding.py @@ -26,6 +26,7 @@ import pytest import pyfory +import pyfory.meta.typedef_decoder as typedef_decoder from pyfory.serialization import Buffer from pyfory.meta.typedef import ( TypeDef, @@ -324,24 +325,79 @@ def test_id_registered_typedef_extended_field_count_header(): assert len(decoded_typedef.fields) == 32 -def test_meta_shared_typedef_cache_is_bounded(): - fory = Fory(xlang=True, compatible=True) - fory.register(SimpleTypeDef, name="example.SimpleTypeDef") - resolver = fory.type_resolver - read_and_build = getattr(resolver, "_read_and_build_type_info", None) - if read_and_build is None: - pytest.skip("pure-Python resolver internals are not exposed by this runtime") - typedef = encode_typedef(resolver, SimpleTypeDef) - header_buffer = Buffer(typedef.encoded) - header = header_buffer.read_int64() - for i in range(8192): - resolver._meta_shared_type_info[i] = object() - - typeinfo = read_and_build(Buffer(typedef.encoded)) - - assert typeinfo.type_def.type_id == typedef.type_id - assert header not in resolver._meta_shared_type_info - assert len(resolver._meta_shared_type_info) == 8192 +@pytest.mark.parametrize("xlang", [False, True]) +def test_type_meta_field_limit_rejects_large_struct(xlang): + reader = Fory(xlang=xlang, strict=False, compatible=True, max_type_fields=1) + remote = make_dataclass("RemoteTooManyFields", [("value", int), ("extra", int)]) + type_id, typedef = _remote_typedef(xlang, "example.TooManyFields", remote) + + with pytest.raises(ValueError, match="max_type_fields"): + _read_remote_typedef(reader, type_id, typedef) + + +@pytest.mark.parametrize("xlang", [False, True]) +def test_type_meta_body_limit_rejects_large_metadata(xlang): + reader = Fory(xlang=xlang, strict=False, compatible=True, max_type_meta_bytes=1) + remote = make_dataclass("RemoteLargeTypeMeta", [("value", int)]) + type_id, typedef = _remote_typedef(xlang, "example.LargeTypeMeta", remote) + + with pytest.raises(ValueError, match="max_type_meta_bytes"): + _read_remote_typedef(reader, type_id, typedef) + + +@pytest.mark.parametrize("xlang", [False, True]) +def test_remote_schema_limit_rejects_extra_versions(xlang): + reader = Fory( + xlang=xlang, + strict=False, + compatible=True, + max_schema_versions_per_type=1, + ) + first = make_dataclass("RemoteLimitV1", [("value", int)]) + second = make_dataclass("RemoteLimitV2", [("value", int), ("extra", int)]) + first_type_id, first_typedef = _remote_typedef(xlang, "example.Unknown", first) + second_type_id, second_typedef = _remote_typedef(xlang, "example.Unknown", second) + + _read_remote_typedef(reader, first_type_id, first_typedef) + + generated_class_count = typedef_decoder._generated_class_count + with pytest.raises(ValueError, match="max_schema_versions_per_type"): + _read_remote_typedef(reader, second_type_id, second_typedef) + assert typedef_decoder._generated_class_count == generated_class_count + + +@pytest.mark.parametrize("xlang", [False, True]) +def test_remote_schema_limit_keeps_unknown_types_separate(xlang): + reader = Fory( + xlang=xlang, + strict=False, + compatible=True, + max_schema_versions_per_type=1, + ) + first = make_dataclass("RemoteUnknownA", [("value", int)]) + second = make_dataclass("RemoteUnknownB", [("value", int)]) + first_type_id, first_typedef = _remote_typedef(xlang, "example.UnknownA", first) + second_type_id, second_typedef = _remote_typedef(xlang, "example.UnknownB", second) + + _read_remote_typedef(reader, first_type_id, first_typedef) + _read_remote_typedef(reader, second_type_id, second_typedef) + + +def _remote_typedef(xlang, remote_name, cls): + writer = Fory(xlang=xlang, strict=False, compatible=True) + writer.register(cls, name=remote_name) + type_id, _ = writer.type_resolver.get_registered_type_ids(cls) + return type_id, encode_typedef(writer.type_resolver, cls).encoded + + +def _read_remote_typedef(fory, type_id, encoded): + buffer = Buffer.allocate(len(encoded) + 8) + buffer.write_uint8(type_id) + buffer.write_var_uint32(0) + buffer.write_bytes(encoded) + fory.read_context.reset() + fory.read_context.prepare(Buffer(buffer.to_bytes())) + return fory.type_resolver.read_type_info(fory.read_context) def _corrupt_encoded_field_name(typedef, field_name): diff --git a/rust/fory-core/src/config.rs b/rust/fory-core/src/config.rs index f7152d7292..b5300d70c8 100644 --- a/rust/fory-core/src/config.rs +++ b/rust/fory-core/src/config.rs @@ -40,6 +40,14 @@ pub struct Config { /// When enabled, shared references and circular references are tracked /// and preserved during serialization/deserialization. pub track_ref: bool, + /// Maximum accepted field count in one received struct TypeMeta. + pub max_type_fields: usize, + /// Maximum accepted body size in one received TypeMeta. + pub max_type_meta_bytes: usize, + /// Maximum accepted remote struct schema versions for one logical type. + pub max_schema_versions_per_type: usize, + /// Maximum accepted average remote struct schema versions across logical types. + pub max_average_schema_versions_per_type: usize, } impl Default for Config { @@ -53,6 +61,10 @@ impl Default for Config { max_dyn_depth: 5, check_struct_version: false, track_ref: false, + max_type_fields: 512, + max_type_meta_bytes: 4096, + max_schema_versions_per_type: 10, + max_average_schema_versions_per_type: 3, } } } @@ -110,4 +122,28 @@ impl Config { pub fn is_track_ref(&self) -> bool { self.track_ref } + + /// Get maximum accepted field count in one received struct TypeMeta. + #[inline(always)] + pub fn max_type_fields(&self) -> usize { + self.max_type_fields + } + + /// Get maximum accepted body size in one received TypeMeta. + #[inline(always)] + pub fn max_type_meta_bytes(&self) -> usize { + self.max_type_meta_bytes + } + + /// Get maximum accepted remote struct schema versions for one logical type. + #[inline(always)] + pub fn max_schema_versions_per_type(&self) -> usize { + self.max_schema_versions_per_type + } + + /// Get maximum accepted average remote struct schema versions across logical types. + #[inline(always)] + pub fn max_average_schema_versions_per_type(&self) -> usize { + self.max_average_schema_versions_per_type + } } diff --git a/rust/fory-core/src/context.rs b/rust/fory-core/src/context.rs index 933c80ac38..260f94ea4c 100644 --- a/rust/fory-core/src/context.rs +++ b/rust/fory-core/src/context.rs @@ -352,6 +352,7 @@ unsafe impl<'a> Sync for WriteContext<'a> {} pub struct ReadContext<'a> { // Replicated environment fields (direct access, no Arc indirection for flags) type_resolver: TypeResolver, + config: Config, compatible: bool, share_meta: bool, xlang: bool, @@ -380,6 +381,7 @@ impl<'a> ReadContext<'a> { pub fn new(type_resolver: TypeResolver, config: Config) -> ReadContext<'a> { ReadContext { type_resolver, + config: config.clone(), compatible: config.compatible, share_meta: config.share_meta, xlang: config.xlang, @@ -463,7 +465,7 @@ impl<'a> ReadContext<'a> { #[inline(always)] pub fn read_type_meta(&mut self) -> Result, Error> { self.meta_resolver - .read_type_meta(&mut self.reader, &self.type_resolver) + .read_type_meta(&mut self.reader, &self.type_resolver, &self.config) } pub fn read_any_type_info(&mut self) -> Result, Error> { diff --git a/rust/fory-core/src/fory.rs b/rust/fory-core/src/fory.rs index f9885fd3d6..295e6d5a96 100644 --- a/rust/fory-core/src/fory.rs +++ b/rust/fory-core/src/fory.rs @@ -298,6 +298,40 @@ impl ForyBuilder { self } + /// Sets the maximum field count accepted in one received struct TypeMeta. + pub fn max_type_fields(mut self, max_fields: usize) -> Self { + assert!(max_fields > 0, "max_type_fields must be positive"); + self.config.max_type_fields = max_fields; + self + } + + /// Sets the maximum body size accepted for one received TypeMeta. + pub fn max_type_meta_bytes(mut self, max_bytes: usize) -> Self { + assert!(max_bytes > 0, "max_type_meta_bytes must be positive"); + self.config.max_type_meta_bytes = max_bytes; + self + } + + /// Sets the maximum accepted remote struct schema versions for one logical type. + pub fn max_schema_versions_per_type(mut self, max_versions: usize) -> Self { + assert!( + max_versions > 0, + "max_schema_versions_per_type must be positive" + ); + self.config.max_schema_versions_per_type = max_versions; + self + } + + /// Sets the maximum accepted average remote struct schema versions across logical types. + pub fn max_average_schema_versions_per_type(mut self, max_versions: usize) -> Self { + assert!( + max_versions > 0, + "max_average_schema_versions_per_type must be positive" + ); + self.config.max_average_schema_versions_per_type = max_versions; + self + } + fn finish_config(self) -> Config { let mut config = self.config; if !self.compatible_set { diff --git a/rust/fory-core/src/meta/type_meta.rs b/rust/fory-core/src/meta/type_meta.rs index 5913bb2796..b60d5502ee 100644 --- a/rust/fory-core/src/meta/type_meta.rs +++ b/rust/fory-core/src/meta/type_meta.rs @@ -89,7 +89,8 @@ use std::collections::HashMap; use std::rc::Rc; const SMALL_NUM_FIELDS_THRESHOLD: usize = 0b11111; -const MAX_TYPE_META_FIELDS: usize = i16::MAX as usize; +const DEFAULT_MAX_TYPE_FIELDS: usize = 512; +const DEFAULT_MAX_TYPE_META_BYTES: usize = 4096; const MAX_COMPATIBLE_MATCHED_FIELD_INDEX: usize = (i16::MAX as usize - 1) / 2; const REGISTER_BY_NAME_FLAG: u8 = 0b0010_0000; const COMPATIBLE_TYPEDEF_FLAG: u8 = 0b0100_0000; @@ -196,6 +197,24 @@ fn read_type_meta_body_size(reader: &mut Reader, header: i64) -> Result Result<(), Error> { + if meta_size > max_type_meta_bytes { + return Err(Error::invalid_data(format!( + "Type metadata body size {meta_size} exceeds max_type_meta_bytes {max_type_meta_bytes}. The data may be malicious. If the data is not malicious, please increase max_type_meta_bytes." + ))); + } + Ok(()) +} + +fn check_type_meta_fields(num_fields: usize, max_type_fields: usize) -> Result<(), Error> { + if num_fields > max_type_fields { + return Err(Error::invalid_data(format!( + "Type metadata field count {num_fields} exceeds max_type_fields {max_type_fields}. The data may be malicious. If the data is not malicious, please increase max_type_fields." + ))); + } + Ok(()) +} + #[inline(always)] fn type_meta_hash_bits(body: &[u8], header_low_bits: u64) -> u64 { let mut hash_input = Vec::with_capacity(body.len() + 2); @@ -1130,6 +1149,7 @@ impl TypeMeta { fn from_meta_bytes( reader: &mut Reader, type_resolver: &TypeResolver, + max_type_fields: usize, ) -> Result { let meta_header = reader.read_u8()?; let is_struct = (meta_header & STRUCT_TYPEDEF_FLAG) != 0; @@ -1157,12 +1177,7 @@ impl TypeMeta { if num_fields == SMALL_NUM_FIELDS_THRESHOLD { num_fields += reader.read_var_u32()? as usize; } - if num_fields > MAX_TYPE_META_FIELDS { - return Err(Error::invalid_data(format!( - "too many fields in type meta: {}, max: {}", - num_fields, MAX_TYPE_META_FIELDS - ))); - } + check_type_meta_fields(num_fields, max_type_fields)?; } else { if (meta_header & 0b0111_0000) != 0 { return Err(Error::invalid_data("invalid TypeMeta kind header")); @@ -1264,19 +1279,28 @@ impl TypeMeta { type_resolver: &TypeResolver, ) -> Result { let header = reader.read_i64()?; - Self::from_bytes_with_header(reader, type_resolver, header) + Self::from_bytes_with_header( + reader, + type_resolver, + header, + DEFAULT_MAX_TYPE_FIELDS, + DEFAULT_MAX_TYPE_META_BYTES, + ) } pub(crate) fn from_bytes_with_header( reader: &mut Reader, type_resolver: &TypeResolver, header: i64, + max_type_fields: usize, + max_type_meta_bytes: usize, ) -> Result { validate_type_meta_header(header)?; let meta_size = read_type_meta_body_size(reader, header)?; + check_type_meta_body_size(meta_size, max_type_meta_bytes)?; let body = reader.read_bytes(meta_size)?; let mut body_reader = Reader::new(body); - let mut meta = Self::from_meta_bytes(&mut body_reader, type_resolver)?; + let mut meta = Self::from_meta_bytes(&mut body_reader, type_resolver, max_type_fields)?; if !body_reader.slice_after_cursor().is_empty() { return Err(Error::invalid_data("invalid TypeMeta metadata size")); } diff --git a/rust/fory-core/src/resolver/meta_resolver.rs b/rust/fory-core/src/resolver/meta_resolver.rs index 48b7770042..fb4a26b765 100644 --- a/rust/fory-core/src/resolver/meta_resolver.rs +++ b/rust/fory-core/src/resolver/meta_resolver.rs @@ -16,10 +16,12 @@ // under the License. use crate::buffer::{Reader, Writer}; +use crate::config::Config; use crate::error::Error; use crate::meta::TypeMeta; use crate::resolver::type_resolver::NO_USER_TYPE_ID; use crate::resolver::{TypeInfo, TypeResolver}; +use crate::type_id::{COMPATIBLE_STRUCT, NAMED_COMPATIBLE_STRUCT, NAMED_STRUCT, STRUCT}; use std::collections::HashMap; use std::rc::Rc; @@ -34,7 +36,7 @@ pub struct MetaWriterResolver { next_index: usize, } -const MAX_PARSED_NUM_TYPE_DEFS: usize = 8192; +const MIN_REMOTE_STRUCT_SCHEMA_LIMIT: usize = 8192; const NO_WRITTEN_TYPE_INDEX: usize = usize::MAX; #[allow(dead_code)] @@ -113,6 +115,8 @@ impl MetaWriterResolver { pub struct MetaReaderResolver { pub reading_type_infos: Vec>, parsed_type_infos: HashMap>, + remote_schema_versions_by_type: HashMap, + total_accepted_schema_versions: usize, last_meta_header: i64, last_type_info: Option>, } @@ -130,6 +134,7 @@ impl MetaReaderResolver { &mut self, reader: &mut Reader, type_resolver: &TypeResolver, + config: &Config, ) -> Result, Error> { let index_marker = reader.read_var_u32()?; let is_ref = (index_marker & 1) == 1; @@ -163,109 +168,163 @@ impl MetaReaderResolver { TypeMeta::skip_bytes_for_validated_header(reader, meta_header)?; Ok(type_info.clone()) } else { - let type_meta = Rc::new(TypeMeta::from_bytes_with_header( - reader, - type_resolver, - meta_header, - )?); - - // Try to find local type info - let namespace = &type_meta.get_namespace().original; - let type_name = &type_meta.get_type_name().original; - let register_by_name = !namespace.is_empty() || !type_name.is_empty(); - let type_info = if register_by_name { - // Registered by name (namespace can be empty) - if let Some(local_type_info) = - type_resolver.get_type_info_by_name(namespace, type_name) - { - // Exact schemas can reuse the local TypeInfo; changed - // schemas keep the remote metadata with the local harness. - if type_meta.get_hash() == local_type_info.get_type_meta_ref().get_hash() { - local_type_info - } else { - Rc::new(TypeInfo::from_remote_meta( - type_meta.clone(), - Some(local_type_info.get_harness()), - Some(local_type_info.get_type_id() as u32), - Some(local_type_info.get_user_type_id()), - )) - } - } else { - // No local type found, use stub harness - Rc::new(TypeInfo::from_remote_meta( - type_meta.clone(), - None, - None, - None, - )) - } + let type_def_start = reader.get_cursor() - std::mem::size_of::(); + self.read_type_meta_miss(reader, type_resolver, config, meta_header, type_def_start) + } + } + } + + #[cold] + #[inline(never)] + fn read_type_meta_miss( + &mut self, + reader: &mut Reader, + type_resolver: &TypeResolver, + config: &Config, + meta_header: i64, + type_def_start: usize, + ) -> Result, Error> { + let type_meta = Rc::new(TypeMeta::from_bytes_with_header( + reader, + type_resolver, + meta_header, + config.max_type_fields(), + config.max_type_meta_bytes(), + )?); + let remote_type_def = reader.sub_slice(type_def_start, reader.get_cursor())?; + + let namespace = type_meta.get_namespace(); + let type_name = type_meta.get_type_name(); + let register_by_name = !namespace.original.is_empty() || !type_name.original.is_empty(); + let mut remote_schema_key = None; + let type_info = if register_by_name { + if let Some(local_type_info) = + type_resolver.get_type_info_by_name(&namespace.original, &type_name.original) + { + if local_type_info.get_type_meta_ref().get_bytes() == remote_type_def { + local_type_info } else { - // Registered by ID - let type_id = type_meta.get_type_id(); - let user_type_id = type_meta.get_user_type_id(); - if user_type_id != NO_USER_TYPE_ID { - if let Some(local_type_info) = - type_resolver.get_user_type_info_by_id(user_type_id) - { - // Exact schemas can reuse the local TypeInfo; changed - // schemas keep the remote metadata with the local harness. - if type_meta.get_hash() - == local_type_info.get_type_meta_ref().get_hash() - { - local_type_info - } else { - Rc::new(TypeInfo::from_remote_meta( - type_meta.clone(), - Some(local_type_info.get_harness()), - Some(local_type_info.get_type_id() as u32), - Some(local_type_info.get_user_type_id()), - )) - } - } else { - // No local type found, use stub harness - Rc::new(TypeInfo::from_remote_meta( - type_meta.clone(), - None, - None, - None, - )) - } - } else if let Some(local_type_info) = type_resolver.get_type_info_by_id(type_id) - { - // Exact schemas can reuse the local TypeInfo; changed - // schemas keep the remote metadata with the local harness. - if type_meta.get_hash() == local_type_info.get_type_meta_ref().get_hash() { - local_type_info - } else { - Rc::new(TypeInfo::from_remote_meta( - type_meta.clone(), - Some(local_type_info.get_harness()), - Some(local_type_info.get_type_id() as u32), - Some(local_type_info.get_user_type_id()), - )) - } - } else { - // No local type found, use stub harness - Rc::new(TypeInfo::from_remote_meta( - type_meta.clone(), - None, - None, - None, - )) - } - }; - - if self.parsed_type_infos.len() < MAX_PARSED_NUM_TYPE_DEFS { - // avoid malicious type defs to OOM parsed_type_infos - self.parsed_type_infos - .insert(meta_header, type_info.clone()); - self.last_meta_header = meta_header; - self.last_type_info = Some(type_info.clone()); + remote_schema_key = + self.check_remote_struct_schema_limit(&type_meta, config)?; + Rc::new(TypeInfo::from_remote_meta( + type_meta.clone(), + Some(local_type_info.get_harness()), + Some(local_type_info.get_type_id() as u32), + Some(local_type_info.get_user_type_id()), + )) } - self.reading_type_infos.push(type_info.clone()); - Ok(type_info) + } else { + remote_schema_key = self.check_remote_struct_schema_limit(&type_meta, config)?; + Rc::new(TypeInfo::from_remote_meta( + type_meta.clone(), + None, + None, + None, + )) } + } else { + let type_id = type_meta.get_type_id(); + let user_type_id = type_meta.get_user_type_id(); + let local_type_info = if user_type_id != NO_USER_TYPE_ID { + type_resolver.get_user_type_info_by_id(user_type_id) + } else { + type_resolver.get_type_info_by_id(type_id) + }; + if let Some(local_type_info) = local_type_info { + if local_type_info.get_type_meta_ref().get_bytes() == remote_type_def { + local_type_info + } else { + remote_schema_key = + self.check_remote_struct_schema_limit(&type_meta, config)?; + Rc::new(TypeInfo::from_remote_meta( + type_meta.clone(), + Some(local_type_info.get_harness()), + Some(local_type_info.get_type_id() as u32), + Some(local_type_info.get_user_type_id()), + )) + } + } else { + remote_schema_key = self.check_remote_struct_schema_limit(&type_meta, config)?; + Rc::new(TypeInfo::from_remote_meta( + type_meta.clone(), + None, + None, + None, + )) + } + }; + + self.parsed_type_infos + .insert(meta_header, type_info.clone()); + self.last_meta_header = meta_header; + self.last_type_info = Some(type_info.clone()); + self.reading_type_infos.push(type_info.clone()); + self.record_remote_struct_schema(remote_schema_key); + Ok(type_info) + } + + #[cold] + #[inline(never)] + fn check_remote_struct_schema_limit( + &self, + type_meta: &TypeMeta, + config: &Config, + ) -> Result, Error> { + if !matches!( + type_meta.get_type_id(), + STRUCT | COMPATIBLE_STRUCT | NAMED_STRUCT | NAMED_COMPATIBLE_STRUCT + ) { + return Ok(None); + } + + let namespace = type_meta.get_namespace(); + let type_name = type_meta.get_type_name(); + let key = if !namespace.original.is_empty() || !type_name.original.is_empty() { + format!("n{}\0{}", namespace.original, type_name.original) + } else { + format!("i{}", type_meta.get_user_type_id()) + }; + + let versions_for_type = self + .remote_schema_versions_by_type + .get(&key) + .copied() + .unwrap_or(0); + if versions_for_type >= config.max_schema_versions_per_type() { + return Err(Error::invalid_data(format!( + "remote struct schema versions for one type exceeded max_schema_versions_per_type={}", + config.max_schema_versions_per_type() + ))); + } + + let accepted_type_count = + self.remote_schema_versions_by_type.len() + if versions_for_type == 0 { 1 } else { 0 }; + let global_limit = usize::max( + MIN_REMOTE_STRUCT_SCHEMA_LIMIT, + accepted_type_count * config.max_average_schema_versions_per_type(), + ); + if self.total_accepted_schema_versions >= global_limit { + return Err(Error::invalid_data(format!( + "remote struct schema versions exceeded global limit from max_average_schema_versions_per_type={}", + config.max_average_schema_versions_per_type() + ))); } + + Ok(Some(key)) + } + + fn record_remote_struct_schema(&mut self, key: Option) { + let Some(key) = key else { + return; + }; + let versions_for_type = self + .remote_schema_versions_by_type + .get(&key) + .copied() + .unwrap_or(0); + self.remote_schema_versions_by_type + .insert(key, versions_for_type + 1); + self.total_accepted_schema_versions += 1; } #[inline(always)] @@ -277,54 +336,210 @@ impl MetaReaderResolver { #[cfg(test)] mod tests { use super::*; - use crate::meta::MetaString; + use crate::config::Config; + use crate::meta::{FieldInfo, FieldType, MetaString}; use crate::TypeId; + fn read_type_def( + resolver: &mut MetaReaderResolver, + config: &Config, + type_def: &[u8], + ) -> Result, Error> { + let mut bytes = vec![]; + let mut writer = Writer::from_buffer(&mut bytes); + writer.write_var_u32(0); + writer.write_bytes(type_def); + let mut reader = Reader::new(&bytes); + resolver.read_type_meta(&mut reader, &TypeResolver::default(), config) + } + #[test] - fn parsed_type_info_cache_does_not_publish_after_limit() { + fn type_meta_field_limit_rejects_large_struct() { let meta = TypeMeta::new( TypeId::STRUCT as u32, 9001, MetaString::get_empty().clone(), MetaString::get_empty().clone(), false, - vec![], + vec![ + FieldInfo::new("a", FieldType::new(crate::type_id::INT32, false, vec![])), + FieldInfo::new("b", FieldType::new(crate::type_id::INT32, false, vec![])), + ], ) .unwrap(); - let type_def = meta.get_bytes().to_vec(); - let mut header_reader = Reader::new(&type_def); - let meta_header = header_reader.read_i64().unwrap(); + let config = Config { + max_type_fields: 1, + ..Default::default() + }; + let err = read_type_def( + &mut MetaReaderResolver::default(), + &config, + meta.get_bytes(), + ) + .unwrap_err() + .to_string(); + assert!(err.contains("max_type_fields")); + } - let mut resolver = MetaReaderResolver::default(); - let cached_type_info = Rc::new(TypeInfo::from_remote_meta( - Rc::new(TypeMeta::empty().unwrap()), - None, - None, - None, - )); - let mut header = 0; - while resolver.parsed_type_infos.len() < MAX_PARSED_NUM_TYPE_DEFS { - if header != meta_header { - resolver - .parsed_type_infos - .insert(header, cached_type_info.clone()); - } - header += 1; + #[test] + fn type_meta_body_limit_rejects_large_metadata() { + let meta = TypeMeta::new( + TypeId::STRUCT as u32, + 9001, + MetaString::get_empty().clone(), + MetaString::get_empty().clone(), + false, + vec![FieldInfo::new( + "a", + FieldType::new(crate::type_id::INT32, false, vec![]), + )], + ) + .unwrap(); + let config = Config { + max_type_meta_bytes: 1, + ..Default::default() + }; + let err = read_type_def( + &mut MetaReaderResolver::default(), + &config, + meta.get_bytes(), + ) + .unwrap_err() + .to_string(); + assert!(err.contains("max_type_meta_bytes")); + } + + #[test] + fn schema_limit_tracks_unknown_struct_types_separately() { + fn type_def(user_type_id: u32, field_name: &str) -> Vec { + TypeMeta::new( + TypeId::STRUCT as u32, + user_type_id, + MetaString::get_empty().clone(), + MetaString::get_empty().clone(), + false, + vec![FieldInfo::new( + field_name, + FieldType::new(crate::type_id::INT32, false, vec![]), + )], + ) + .unwrap() + .get_bytes() + .to_vec() } + let config = Config { + max_schema_versions_per_type: 1, + ..Default::default() + }; + + let mut resolver = MetaReaderResolver::default(); + read_type_def(&mut resolver, &config, &type_def(9001, "a")).unwrap(); + read_type_def(&mut resolver, &config, &type_def(9002, "a")).unwrap(); + + let err = read_type_def(&mut resolver, &config, &type_def(9001, "b")) + .unwrap_err() + .to_string(); + assert!(err.contains("max_schema_versions_per_type")); + } + + #[test] + fn schema_limit_rejects_extra_versions_for_type() { + let meta = TypeMeta::new( + TypeId::STRUCT as u32, + 9001, + MetaString::get_empty().clone(), + MetaString::get_empty().clone(), + false, + vec![FieldInfo::new( + "a", + FieldType::new(crate::type_id::INT32, false, vec![]), + )], + ) + .unwrap(); + let type_def = meta.get_bytes().to_vec(); + + let config = Config { + max_schema_versions_per_type: 1, + ..Default::default() + }; + let mut resolver = MetaReaderResolver::default(); let mut bytes = vec![]; let mut writer = Writer::from_buffer(&mut bytes); writer.write_var_u32(0); writer.write_bytes(&type_def); + let mut reader = Reader::new(&bytes); + resolver + .read_type_meta(&mut reader, &TypeResolver::default(), &config) + .unwrap(); + let changed = TypeMeta::new( + TypeId::STRUCT as u32, + 9001, + MetaString::get_empty().clone(), + MetaString::get_empty().clone(), + false, + vec![FieldInfo::new( + "b", + FieldType::new(crate::type_id::INT32, false, vec![]), + )], + ) + .unwrap(); + let mut bytes = vec![]; + let mut writer = Writer::from_buffer(&mut bytes); + writer.write_var_u32(0); + writer.write_bytes(changed.get_bytes()); let mut reader = Reader::new(&bytes); - let current = resolver - .read_type_meta(&mut reader, &TypeResolver::default()) + let err = resolver + .read_type_meta(&mut reader, &TypeResolver::default(), &config) + .unwrap_err() + .to_string(); + assert!(err.contains("max_schema_versions_per_type")); + } + + #[test] + fn schema_limit_check_is_not_recorded() { + let config = Config { + max_schema_versions_per_type: 1, + ..Default::default() + }; + let mut resolver = MetaReaderResolver::default(); + let checked = TypeMeta::new( + TypeId::STRUCT as u32, + 9001, + MetaString::get_empty().clone(), + MetaString::get_empty().clone(), + false, + vec![FieldInfo::new( + "a", + FieldType::new(crate::type_id::INT32, false, vec![]), + )], + ) + .unwrap(); + let accepted = TypeMeta::new( + TypeId::STRUCT as u32, + 9001, + MetaString::get_empty().clone(), + MetaString::get_empty().clone(), + false, + vec![FieldInfo::new( + "b", + FieldType::new(crate::type_id::INT32, false, vec![]), + )], + ) + .unwrap(); + + resolver + .check_remote_struct_schema_limit(&checked, &config) .unwrap(); - assert_eq!(current.get_user_type_id(), 9001); - assert_eq!(resolver.parsed_type_infos.len(), MAX_PARSED_NUM_TYPE_DEFS); - assert!(!resolver.parsed_type_infos.contains_key(&meta_header)); - assert!(resolver.last_type_info.is_none()); + let mut bytes = vec![]; + let mut writer = Writer::from_buffer(&mut bytes); + writer.write_var_u32(0); + writer.write_bytes(accepted.get_bytes()); + let mut reader = Reader::new(&bytes); + resolver + .read_type_meta(&mut reader, &TypeResolver::default(), &config) + .unwrap(); } } diff --git a/swift/Sources/Fory/Fory.swift b/swift/Sources/Fory/Fory.swift index 023a08ca7c..600b3b1196 100644 --- a/swift/Sources/Fory/Fory.swift +++ b/swift/Sources/Fory/Fory.swift @@ -22,19 +22,37 @@ public struct Config { public let compatible: Bool public let checkClassVersion: Bool public let maxDepth: Int + public let maxTypeFields: Int + public let maxTypeMetaBytes: Int + public let maxSchemaVersionsPerType: Int + public let maxAverageSchemaVersionsPerType: Int public init( trackRef: Bool = false, compatible: Bool? = nil, checkClassVersion: Bool? = nil, - maxDepth: Int = 5 + maxDepth: Int = 5, + maxTypeFields: Int = 512, + maxTypeMetaBytes: Int = 4096, + maxSchemaVersionsPerType: Int = 10, + maxAverageSchemaVersionsPerType: Int = 3 ) { + precondition(maxTypeFields > 0, "maxTypeFields must be positive") + precondition(maxTypeMetaBytes > 0, "maxTypeMetaBytes must be positive") + precondition(maxSchemaVersionsPerType > 0, "maxSchemaVersionsPerType must be positive") + precondition( + maxAverageSchemaVersionsPerType > 0, + "maxAverageSchemaVersionsPerType must be positive") let effectiveCompatible = compatible ?? true let effectiveCheckClassVersion = checkClassVersion ?? !effectiveCompatible self.trackRef = trackRef self.compatible = effectiveCompatible self.checkClassVersion = effectiveCheckClassVersion self.maxDepth = maxDepth + self.maxTypeFields = maxTypeFields + self.maxTypeMetaBytes = maxTypeMetaBytes + self.maxSchemaVersionsPerType = maxSchemaVersionsPerType + self.maxAverageSchemaVersionsPerType = maxAverageSchemaVersionsPerType } } @@ -53,20 +71,28 @@ public final class Fory { ref: Bool = false, compatible: Bool? = nil, checkClassVersion: Bool? = nil, - maxDepth: Int = 5 + maxDepth: Int = 5, + maxTypeFields: Int = 512, + maxTypeMetaBytes: Int = 4096, + maxSchemaVersionsPerType: Int = 10, + maxAverageSchemaVersionsPerType: Int = 3 ) { self.init( config: Config( trackRef: ref, compatible: compatible, checkClassVersion: checkClassVersion, - maxDepth: maxDepth + maxDepth: maxDepth, + maxTypeFields: maxTypeFields, + maxTypeMetaBytes: maxTypeMetaBytes, + maxSchemaVersionsPerType: maxSchemaVersionsPerType, + maxAverageSchemaVersionsPerType: maxAverageSchemaVersionsPerType )) } public init(config: Config) { self.config = config - self.typeResolver = TypeResolver(trackRef: self.config.trackRef) + self.typeResolver = TypeResolver(config: self.config) self.writeContext = WriteContext( buffer: ByteBuffer(), typeResolver: typeResolver, @@ -79,10 +105,7 @@ public final class Fory { self.readContext = ReadContext( buffer: ByteBuffer(), typeResolver: typeResolver, - trackRef: self.config.trackRef, - compatible: self.config.compatible, - checkClassVersion: self.config.checkClassVersion, - maxDepth: self.config.maxDepth + config: self.config ) } @@ -115,7 +138,8 @@ public final class Fory { } } - public func deserialize(from buffer: ByteBuffer, as _: T.Type = T.self) throws -> T { + public func deserialize(from buffer: ByteBuffer, as _: T.Type = T.self) throws -> T + { try deserializeRoot( from: buffer ) { context in @@ -170,7 +194,8 @@ public final class Fory { @_disfavoredOverload public func deserialize(_ data: Data, as _: (any Serializer).Type = (any Serializer).self) throws - -> any Serializer { + -> any Serializer + { try deserializeRoot( data: data ) { context in @@ -207,7 +232,8 @@ public final class Fory { @_disfavoredOverload public func deserialize(_ data: Data, as _: [String: Any].Type = [String: Any].self) throws - -> [String: Any] { + -> [String: Any] + { try deserializeRoot( data: data ) { context in @@ -225,7 +251,8 @@ public final class Fory { @_disfavoredOverload public func deserialize(_ data: Data, as _: [Int32: Any].Type = [Int32: Any].self) throws - -> [Int32: Any] { + -> [Int32: Any] + { try deserializeRoot( data: data ) { context in @@ -243,7 +270,8 @@ public final class Fory { @_disfavoredOverload public func deserialize(_ data: Data, as _: [AnyHashable: Any].Type = [AnyHashable: Any].self) - throws -> [AnyHashable: Any] { + throws -> [AnyHashable: Any] + { try deserializeRoot( data: data ) { context in @@ -286,7 +314,8 @@ public final class Fory { @_disfavoredOverload public func deserialize(from buffer: ByteBuffer, as _: AnyObject.Type = AnyObject.self) throws - -> AnyObject { + -> AnyObject + { try deserializeRoot( from: buffer ) { context in @@ -338,7 +367,8 @@ public final class Fory { @_disfavoredOverload public func deserialize(from buffer: ByteBuffer, as _: [String: Any].Type = [String: Any].self) - throws -> [String: Any] { + throws -> [String: Any] + { try deserializeRoot( from: buffer ) { context in @@ -364,7 +394,8 @@ public final class Fory { @_disfavoredOverload public func deserialize(from buffer: ByteBuffer, as _: [Int32: Any].Type = [Int32: Any].self) - throws -> [Int32: Any] { + throws -> [Int32: Any] + { try deserializeRoot( from: buffer ) { context in diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index 99c089be61..43ed3f220b 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -26,6 +26,7 @@ public final class ReadContext { public let compatible: Bool public let checkClassVersion: Bool public let maxDepth: Int + private let config: Config public let refReader: RefReader private let compatibleTypeDefTypeInfos = ReusableArray(defaultValue: nil, reserve: 2) private let metaStrings = ReusableArray(defaultValue: nil, reserve: 16) @@ -38,17 +39,15 @@ public final class ReadContext { init( buffer: ByteBuffer, typeResolver: TypeResolver, - trackRef: Bool, - compatible: Bool = false, - checkClassVersion: Bool = true, - maxDepth: Int = 5 + config: Config ) { self.buffer = buffer self.typeResolver = typeResolver - self.trackRef = trackRef - self.compatible = compatible - self.checkClassVersion = checkClassVersion - self.maxDepth = maxDepth + self.trackRef = config.trackRef + self.compatible = config.compatible + self.checkClassVersion = config.checkClassVersion + self.maxDepth = config.maxDepth + self.config = config self.refReader = RefReader() } @@ -201,7 +200,8 @@ public final class ReadContext { "received name-registered type info for id-registered local type") } if namespace.value != localTypeInfo.namespace.value - || typeName.value != localTypeInfo.typeName.value { + || typeName.value != localTypeInfo.typeName.value + { let expectedTypeName = "\(localTypeInfo.namespace.value)::\(localTypeInfo.typeName.value)" let actualTypeName = "\(namespace.value)::\(typeName.value)" throw ForyError.invalidData( @@ -233,7 +233,8 @@ public final class ReadContext { if !checkClassVersion, compatibleTypeDefTypeInfos.isEmpty, !localTypeInfo.typeDefHasUserTypeFields, - let localTypeDefHeader = localTypeInfo.typeDefHeader { + let localTypeDefHeader = localTypeInfo.typeDefHeader + { let indexMarker = try buffer.readVarUInt32() if indexMarker == 0 { let headerStart = buffer.getCursor() @@ -242,22 +243,33 @@ public final class ReadContext { if bodySize == typeMetaSizeMask { bodySize += Int(try buffer.readVarUInt32()) } - if header == localTypeDefHeader { - // Header-cache hits intentionally skip without rehashing. Entries reach this - // cache only after a successful TypeDef parse and 52-bit metadata-hash validation. - compatibleTypeDefTypeInfos.push(localTypeInfo) - try buffer.skip(bodySize) - return nil - } if let cached = typeResolver.getTypeInfo(forHeader: header) { try buffer.skip(bodySize) compatibleTypeDefTypeInfos.push(cached) + if header == localTypeDefHeader, + cached.compatibleTypeMeta === localTypeInfo.typeMeta + { + return nil + } return try validateCompatibleTypeInfo(cached, for: localTypeInfo, wireTypeID: wireTypeID) } buffer.setCursor(headerStart) - let decoded = try TypeMeta.decode(buffer) - let cachedTypeInfo = try typeResolver.cacheTypeInfo(decoded, forHeader: header) + let decoded = try TypeMeta.decode( + buffer, + maxTypeFields: config.maxTypeFields, + maxTypeMetaBytes: config.maxTypeMetaBytes) + let typeMetaEnd = buffer.getCursor() + let cachedTypeInfo = try typeResolver.cacheTypeInfo( + decoded, + forHeader: header, + buffer: buffer, + typeDefStart: headerStart, + typeDefEnd: typeMetaEnd + ) compatibleTypeDefTypeInfos.push(cachedTypeInfo) + if cachedTypeInfo === localTypeInfo { + return nil + } return try validateCompatibleTypeInfo( cachedTypeInfo, for: localTypeInfo, wireTypeID: wireTypeID) } @@ -303,8 +315,18 @@ public final class ReadContext { } buffer.setCursor(typeMetaStart) - let decoded = try TypeMeta.decode(buffer) - let cachedTypeInfo = try typeResolver.cacheTypeInfo(decoded, forHeader: header) + let decoded = try TypeMeta.decode( + buffer, + maxTypeFields: config.maxTypeFields, + maxTypeMetaBytes: config.maxTypeMetaBytes) + let typeMetaEnd = buffer.getCursor() + let cachedTypeInfo = try typeResolver.cacheTypeInfo( + decoded, + forHeader: header, + buffer: buffer, + typeDefStart: typeMetaStart, + typeDefEnd: typeMetaEnd + ) compatibleTypeDefTypeInfos.push(cachedTypeInfo) return cachedTypeInfo } @@ -318,7 +340,8 @@ public final class ReadContext { let compatibleTypeDefTypeInfos = self.compatibleTypeDefTypeInfos let remoteTypeInfo: TypeInfo if compatibleTypeDefTypeInfos.isEmpty, - let localTypeDefHeader = localTypeInfo.typeDefHeader { + localTypeInfo.typeDefHeader != nil + { let indexMarker = try buffer.readVarUInt32() if indexMarker != 0 { remoteTypeInfo = try readCompatibleTypeInfo(afterMarker: indexMarker) @@ -330,22 +353,24 @@ public final class ReadContext { bodySize += Int(try buffer.readVarUInt32()) } - if header == localTypeDefHeader { - // Header-cache hits intentionally skip without rehashing. Entries reach this - // cache only after a successful TypeDef parse and 52-bit metadata-hash validation. - compatibleTypeDefTypeInfos.push(localTypeInfo) - try buffer.skip(bodySize) - return localTypeInfo - } - if let cached = typeResolver.getTypeInfo(forHeader: header) { try buffer.skip(bodySize) compatibleTypeDefTypeInfos.push(cached) remoteTypeInfo = cached } else { buffer.setCursor(headerStart) - let decoded = try TypeMeta.decode(buffer) - remoteTypeInfo = try typeResolver.cacheTypeInfo(decoded, forHeader: header) + let decoded = try TypeMeta.decode( + buffer, + maxTypeFields: config.maxTypeFields, + maxTypeMetaBytes: config.maxTypeMetaBytes) + let typeMetaEnd = buffer.getCursor() + remoteTypeInfo = try typeResolver.cacheTypeInfo( + decoded, + forHeader: header, + buffer: buffer, + typeDefStart: headerStart, + typeDefEnd: typeMetaEnd + ) compatibleTypeDefTypeInfos.push(remoteTypeInfo) } } @@ -365,7 +390,8 @@ public final class ReadContext { throw ForyError.invalidData("compatible type metadata is required") } if let localTypeMeta = localTypeInfo.typeMeta, - remoteTypeMeta === localTypeMeta { + remoteTypeMeta === localTypeMeta + { return localTypeInfo } if remoteTypeMeta.registerByName { @@ -407,7 +433,8 @@ public final class ReadContext { registerByName: localTypeInfo.registerByName, compatible: compatible, evolving: localTypeInfo.evolving - ) { + ) + { throw ForyError.typeMismatch(expected: wireTypeID.rawValue, actual: remoteTypeID) } return remoteTypeInfo @@ -571,15 +598,15 @@ public final class ReadContext { } } -public extension ReadContext { - func readAny( +extension ReadContext { + public func readAny( refMode: RefMode, readTypeInfo: Bool = true ) throws -> Any? { try SerializableAny.foryRead(self, refMode: refMode, readTypeInfo: readTypeInfo).anyValue() } - func readListOfAny( + public func readListOfAny( refMode: RefMode, readTypeInfo: Bool = false ) throws -> [Any]? { @@ -591,7 +618,7 @@ public extension ReadContext { return wrapped?.map { $0.anyValueForCollection() } } - func readMapStringToAny( + public func readMapStringToAny( refMode: RefMode, readTypeInfo: Bool = false ) throws -> [String: Any]? { @@ -611,7 +638,7 @@ public extension ReadContext { return map } - func readMapInt32ToAny( + public func readMapInt32ToAny( refMode: RefMode, readTypeInfo: Bool = false ) throws -> [Int32: Any]? { @@ -631,7 +658,7 @@ public extension ReadContext { return map } - func readMapAnyHashableToAny( + public func readMapAnyHashableToAny( refMode: RefMode, readTypeInfo: Bool = false ) throws -> [AnyHashable: Any]? { diff --git a/swift/Sources/Fory/TypeMeta.swift b/swift/Sources/Fory/TypeMeta.swift index fe4e53a047..10e3d0f35a 100644 --- a/swift/Sources/Fory/TypeMeta.swift +++ b/swift/Sources/Fory/TypeMeta.swift @@ -325,11 +325,22 @@ public final class TypeMeta: Equatable, @unchecked Sendable { return Array(buffer.storage.prefix(buffer.count)) } - public static func decode(_ bytes: [UInt8]) throws -> TypeMeta { - try decode(ByteBuffer(bytes: bytes)) + public static func decode( + _ bytes: [UInt8], + maxTypeFields: Int = 512, + maxTypeMetaBytes: Int = 4096 + ) throws -> TypeMeta { + try decode( + ByteBuffer(bytes: bytes), + maxTypeFields: maxTypeFields, + maxTypeMetaBytes: maxTypeMetaBytes) } - public static func decode(_ buffer: ByteBuffer) throws -> TypeMeta { + public static func decode( + _ buffer: ByteBuffer, + maxTypeFields: Int = 512, + maxTypeMetaBytes: Int = 4096 + ) throws -> TypeMeta { let header = try buffer.readUInt64() if (header & typeMetaReservedFlags) != 0 { throw ForyError.invalidData("invalid TypeMeta global header") @@ -340,6 +351,11 @@ public final class TypeMeta: Equatable, @unchecked Sendable { if metaSize == Int(typeMetaSizeMask) { metaSize += Int(try buffer.readVarUInt32()) } + if metaSize > maxTypeMetaBytes { + throw ForyError.invalidData( + "Type metadata body size \(metaSize) exceeds maxTypeMetaBytes \(maxTypeMetaBytes). The data may be malicious. If the data is not malicious, please increase maxTypeMetaBytes." + ) + } let encodedBody = try buffer.readBytes(count: metaSize) if compressed { @@ -364,6 +380,11 @@ public final class TypeMeta: Equatable, @unchecked Sendable { if numFields == smallNumFieldsThreshold { numFields += Int(try bodyReader.readVarUInt32()) } + if numFields > maxTypeFields { + throw ForyError.invalidData( + "Type metadata field count \(numFields) exceeds maxTypeFields \(maxTypeFields). The data may be malicious. If the data is not malicious, please increase maxTypeFields." + ) + } if registerByName { typeID = compatible ? TypeId.namedCompatibleStruct.rawValue : TypeId.namedStruct.rawValue } else { diff --git a/swift/Sources/Fory/TypeResolver.swift b/swift/Sources/Fory/TypeResolver.swift index 06963509a4..6800bc5efe 100644 --- a/swift/Sources/Fory/TypeResolver.swift +++ b/swift/Sources/Fory/TypeResolver.swift @@ -132,7 +132,8 @@ private func encodedTypeDefHeaderHash(_ bytes: [UInt8]) throws -> UInt64 { private func fieldNeedsTypeInfo(_ fieldType: TypeMeta.FieldType) -> Bool { if let typeID = TypeId(rawValue: fieldType.typeID), - TypeId.needsTypeInfoForField(typeID) { + TypeId.needsTypeInfoForField(typeID) + { return true } return fieldType.generics.contains { fieldNeedsTypeInfo($0) } @@ -143,7 +144,8 @@ private func encodedTypeDefHasUserTypeFields(_ fields: [TypeMeta.FieldInfo]) -> } @inline(__always) -private func readRegisteredValue(_ context: ReadContext, as type: T.Type) throws -> T { +private func readRegisteredValue(_ context: ReadContext, as type: T.Type) throws -> T +{ try T.foryRead( context, refMode: T.isRefType ? .tracking : .none, @@ -382,7 +384,8 @@ public final class TypeInfo: @unchecked Sendable { } if context.compatible && (compatibleWireTypeID == .compatibleStruct - || compatibleWireTypeID == .namedCompatibleStruct) { + || compatibleWireTypeID == .namedCompatibleStruct) + { return try compatibleReader(context, self) } if compatibleTypeMeta !== typeMeta { @@ -399,9 +402,9 @@ private struct TypeNameKey: Hashable { } final class TypeResolver { - private static let maxCachedTypeDefHeaders = 8192 + private static let minRemoteStructSchemaLimit = 8192 - private let trackRef: Bool + private let config: Config private var registrationFinished = false private var bySwiftType = UInt64Map(initialCapacity: 64) @@ -409,9 +412,11 @@ final class TypeResolver { private var byTypeName: [TypeNameKey: TypeInfo] = [:] private var builtinTypeInfoByID: [TypeInfo?] = [] private var typeInfoByHeader = UInt64Map(initialCapacity: 64) + private var remoteSchemaVersionsByType: [String: Int] = [:] + private var totalAcceptedSchemaVersions = 0 - init(trackRef: Bool = false) { - self.trackRef = trackRef + init(config: Config) { + self.config = config } func finishRegistration() { @@ -448,7 +453,7 @@ final class TypeResolver { evolving: evolving, namespace: MetaString.empty(specialChar1: ".", specialChar2: "_"), typeName: MetaString.empty(specialChar1: "$", specialChar2: "_"), - fields: T.foryFieldsInfo(trackRef: trackRef), + fields: T.foryFieldsInfo(trackRef: config.trackRef), reader: { context in try readRegisteredValue(context, as: T.self) }, @@ -464,7 +469,8 @@ final class TypeResolver { registerByName: false, evolving: evolving, typeName: (namespace: "", name: "") - ) { + ) + { return } @@ -504,7 +510,7 @@ final class TypeResolver { evolving: evolving, namespace: namespaceMeta, typeName: typeNameMeta, - fields: T.foryFieldsInfo(trackRef: trackRef), + fields: T.foryFieldsInfo(trackRef: config.trackRef), reader: { context in try readRegisteredValue(context, as: T.self) }, @@ -520,7 +526,8 @@ final class TypeResolver { registerByName: true, evolving: evolving, typeName: (namespace: namespace, name: typeName) - ) { + ) + { return } @@ -556,17 +563,25 @@ final class TypeResolver { } @inline(__always) - func cacheTypeInfo(_ typeMeta: TypeMeta, forHeader header: UInt64) throws -> TypeInfo { + func cacheTypeInfo( + _ typeMeta: TypeMeta, + forHeader header: UInt64, + buffer: ByteBuffer, + typeDefStart: Int, + typeDefEnd: Int + ) throws -> TypeInfo { if let cached = typeInfoByHeader.value(for: header) { return cached } let localTypeInfo = try requireTypeInfo(for: typeMeta) - if header == localTypeInfo.typeDefHeader { - if typeInfoByHeader.count < Self.maxCachedTypeDefHeaders { - typeInfoByHeader.set(localTypeInfo, for: header) - } + if let localTypeDefBytes = localTypeInfo.typeDefBytes, + typeDefEnd - typeDefStart == localTypeDefBytes.count, + buffer.matchesBytes(start: typeDefStart, bytes: localTypeDefBytes) + { + typeInfoByHeader.set(localTypeInfo, for: header) return localTypeInfo } + let remoteSchemaKey = try checkRemoteStructSchemaLimit(typeMeta) let canonicalTypeMeta: TypeMeta if let localTypeMeta = localTypeInfo.typeMeta { canonicalTypeMeta = try typeMeta.assigningFieldIDs(from: localTypeMeta) @@ -574,12 +589,63 @@ final class TypeResolver { canonicalTypeMeta = typeMeta } let typeInfo = TypeInfo(dynamic: localTypeInfo, compatibleTypeMeta: canonicalTypeMeta) - if typeInfoByHeader.count < Self.maxCachedTypeDefHeaders { - typeInfoByHeader.set(typeInfo, for: header) - } + typeInfoByHeader.set(typeInfo, for: header) + recordRemoteStructSchema(remoteSchemaKey) return typeInfo } + @inline(never) + private func checkRemoteStructSchemaLimit(_ typeMeta: TypeMeta) throws -> String? { + guard let rawTypeID = typeMeta.typeID, + let typeID = TypeId(rawValue: rawTypeID) + else { + return nil + } + switch typeID { + case .structType, .compatibleStruct, .namedStruct, .namedCompatibleStruct: + break + default: + return nil + } + + let key: String + if typeMeta.registerByName { + key = "n\(typeMeta.namespace.value)\0\(typeMeta.typeName.value)" + } else { + key = "i\(typeMeta.userTypeID ?? UInt32.max)" + } + + let versionsForType = remoteSchemaVersionsByType[key] ?? 0 + let maxSchemaVersionsPerType = config.maxSchemaVersionsPerType + if versionsForType >= maxSchemaVersionsPerType { + throw ForyError.invalidData( + "remote struct schema versions for one type exceeded maxSchemaVersionsPerType=\(maxSchemaVersionsPerType)" + ) + } + let acceptedStructTypeCount = + versionsForType == 0 ? remoteSchemaVersionsByType.count + 1 : remoteSchemaVersionsByType.count + let maxAverageSchemaVersionsPerType = config.maxAverageSchemaVersionsPerType + let globalLimit = max( + Self.minRemoteStructSchemaLimit, + acceptedStructTypeCount * maxAverageSchemaVersionsPerType + ) + if totalAcceptedSchemaVersions >= globalLimit { + throw ForyError.invalidData( + "remote struct schema versions exceeded global limit from maxAverageSchemaVersionsPerType=\(maxAverageSchemaVersionsPerType)" + ) + } + return key + } + + private func recordRemoteStructSchema(_ key: String?) { + guard let key else { + return + } + let versionsForType = remoteSchemaVersionsByType[key] ?? 0 + remoteSchemaVersionsByType[key] = versionsForType + 1 + totalAcceptedSchemaVersions += 1 + } + private func store( _ typeInfo: TypeInfo, for swiftTypeID: ObjectIdentifier, @@ -594,7 +660,8 @@ final class TypeResolver { byTypeName[typeNameKey] = typeInfo } if let typeMeta = typeInfo.typeMeta, - let typeDefHeader = typeInfo.typeDefHeader { + let typeDefHeader = typeInfo.typeDefHeader + { typeInfoByHeader.set( TypeInfo( dynamic: typeInfo, @@ -674,7 +741,8 @@ final class TypeResolver { ) } if existing.typeID != T.staticTypeId || existing.namespace.value != namespace - || existing.typeName.value != typeName { + || existing.typeName.value != typeName + { throw ForyError.invalidData( """ \(type) registration conflict: existing name=\(existing.namespace.value)::\(existing.typeName.value), \ diff --git a/swift/Tests/ForyTests/CollectionSerializerTests.swift b/swift/Tests/ForyTests/CollectionSerializerTests.swift index 0945e82939..c6312ae5e0 100644 --- a/swift/Tests/ForyTests/CollectionSerializerTests.swift +++ b/swift/Tests/ForyTests/CollectionSerializerTests.swift @@ -454,8 +454,8 @@ func collectionSerializersRejectMalformedPrimitivePayloads() throws { int16Buffer.writeBytes([0x01, 0x02, 0x03]) let int16Context = ReadContext( buffer: int16Buffer, - typeResolver: TypeResolver(trackRef: false), - trackRef: false + typeResolver: TypeResolver(config: Config(trackRef: false)), + config: Config(trackRef: false) ) do { let _: [Int16] = try ArrayFieldCodec.readPayload(int16Context) @@ -469,8 +469,8 @@ func collectionSerializersRejectMalformedPrimitivePayloads() throws { float64Buffer.writeBytes([0x01, 0x02, 0x03, 0x04]) let float64Context = ReadContext( buffer: float64Buffer, - typeResolver: TypeResolver(trackRef: false), - trackRef: false + typeResolver: TypeResolver(config: Config(trackRef: false)), + config: Config(trackRef: false) ) do { let _: [Double] = try ArrayFieldCodec.readPayload(float64Context) diff --git a/swift/Tests/ForyTests/DateTimeTests.swift b/swift/Tests/ForyTests/DateTimeTests.swift index 0f8bba02a5..05d3e56099 100644 --- a/swift/Tests/ForyTests/DateTimeTests.swift +++ b/swift/Tests/ForyTests/DateTimeTests.swift @@ -86,7 +86,8 @@ func localDateConvenienceMethodsExposeEpochAndCalendarViews() throws { @Test func dateAndTimestampContextHelpersUseExpectedWireProtocols() throws { let xlangWriteBuffer = ByteBuffer() - let xlangTypeResolver = TypeResolver(trackRef: false) + let xlangConfig = Config(trackRef: false, compatible: true, checkClassVersion: true, maxDepth: 5) + let xlangTypeResolver = TypeResolver(config: xlangConfig) let xlangWriteContext = WriteContext( buffer: xlangWriteBuffer, typeResolver: xlangTypeResolver, @@ -109,10 +110,7 @@ func dateAndTimestampContextHelpersUseExpectedWireProtocols() throws { let xlangReadContext = ReadContext( buffer: ByteBuffer(data: xlangWriteBuffer.copyToData()), typeResolver: xlangTypeResolver, - trackRef: false, - compatible: true, - checkClassVersion: true, - maxDepth: 5 + config: xlangConfig ) let xlangLocalDateDecoded = try xlangReadContext.readLocalDate(refMode: RefMode.nullOnly, readTypeInfo: true) #expect(xlangLocalDateDecoded == xlangLocalDate) @@ -132,10 +130,7 @@ func dateAndTimestampContextHelpersUseExpectedWireProtocols() throws { let timestampReadContext = ReadContext( buffer: ByteBuffer(data: timestampBuffer.copyToData()), typeResolver: xlangTypeResolver, - trackRef: false, - compatible: true, - checkClassVersion: true, - maxDepth: 5 + config: xlangConfig ) let timestampDecoded = try timestampReadContext.readTimestamp(refMode: RefMode.nullOnly, readTypeInfo: true) #expect(abs(timestampDecoded.timeIntervalSince1970 - instant.timeIntervalSince1970) < 0.000_001) diff --git a/swift/Tests/ForyTests/EnumTests.swift b/swift/Tests/ForyTests/EnumTests.swift index 24f69f3411..823e3a9a11 100644 --- a/swift/Tests/ForyTests/EnumTests.swift +++ b/swift/Tests/ForyTests/EnumTests.swift @@ -97,7 +97,7 @@ func unionDefaultUsesKnownCase() { @Test func unionCaseIdZeroIsKnownCase() throws { let buffer = ByteBuffer() - let typeResolver = TypeResolver(trackRef: false) + let typeResolver = TypeResolver(config: Config(trackRef: false)) let writeContext = WriteContext(buffer: buffer, typeResolver: typeResolver, trackRef: false) try ForwardStringOrLong.text("zero").foryWriteData(writeContext, hasGenerics: false) buffer.flip() @@ -106,7 +106,7 @@ func unionCaseIdZeroIsKnownCase() throws { let context = ReadContext( buffer: buffer, typeResolver: typeResolver, - trackRef: false + config: Config(trackRef: false) ) #expect(try ForwardStringOrLong.foryReadData(context) == .text("zero")) diff --git a/swift/Tests/ForyTests/ForySwiftTests.swift b/swift/Tests/ForyTests/ForySwiftTests.swift index 0aac15035f..f1c38e45d6 100644 --- a/swift/Tests/ForyTests/ForySwiftTests.swift +++ b/swift/Tests/ForyTests/ForySwiftTests.swift @@ -301,7 +301,7 @@ func floatingSpecialsRoundTrip() throws { -.infinity, .leastNonzeroMagnitude, .greatestFiniteMagnitude, - Float(bitPattern: 0x7FC0_1234) + Float(bitPattern: 0x7FC0_1234), ] for value in floatValues { let decoded: Float = try fory.deserialize(try fory.serialize(value)) @@ -315,7 +315,7 @@ func floatingSpecialsRoundTrip() throws { -.infinity, .leastNonzeroMagnitude, .greatestFiniteMagnitude, - Double(bitPattern: 0x7FF8_0000_0000_1234) + Double(bitPattern: 0x7FF8_0000_0000_1234), ] for value in doubleValues { let decoded: Double = try fory.deserialize(try fory.serialize(value)) @@ -329,7 +329,7 @@ func floatingSpecialsRoundTrip() throws { .init(bitPattern: 0xFC00), .init(bitPattern: 0x0001), .init(bitPattern: 0x7BFF), - .init(bitPattern: 0x7E11) + .init(bitPattern: 0x7E11), ] for value in float16Values { let decoded: Float16 = try fory.deserialize(try fory.serialize(value)) @@ -342,7 +342,7 @@ func floatingSpecialsRoundTrip() throws { .init(rawValue: 0x7F80), .init(rawValue: 0xFF80), .init(rawValue: 0x0001), - .init(rawValue: 0x7FC1) + .init(rawValue: 0x7FC1), ] for value in bfloat16Values { let decoded: BFloat16 = try fory.deserialize(try fory.serialize(value)) @@ -357,18 +357,47 @@ func namedInitializerBuildsConfig() { #expect(defaultConfig.config.compatible == true) #expect(defaultConfig.config.checkClassVersion == false) #expect(defaultConfig.config.maxDepth == 5) - - let explicitConfig = Fory(ref: true, compatible: true, maxDepth: 7) + #expect(defaultConfig.config.maxTypeFields == 512) + #expect(defaultConfig.config.maxTypeMetaBytes == 4096) + #expect(defaultConfig.config.maxSchemaVersionsPerType == 10) + #expect(defaultConfig.config.maxAverageSchemaVersionsPerType == 3) + + let explicitConfig = Fory( + ref: true, + compatible: true, + maxDepth: 7, + maxTypeFields: 31, + maxTypeMetaBytes: 1234, + maxSchemaVersionsPerType: 12, + maxAverageSchemaVersionsPerType: 4 + ) #expect(explicitConfig.config.trackRef == true) #expect(explicitConfig.config.compatible == true) #expect(explicitConfig.config.checkClassVersion == false) #expect(explicitConfig.config.maxDepth == 7) - - let configInit = Fory(config: .init(trackRef: false, compatible: true, maxDepth: 9)) + #expect(explicitConfig.config.maxTypeFields == 31) + #expect(explicitConfig.config.maxTypeMetaBytes == 1234) + #expect(explicitConfig.config.maxSchemaVersionsPerType == 12) + #expect(explicitConfig.config.maxAverageSchemaVersionsPerType == 4) + + let configInit = Fory( + config: .init( + trackRef: false, + compatible: true, + maxDepth: 9, + maxTypeFields: 41, + maxTypeMetaBytes: 2048, + maxSchemaVersionsPerType: 14, + maxAverageSchemaVersionsPerType: 5 + )) #expect(configInit.config.trackRef == false) #expect(configInit.config.compatible == true) #expect(configInit.config.checkClassVersion == false) #expect(configInit.config.maxDepth == 9) + #expect(configInit.config.maxTypeFields == 41) + #expect(configInit.config.maxTypeMetaBytes == 2048) + #expect(configInit.config.maxSchemaVersionsPerType == 14) + #expect(configInit.config.maxAverageSchemaVersionsPerType == 5) let schemaConsistentDirect = Fory(ref: true, compatible: false) let schemaConsistentViaConfig = Fory(config: Config(trackRef: true, compatible: false)) @@ -475,28 +504,147 @@ func primitiveArrayTypeIDs() throws { } @Test -func typeDefHeaderCacheStopsPublishingAtCapacity() throws { - let resolver = TypeResolver() +func typeMetaFieldLimitRejectsLargeStruct() throws { + let fieldType = TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) + let meta = try TypeMeta( + typeID: TypeId.structType.rawValue, + userTypeID: 901, + namespace: .empty(specialChar1: ".", specialChar2: "_"), + typeName: .empty(specialChar1: "$", specialChar2: "_"), + registerByName: false, + fields: [ + TypeMeta.FieldInfo(fieldID: nil, fieldName: "first", fieldType: fieldType), + TypeMeta.FieldInfo(fieldID: nil, fieldName: "second", fieldType: fieldType) + ] + ) + let encoded = try meta.encode() + + #expect(throws: (any Error).self) { + _ = try TypeMeta.decode(encoded, maxTypeFields: 1) + } +} + +@Test +func typeMetaBodyLimitRejectsLargeMetadata() throws { + let meta = try TypeMeta( + typeID: TypeId.structType.rawValue, + userTypeID: 901, + namespace: .empty(specialChar1: ".", specialChar2: "_"), + typeName: .empty(specialChar1: "$", specialChar2: "_"), + registerByName: false, + fields: [ + TypeMeta.FieldInfo( + fieldID: nil, + fieldName: "value", + fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false)) + ] + ) + let encoded = try meta.encode() + + #expect(throws: (any Error).self) { + _ = try TypeMeta.decode(encoded, maxTypeMetaBytes: 1) + } +} + +@Test +func schemaLimitTracksStructTypesSeparately() throws { + let resolver = TypeResolver(config: Config(maxSchemaVersionsPerType: 1)) resolver.register(Person.self, id: 901) - let typeInfo = try resolver.requireTypeInfo(for: Person.self) - let typeMeta = try #require(typeInfo.typeMeta) - let localHeader = try #require(typeInfo.typeDefHeader) - #expect(resolver.getTypeInfo(forHeader: localHeader) != nil) - - var header = UInt64(0x0100_0000_0000_0000) - var inserted = 0 - while inserted < 8191 { - if header != localHeader { - _ = try resolver.cacheTypeInfo(typeMeta, forHeader: header) - inserted += 1 - } - header += 1 + resolver.register(Address.self, id: 902) + + func remoteTypeMeta(userTypeID: UInt32, fieldName: String) throws -> TypeMeta { + try TypeMeta( + typeID: TypeId.structType.rawValue, + userTypeID: userTypeID, + namespace: .empty(specialChar1: ".", specialChar2: "_"), + typeName: .empty(specialChar1: "$", specialChar2: "_"), + registerByName: false, + fields: [ + TypeMeta.FieldInfo( + fieldID: nil, + fieldName: fieldName, + fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) + ) + ] + ) + } + + func cache(_ typeMeta: TypeMeta) throws { + let encoded = try typeMeta.encode() + let headerReader = ByteBuffer(bytes: encoded) + let header = try headerReader.readUInt64() + let buffer = ByteBuffer(bytes: encoded) + let decoded = try TypeMeta.decode(buffer) + _ = try resolver.cacheTypeInfo( + decoded, + forHeader: header, + buffer: buffer, + typeDefStart: 0, + typeDefEnd: buffer.getCursor() + ) + } + + try cache(remoteTypeMeta(userTypeID: 901, fieldName: "remoteA")) + try cache(remoteTypeMeta(userTypeID: 902, fieldName: "remoteA")) + #expect(throws: (any Error).self) { + try cache(remoteTypeMeta(userTypeID: 901, fieldName: "remoteB")) + } +} + +@Test +func failedSchemaDoesNotConsumeLimit() throws { + let resolver = TypeResolver(config: Config(maxSchemaVersionsPerType: 1)) + resolver.register(Person.self, id: 901) + + func remoteTypeMeta(fieldName: String, fieldType: TypeMeta.FieldType) throws -> TypeMeta { + try TypeMeta( + typeID: TypeId.structType.rawValue, + userTypeID: 901, + namespace: .empty(specialChar1: ".", specialChar2: "_"), + typeName: .empty(specialChar1: "$", specialChar2: "_"), + registerByName: false, + fields: [ + TypeMeta.FieldInfo( + fieldID: nil, + fieldName: fieldName, + fieldType: fieldType + ) + ] + ) } - let uncachedHeader = header == localHeader ? header + 1 : header - let current = try resolver.cacheTypeInfo(typeMeta, forHeader: uncachedHeader) - #expect(current.compatibleTypeMeta != nil) - #expect(resolver.getTypeInfo(forHeader: uncachedHeader) == nil) + func cache(_ typeMeta: TypeMeta) throws { + let encoded = try typeMeta.encode() + let headerReader = ByteBuffer(bytes: encoded) + let header = try headerReader.readUInt64() + let buffer = ByteBuffer(bytes: encoded) + let decoded = try TypeMeta.decode(buffer) + _ = try resolver.cacheTypeInfo( + decoded, + forHeader: header, + buffer: buffer, + typeDefStart: 0, + typeDefEnd: buffer.getCursor() + ) + } + + #expect(throws: (any Error).self) { + try cache(remoteTypeMeta( + fieldName: "id", + fieldType: TypeMeta.FieldType( + typeID: TypeId.map.rawValue, + nullable: false, + generics: [ + TypeMeta.FieldType(typeID: TypeId.string.rawValue, nullable: false), + TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false), + ] + ) + )) + } + try cache(remoteTypeMeta( + fieldName: "remoteA", + fieldType: TypeMeta.FieldType(typeID: TypeId.int32.rawValue, nullable: false) + )) } @Test @@ -585,7 +733,7 @@ func dynamicUserTypesDecodeByID() throws { @Test func duplicateNameRegistrationIsRejected() throws { - let resolver = TypeResolver(trackRef: false) + let resolver = TypeResolver(config: Config(trackRef: false)) try resolver.register(Address.self, namespace: "demo", typeName: "entity") do { @@ -596,7 +744,7 @@ func duplicateNameRegistrationIsRejected() throws { @Test func nameRegistrationSplitsLastDot() throws { - let resolver = TypeResolver(trackRef: false) + let resolver = TypeResolver(config: Config(trackRef: false)) try resolver.register(Address.self, name: "com.example.Address") let info = try resolver.requireTypeInfo(namespace: "com.example", typeName: "Address") @@ -606,7 +754,7 @@ func nameRegistrationSplitsLastDot() throws { @Test func nameRegistrationAllowsSimpleName() throws { - let resolver = TypeResolver(trackRef: false) + let resolver = TypeResolver(config: Config(trackRef: false)) try resolver.register(Address.self, name: "Address") let info = try resolver.requireTypeInfo(namespace: "", typeName: "Address") @@ -634,7 +782,7 @@ func nameRegistrationRejectsTrailingDot() throws { @Test func splitNameRegistrationRejectsDottedTypeName() throws { - let resolver = TypeResolver(trackRef: false) + let resolver = TypeResolver(config: Config(trackRef: false)) #expect(throws: ForyError.self) { try resolver.register(Address.self, namespace: "com", typeName: "example.Address") @@ -771,7 +919,7 @@ func macroDynamicAnyObjectAndAnySerializerFieldsRoundTrip() throws { items: [Int32(11), Address(street: "Nested", zip: 10002)], map: [ "age": Int64(19), - "address": Address(street: "Mapped", zip: 10003) + "address": Address(street: "Mapped", zip: 10003), ] ) let serializerData = try fory.serialize(serializerHolder) @@ -824,13 +972,13 @@ func macroAnyFieldsRoundTrip() throws { "count": Int64(3), "name": "map", "address": Address(street: "AnyMap", zip: 11003), - "empty": NSNull() + "empty": NSNull(), ], int32Map: [ 1: Int32(-9), 2: "v2", 3: Address(street: "AnyIntMap", zip: 11004), - 4: NSNull() + 4: NSNull(), ] ) let data = try fory.serialize(value) @@ -911,7 +1059,7 @@ func macroFieldOrderFollowsForyRules() throws { let second = try buffer.readVarInt64() let third = try buffer.readVarInt32() - let tailContext = ReadContext(buffer: buffer, typeResolver: fory.typeResolver, trackRef: false) + let tailContext = ReadContext(buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) let fourth = try String.foryReadData(tailContext) #expect(first == value.shortValue) @@ -939,7 +1087,7 @@ func macroTaggedFieldsKeepGroupedPayloadOrder() throws { _ = try buffer.readInt32() #expect(try buffer.readVarInt32() == value.intValue) - let tailContext = ReadContext(buffer: buffer, typeResolver: fory.typeResolver, trackRef: false) + let tailContext = ReadContext(buffer: buffer, typeResolver: fory.typeResolver, config: fory.config) #expect(try String.foryReadData(tailContext) == value.textTail) } @@ -947,7 +1095,10 @@ func macroTaggedFieldsKeepGroupedPayloadOrder() throws { func macroNonPrimitiveFieldsSortByFieldIdentifier() throws { let fields = NonPrimitiveFieldOrder.foryFieldsInfo(trackRef: false) - #expect(fields.map(\.fieldName) == ["intValue", "mapValue", "stringValue", "addressValue", "binaryValue"]) + #expect( + fields.map(\.fieldName) == [ + "intValue", "mapValue", "stringValue", "addressValue", "binaryValue", + ]) #expect(fields.map(\.fieldID) == [nil, 10, 20, nil, nil]) } @@ -1013,7 +1164,7 @@ func macroReducedPrecisionFieldsUseXlangTypeIDs() { TypeId.float16.rawValue, TypeId.bfloat16.rawValue, TypeId.bfloat16Array.rawValue, - TypeId.float16Array.rawValue + TypeId.float16Array.rawValue, ]) } @@ -1088,7 +1239,7 @@ func compatibleNestedStructArrayRoundTrip() throws { let value = CompatibleNestedArrayHolder( items: [ CompatibleNestedItem(id: 1, name: "alpha"), - CompatibleNestedItem(id: 2, name: "beta") + CompatibleNestedItem(id: 2, name: "beta"), ] ) let bytes = try writer.serialize(value) @@ -1110,7 +1261,7 @@ func compatibleNestedStructOptionalArrayRoundTrip() throws { items: [ CompatibleNestedItem(id: 1, name: "alpha"), nil, - CompatibleNestedItem(id: 2, name: "beta") + CompatibleNestedItem(id: 2, name: "beta"), ] ) let bytes = try writer.serialize(value) @@ -1131,7 +1282,7 @@ func compatibleNestedStructMapRoundTrip() throws { let value = CompatibleNestedMapHolder( items: [ 1: CompatibleNestedItem(id: 10, name: "first"), - 2: CompatibleNestedItem(id: 20, name: "second") + 2: CompatibleNestedItem(id: 20, name: "second"), ] ) let bytes = try writer.serialize(value) @@ -1161,7 +1312,7 @@ func pvlVarInt64AndVarUInt64Extremes() throws { 72_057_594_037_927_935, 72_057_594_037_927_936, UInt64(Int64.max), - UInt64.max + UInt64.max, ] let intValues: [Int64] = [ Int64.min, @@ -1178,7 +1329,7 @@ func pvlVarInt64AndVarUInt64Extremes() throws { 1_000_000, 1_000_000_000_000, Int64.max - 1, - Int64.max + Int64.max, ] let writeBuffer = ByteBuffer() @@ -1264,7 +1415,7 @@ func typeMetaRoundTripByName() throws { nullable: true, generics: [ .init(typeID: TypeId.string.rawValue, nullable: false), - .init(typeID: TypeId.varint32.rawValue, nullable: true) + .init(typeID: TypeId.varint32.rawValue, nullable: true), ] ) ), @@ -1272,7 +1423,7 @@ func typeMetaRoundTripByName() throws { fieldID: 7, fieldName: "ignored_for_tag_mode", fieldType: .init(typeID: TypeId.varint32.rawValue, nullable: false) - ) + ), ] let meta = try TypeMeta( diff --git a/swift/Tests/ForyTests/StringSerializerTests.swift b/swift/Tests/ForyTests/StringSerializerTests.swift index 47f08f6fcc..d5a17c9315 100644 --- a/swift/Tests/ForyTests/StringSerializerTests.swift +++ b/swift/Tests/ForyTests/StringSerializerTests.swift @@ -40,15 +40,15 @@ private func makeStringReadContext(payload: [UInt8], encoding: ManualStringEncod buffer.writeBytes(payload) return ReadContext( buffer: buffer, - typeResolver: TypeResolver(trackRef: false), - trackRef: false + typeResolver: TypeResolver(config: Config(trackRef: false)), + config: Config(trackRef: false) ) } private func stringPayloadBytes(for value: String) throws -> [UInt8] { let context = WriteContext( buffer: ByteBuffer(), - typeResolver: TypeResolver(trackRef: false), + typeResolver: TypeResolver(config: Config(trackRef: false)), trackRef: false ) try value.foryWriteData(context, hasGenerics: false) @@ -83,8 +83,8 @@ func stringSerializerRoundTripsUnicodeAndLengthBoundaries() throws { let context = ReadContext( buffer: ByteBuffer(bytes: payload), - typeResolver: TypeResolver(trackRef: false), - trackRef: false + typeResolver: TypeResolver(config: Config(trackRef: false)), + config: Config(trackRef: false) ) #expect(try String.foryReadData(context) == value) } @@ -135,8 +135,8 @@ func stringSerializerRejectsInvalidPayloads() throws { unsupportedEncodingBuffer.writeVarUInt36Small(3) let unsupportedEncoding = ReadContext( buffer: unsupportedEncodingBuffer, - typeResolver: TypeResolver(trackRef: false), - trackRef: false + typeResolver: TypeResolver(config: Config(trackRef: false)), + config: Config(trackRef: false) ) do { _ = try String.foryReadData(unsupportedEncoding) diff --git a/swift/Tests/ForyTests/UnsignedTests.swift b/swift/Tests/ForyTests/UnsignedTests.swift index 67941645ef..5d34ae28f9 100644 --- a/swift/Tests/ForyTests/UnsignedTests.swift +++ b/swift/Tests/ForyTests/UnsignedTests.swift @@ -76,52 +76,52 @@ func unsignedFieldCodecsPreserveExpectedWireWidths() throws { let fixed32 = UInt32.max let fixed32Context = WriteContext( buffer: ByteBuffer(), - typeResolver: TypeResolver(trackRef: false), + typeResolver: TypeResolver(config: Config(trackRef: false)), trackRef: false ) UInt32FixedCodec.writePayload(fixed32, fixed32Context) #expect(fixed32Context.buffer.count == 4) let fixed32Decoded = try UInt32FixedCodec.readPayload( - ReadContext(buffer: fixed32Context.buffer, typeResolver: TypeResolver(trackRef: false), trackRef: false) + ReadContext(buffer: fixed32Context.buffer, typeResolver: TypeResolver(config: Config(trackRef: false)), config: Config(trackRef: false)) ) #expect(fixed32Decoded == fixed32) let fixed64 = UInt64.max let fixed64Context = WriteContext( buffer: ByteBuffer(), - typeResolver: TypeResolver(trackRef: false), + typeResolver: TypeResolver(config: Config(trackRef: false)), trackRef: false ) UInt64FixedCodec.writePayload(fixed64, fixed64Context) #expect(fixed64Context.buffer.count == 8) let fixed64Decoded = try UInt64FixedCodec.readPayload( - ReadContext(buffer: fixed64Context.buffer, typeResolver: TypeResolver(trackRef: false), trackRef: false) + ReadContext(buffer: fixed64Context.buffer, typeResolver: TypeResolver(config: Config(trackRef: false)), config: Config(trackRef: false)) ) #expect(fixed64Decoded == fixed64) let compactTagged = UInt64(Int32.max) let compactContext = WriteContext( buffer: ByteBuffer(), - typeResolver: TypeResolver(trackRef: false), + typeResolver: TypeResolver(config: Config(trackRef: false)), trackRef: false ) UInt64TaggedCodec.writePayload(compactTagged, compactContext) #expect(compactContext.buffer.count == 4) let compactDecoded = try UInt64TaggedCodec.readPayload( - ReadContext(buffer: compactContext.buffer, typeResolver: TypeResolver(trackRef: false), trackRef: false) + ReadContext(buffer: compactContext.buffer, typeResolver: TypeResolver(config: Config(trackRef: false)), config: Config(trackRef: false)) ) #expect(compactDecoded == compactTagged) let wideTagged = UInt64(Int32.max) + 1 let wideContext = WriteContext( buffer: ByteBuffer(), - typeResolver: TypeResolver(trackRef: false), + typeResolver: TypeResolver(config: Config(trackRef: false)), trackRef: false ) UInt64TaggedCodec.writePayload(wideTagged, wideContext) #expect(wideContext.buffer.count == 9) let wideDecoded = try UInt64TaggedCodec.readPayload( - ReadContext(buffer: wideContext.buffer, typeResolver: TypeResolver(trackRef: false), trackRef: false) + ReadContext(buffer: wideContext.buffer, typeResolver: TypeResolver(config: Config(trackRef: false)), config: Config(trackRef: false)) ) #expect(wideDecoded == wideTagged) }