diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 05324875..5cec03bd 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2054,19 +2054,27 @@ def _compute_column_type(self, column): sample_value: Representative value for type inference and modified_row. min_val: Minimum for integers (None otherwise). max_val: Maximum for integers (None otherwise). + max_decimal_formatted_len: Maximum len(format(d, 'f')) across all + Decimal values in the column (0 when no Decimals are present). + Used by executemany to correct the SQL_VARCHAR column size when + the sample value's formatted string is shorter than another + value's (e.g. positive sample vs negative row value) (GH-557). """ non_nulls = [v for v in column if v is not None] if not non_nulls: - return None, None, None + return None, None, None, 0 int_values = [v for v in non_nulls if isinstance(v, int)] if int_values: min_val, max_val = min(int_values), max(int_values) sample_value = max(int_values, key=abs) - return sample_value, min_val, max_val + return sample_value, min_val, max_val, 0 sample_value = None + max_decimal_formatted_len = 0 for v in non_nulls: + if isinstance(v, decimal.Decimal): + max_decimal_formatted_len = max(max_decimal_formatted_len, len(format(v, "f"))) if not sample_value: sample_value = v elif isinstance(v, (str, bytes, bytearray)) and isinstance( @@ -2120,7 +2128,7 @@ def _compute_column_type(self, column): # If comparing Decimal to non-Decimal, prefer Decimal for better type inference sample_value = v - return sample_value, None, None + return sample_value, None, None, max_decimal_formatted_len def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-statements self, operation: str, seq_of_parameters: Union[List[Sequence[Any]], List[Mapping[str, Any]]] @@ -2225,7 +2233,7 @@ def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-s if hasattr(seq_of_parameters, "__getitem__") else [] ) - sample_value, min_val, max_val = self._compute_column_type(column) + sample_value, min_val, max_val, _ = self._compute_column_type(column) if self._inputsizes and col_index < len(self._inputsizes): # Use explicitly set input sizes @@ -2301,7 +2309,7 @@ def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-s if hasattr(seq_of_parameters, "__getitem__") else [] ) - sample_value, min_val, max_val = self._compute_column_type(column) + sample_value, min_val, max_val, max_decimal_len = self._compute_column_type(column) dummy_row = list(sample_row) paraminfo = self._create_parameter_types_list( @@ -2322,6 +2330,17 @@ def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-s paraminfo.paramSQLType = ddbc_sql_const.SQL_VARCHAR.value paraminfo.columnSize = 1 + # Correct column size for Decimal columns sent as SQL_VARCHAR (GH-557). + # The sample value's formatted string may be shorter than another + # row's (e.g. positive sample "1.0" = 3 chars vs negative "-0.1" = 4). + # max_decimal_len was already computed during _compute_column_type + # so no extra iteration is needed. + if ( + paraminfo.paramSQLType == ddbc_sql_const.SQL_VARCHAR.value + and max_decimal_len > paraminfo.columnSize + ): + paraminfo.columnSize = max_decimal_len + # Special handling for binary data in auto-detected types if paraminfo.paramSQLType in ( ddbc_sql_const.SQL_BINARY.value, diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 17e06961..214964e6 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -2304,6 +2304,54 @@ def test_executemany_Decimal_list(cursor, db_connection): db_connection.commit() +def test_executemany_decimal_sign_change(cursor, db_connection): + """Test executemany with decimals that change signs (GH-557). + + When the sample value chosen for column sizing is shorter than a negative + value in the batch, the formatted string (with a leading '-') can exceed + the allocated column_size, causing a RuntimeError. + """ + try: + cursor.execute("CREATE TABLE #pytest_decimal_sign (col_1 DECIMAL(28, 14))") + + # Case 1: negative first, then positive — previously worked + data1 = [(decimal.Decimal("-0.1"),), (decimal.Decimal("1.0"),)] + cursor.executemany("INSERT INTO #pytest_decimal_sign VALUES (?)", data1) + + # Case 2: positive first, then negative — previously failed + data2 = [(decimal.Decimal("0.1"),), (decimal.Decimal("-0.1"),)] + cursor.executemany("INSERT INTO #pytest_decimal_sign VALUES (?)", data2) + + # Case 3: positive then negative with different integer parts + data3 = [(decimal.Decimal("1.0"),), (decimal.Decimal("-0.1"),)] + cursor.executemany("INSERT INTO #pytest_decimal_sign VALUES (?)", data3) + + # Case 4: multiple sign changes in a single batch + data4 = [ + (decimal.Decimal("100.5"),), + (decimal.Decimal("-0.001"),), + (decimal.Decimal("0.5"),), + (decimal.Decimal("-999.99"),), + ] + cursor.executemany("INSERT INTO #pytest_decimal_sign VALUES (?)", data4) + + db_connection.commit() + + # Verify row count + cursor.execute("SELECT COUNT(*) FROM #pytest_decimal_sign") + count = cursor.fetchone()[0] + assert count == 10 + + # Verify data correctness for the originally-failing case + cursor.execute("SELECT col_1 FROM #pytest_decimal_sign ORDER BY col_1") + rows = [row[0] for row in cursor.fetchall()] + assert decimal.Decimal("-999.99") in [r.quantize(decimal.Decimal("0.01")) for r in rows] + assert decimal.Decimal("0.1") in [r.quantize(decimal.Decimal("0.1")) for r in rows] + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_sign") + db_connection.commit() + + def test_executemany_DecimalString_list(cursor, db_connection): """Test executemany with an string of decimal parameter list.""" try: