diff --git a/lib/mcp/tool.rb b/lib/mcp/tool.rb index 37c5d927..d7cce0ca 100644 --- a/lib/mcp/tool.rb +++ b/lib/mcp/tool.rb @@ -93,23 +93,30 @@ def icons(value = NOT_SET) def input_schema(value = NOT_SET) if value == NOT_SET input_schema_value - elsif value.is_a?(Hash) - @input_schema_value = InputSchema.new(value) - elsif value.is_a?(InputSchema) - @input_schema_value = value + elsif (schema = coerce_schema(value, InputSchema)) + @input_schema_value = schema end end def output_schema(value = NOT_SET) if value == NOT_SET output_schema_value - elsif value.is_a?(Hash) - @output_schema_value = OutputSchema.new(value) - elsif value.is_a?(OutputSchema) - @output_schema_value = value + elsif (schema = coerce_schema(value, OutputSchema)) + @output_schema_value = schema end end + def coerce_schema(value, schema_class) + case value + when Hash + schema_class.new(value) + when schema_class + value + end + end + + private :coerce_schema + def meta(value = NOT_SET) if value == NOT_SET @meta_value diff --git a/lib/mcp/tool/input_schema.rb b/lib/mcp/tool/input_schema.rb index 724f9bbc..d2e50bab 100644 --- a/lib/mcp/tool/input_schema.rb +++ b/lib/mcp/tool/input_schema.rb @@ -18,10 +18,7 @@ def missing_required_arguments(arguments) end def validate_arguments(arguments) - errors = fully_validate(arguments) - if errors.any? - raise ValidationError, "Invalid arguments: #{errors.join(", ")}" - end + fully_validate!(arguments, "arguments") end end end diff --git a/lib/mcp/tool/output_schema.rb b/lib/mcp/tool/output_schema.rb index 8bf7ed93..5f4167e1 100644 --- a/lib/mcp/tool/output_schema.rb +++ b/lib/mcp/tool/output_schema.rb @@ -8,10 +8,7 @@ class OutputSchema < Schema class ValidationError < StandardError; end def validate_result(result) - errors = fully_validate(result) - if errors.any? - raise ValidationError, "Invalid result: #{errors.join(", ")}" - end + fully_validate!(result, "result") end end end diff --git a/lib/mcp/tool/schema.rb b/lib/mcp/tool/schema.rb index 98a137ab..406fd81f 100644 --- a/lib/mcp/tool/schema.rb +++ b/lib/mcp/tool/schema.rb @@ -31,8 +31,11 @@ def to_h private - def fully_validate(data) - JSON::Validator.fully_validate(schema_for_validation, data) + def fully_validate!(payload, label) + errors = JSON::Validator.fully_validate(schema_for_validation, payload) + if errors.any? + raise self.class::ValidationError, "Invalid #{label}: #{errors.join(", ")}" + end end def validate_schema!