Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2054,19 +2054,29 @@ def _compute_column_type(self, column):
sample_value: Representative value for type inference and modified_row.
min_val: Minimum for integers (None otherwise).
max_val: Maximum for integers (None otherwise).
max_decimal_formatted_len: Maximum len(format(d, 'f')) across all
Decimal values in the column (0 when no Decimals are present).
Used by executemany to correct the SQL_VARCHAR column size when
the sample value's formatted string is shorter than another
value's (e.g. positive sample vs negative row value) (GH-557).
"""
non_nulls = [v for v in column if v is not None]
if not non_nulls:
return None, None, None
return None, None, None, 0

int_values = [v for v in non_nulls if isinstance(v, int)]
if int_values:
min_val, max_val = min(int_values), max(int_values)
sample_value = max(int_values, key=abs)
return sample_value, min_val, max_val
return sample_value, min_val, max_val, 0

sample_value = None
max_decimal_formatted_len = 0
for v in non_nulls:
if isinstance(v, decimal.Decimal):
max_decimal_formatted_len = max(
max_decimal_formatted_len, len(format(v, "f"))
)
if not sample_value:
sample_value = v
elif isinstance(v, (str, bytes, bytearray)) and isinstance(
Expand Down Expand Up @@ -2120,7 +2130,7 @@ def _compute_column_type(self, column):
# If comparing Decimal to non-Decimal, prefer Decimal for better type inference
sample_value = v

return sample_value, None, None
return sample_value, None, None, max_decimal_formatted_len

def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-statements
self, operation: str, seq_of_parameters: Union[List[Sequence[Any]], List[Mapping[str, Any]]]
Expand Down Expand Up @@ -2225,7 +2235,7 @@ def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-s
if hasattr(seq_of_parameters, "__getitem__")
else []
)
sample_value, min_val, max_val = self._compute_column_type(column)
sample_value, min_val, max_val, _ = self._compute_column_type(column)

if self._inputsizes and col_index < len(self._inputsizes):
# Use explicitly set input sizes
Expand Down Expand Up @@ -2301,7 +2311,9 @@ def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-s
if hasattr(seq_of_parameters, "__getitem__")
else []
)
sample_value, min_val, max_val = self._compute_column_type(column)
sample_value, min_val, max_val, max_decimal_len = self._compute_column_type(
column
)

dummy_row = list(sample_row)
paraminfo = self._create_parameter_types_list(
Expand All @@ -2322,6 +2334,17 @@ def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-s
paraminfo.paramSQLType = ddbc_sql_const.SQL_VARCHAR.value
paraminfo.columnSize = 1

# Correct column size for Decimal columns sent as SQL_VARCHAR (GH-557).
# The sample value's formatted string may be shorter than another
# row's (e.g. positive sample "1.0" = 3 chars vs negative "-0.1" = 4).
# max_decimal_len was already computed during _compute_column_type
# so no extra iteration is needed.
if (
paraminfo.paramSQLType == ddbc_sql_const.SQL_VARCHAR.value
and max_decimal_len > paraminfo.columnSize
):
paraminfo.columnSize = max_decimal_len

# Special handling for binary data in auto-detected types
if paraminfo.paramSQLType in (
ddbc_sql_const.SQL_BINARY.value,
Expand Down
48 changes: 48 additions & 0 deletions tests/test_004_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2304,6 +2304,54 @@ def test_executemany_Decimal_list(cursor, db_connection):
db_connection.commit()


def test_executemany_decimal_sign_change(cursor, db_connection):
"""Test executemany with decimals that change signs (GH-557).

When the sample value chosen for column sizing is shorter than a negative
value in the batch, the formatted string (with a leading '-') can exceed
the allocated column_size, causing a RuntimeError.
"""
try:
cursor.execute("CREATE TABLE #pytest_decimal_sign (col_1 DECIMAL(28, 14))")

# Case 1: negative first, then positive — previously worked
data1 = [(decimal.Decimal("-0.1"),), (decimal.Decimal("1.0"),)]
cursor.executemany("INSERT INTO #pytest_decimal_sign VALUES (?)", data1)

# Case 2: positive first, then negative — previously failed
data2 = [(decimal.Decimal("0.1"),), (decimal.Decimal("-0.1"),)]
cursor.executemany("INSERT INTO #pytest_decimal_sign VALUES (?)", data2)

# Case 3: positive then negative with different integer parts
data3 = [(decimal.Decimal("1.0"),), (decimal.Decimal("-0.1"),)]
cursor.executemany("INSERT INTO #pytest_decimal_sign VALUES (?)", data3)

# Case 4: multiple sign changes in a single batch
data4 = [
(decimal.Decimal("100.5"),),
(decimal.Decimal("-0.001"),),
(decimal.Decimal("0.5"),),
(decimal.Decimal("-999.99"),),
]
cursor.executemany("INSERT INTO #pytest_decimal_sign VALUES (?)", data4)

db_connection.commit()

# Verify row count
cursor.execute("SELECT COUNT(*) FROM #pytest_decimal_sign")
count = cursor.fetchone()[0]
assert count == 10

# Verify data correctness for the originally-failing case
cursor.execute("SELECT col_1 FROM #pytest_decimal_sign ORDER BY col_1")
rows = [row[0] for row in cursor.fetchall()]
assert decimal.Decimal("-999.99") in [r.quantize(decimal.Decimal("0.01")) for r in rows]
assert decimal.Decimal("0.1") in [r.quantize(decimal.Decimal("0.1")) for r in rows]
finally:
cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_sign")
db_connection.commit()


def test_executemany_DecimalString_list(cursor, db_connection):
"""Test executemany with an string of decimal parameter list."""
try:
Expand Down
Loading