diff --git a/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceDBRecord.java b/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceDBRecord.java index 44131a01b..d43f4d6ed 100644 --- a/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceDBRecord.java +++ b/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceDBRecord.java @@ -30,11 +30,14 @@ import java.lang.reflect.InvocationTargetException; import java.math.BigDecimal; import java.nio.ByteBuffer; +import java.sql.Blob; +import java.sql.Clob; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; +import java.sql.Struct; import java.sql.Timestamp; import java.sql.Types; import java.time.LocalDateTime; @@ -106,7 +109,7 @@ record = recordBuilder.build(); @Override protected void handleField(ResultSet resultSet, StructuredRecord.Builder recordBuilder, Schema.Field field, int columnIndex, int sqlType, int sqlPrecision, int sqlScale) throws SQLException { - if (OracleSourceSchemaReader.ORACLE_TYPES.contains(sqlType) || sqlType == Types.NCLOB) { + if (OracleSourceSchemaReader.ORACLE_TYPES.contains(sqlType) || sqlType == Types.NCLOB || sqlType == Types.STRUCT) { handleOracleSpecificType(resultSet, recordBuilder, field, columnIndex, sqlType, sqlPrecision, sqlScale); } else { setField(resultSet, recordBuilder, field, columnIndex, sqlType, sqlPrecision, sqlScale); @@ -257,6 +260,31 @@ private byte[] getBfileBytes(ResultSet resultSet, String columnName) throws SQLE } } + private byte[] getBfileBytes(Object bfile) throws SQLException { + if (bfile == null) { + return null; + } + try { + ClassLoader classLoader = bfile.getClass().getClassLoader(); + Class oracleBfileClass = classLoader.loadClass("oracle.jdbc.OracleBfile"); + boolean isFileExist = (boolean) oracleBfileClass.getMethod("fileExists").invoke(bfile); + if (!isFileExist) { + return null; + } + + oracleBfileClass.getMethod("openFile").invoke(bfile); + InputStream binaryStream = (InputStream) oracleBfileClass.getMethod("getBinaryStream").invoke(bfile); + byte[] bytes = ByteStreams.toByteArray(binaryStream); + oracleBfileClass.getMethod("closeFile").invoke(bfile); + return bytes; + } catch (ClassNotFoundException | InvocationTargetException | NoSuchMethodException | IllegalAccessException e) { + throw new InvalidStageException("Field is of type 'BFILE', which is not supported " + + "with this version of the JDBC driver.", e); + } catch (IOException e) { + throw new InvalidStageException("Error reading the contents of the BFILE.", e); + } + } + private void handleOracleSpecificType(ResultSet resultSet, StructuredRecord.Builder recordBuilder, Schema.Field field, int columnIndex, int sqlType, int precision, int scale) throws SQLException { @@ -341,6 +369,12 @@ private void handleOracleSpecificType(ResultSet resultSet, StructuredRecord.Buil case OracleSourceSchemaReader.LONG_RAW: recordBuilder.set(field.getName(), resultSet.getBytes(columnIndex)); break; + case Types.STRUCT: + Struct structValue = (Struct) resultSet.getObject(columnIndex); + if (structValue != null) { + recordBuilder.set(field.getName(), convertStructToRecord(structValue, nonNullSchema, resultSet)); + } + break; case Types.DECIMAL: case Types.NUMERIC: // This is the only way to differentiate FLOAT/REAL columns from other numeric columns, that based on NUMBER. @@ -371,6 +405,87 @@ private void handleOracleSpecificType(ResultSet resultSet, StructuredRecord.Buil } } + private StructuredRecord convertStructToRecord(Struct struct, Schema schema, ResultSet resultSet) + throws SQLException { + Object[] attributes = struct.getAttributes(); + List fields = schema.getFields(); + StructuredRecord.Builder builder = StructuredRecord.builder(schema); + + for (int index = 0; index < attributes.length; index++) { + Schema.Field field = fields.get(index); + Object attrValue = attributes[index]; + + if (attrValue == null) { + builder.set(field.getName(), null); + continue; + } + // If it is an internal nested STRUCT, recurse down + if (attrValue instanceof Struct) { + Schema fieldSchema = field.getSchema().isNullable() ? field.getSchema().getNonNullable() : field.getSchema(); + builder.set(field.getName(), convertStructToRecord((Struct) attrValue, fieldSchema, resultSet)); + continue; + } + + String attrClassName = attrValue.getClass().getName(); + Schema fieldSchema = field.getSchema().isNullable() ? field.getSchema().getNonNullable() : field.getSchema(); + if (attrValue instanceof BigDecimal) { + if (Schema.LogicalType.DECIMAL.equals(fieldSchema.getLogicalType())) { + builder.setDecimal(field.getName(), ((BigDecimal) attrValue).setScale(getScale(field.getSchema()), + java.math.RoundingMode.HALF_UP)); + } else if (Schema.Type.DOUBLE.equals(fieldSchema.getType())) { + builder.set(field.getName(), ((BigDecimal) attrValue).doubleValue()); + } else if (Schema.Type.FLOAT.equals(fieldSchema.getType())) { + builder.set(field.getName(), ((BigDecimal) attrValue).floatValue()); + } else if (Schema.Type.INT.equals(fieldSchema.getType())) { + builder.set(field.getName(), ((BigDecimal) attrValue).intValue()); + } else if (Schema.Type.LONG.equals(fieldSchema.getType())) { + builder.set(field.getName(), ((BigDecimal) attrValue).longValue()); + } else { + builder.set(field.getName(), attrValue.toString()); + } + } else if (attrValue instanceof Timestamp) { + Timestamp timestamp = (Timestamp) attrValue; + if (Schema.LogicalType.DATETIME.equals(fieldSchema.getLogicalType())) { + builder.setDateTime(field.getName(), timestamp.toLocalDateTime()); + } else if (Schema.LogicalType.DATE.equals(fieldSchema.getLogicalType())) { + builder.setDate(field.getName(), timestamp.toLocalDateTime().toLocalDate()); + } else { + builder.set(field.getName(), attrValue.toString()); + } + } else if (attrValue instanceof OffsetDateTime || attrValue instanceof ZonedDateTime) { + ZonedDateTime zonedDateTime = (attrValue instanceof OffsetDateTime) + ? ((OffsetDateTime) attrValue).atZoneSameInstant(ZoneId.of("UTC")) + : ((ZonedDateTime) attrValue).withZoneSameInstant(ZoneId.of("UTC")); + if (fieldSchema.getLogicalType() != null && + (Schema.LogicalType.TIMESTAMP_MICROS.equals(fieldSchema.getLogicalType()) || + Schema.LogicalType.TIMESTAMP_MILLIS.equals(fieldSchema.getLogicalType()))) { + builder.setTimestamp(field.getName(), zonedDateTime); + } else if (Schema.Type.LONG.equals(fieldSchema.getType())) { + builder.set(field.getName(), zonedDateTime.toInstant().toEpochMilli()); + } else { + builder.set(field.getName(), zonedDateTime.toString()); + } + } else if (attrValue instanceof Clob) { + Clob clob = (Clob) attrValue; + builder.set(field.getName(), clob.getSubString(1, (int) clob.length())); + } else if (attrValue instanceof Blob) { + Blob blob = (Blob) attrValue; + builder.set(field.getName(), blob.getBytes(1, (int) blob.length())); + } else if ("oracle.jdbc.OracleBfile".equals(attrClassName)) { + builder.set(field.getName(), getBfileBytes(attrValue)); + } else if (attrValue instanceof byte[]) { + byte[] bytesValue = (byte[]) attrValue; + builder.set(field.getName(), bytesValue); + } else if ("oracle.sql.INTERVALDS".equals(attrClassName) + || "oracle.sql.INTERVALYM".equals(attrClassName)) { + builder.set(field.getName(), attrValue.toString()); + } else { + builder.set(field.getName(), attrValue); + } + } + return builder.build(); + } + /** * Get the scale set in Non-nullable schema associated with the schema * */ diff --git a/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceSchemaReader.java b/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceSchemaReader.java index 208b70410..1e095ffa8 100644 --- a/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceSchemaReader.java +++ b/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceSchemaReader.java @@ -18,14 +18,22 @@ import com.google.common.collect.ImmutableSet; import io.cdap.cdap.api.data.schema.Schema; +import io.cdap.cdap.api.exception.ErrorCategory; +import io.cdap.cdap.api.exception.ErrorType; +import io.cdap.cdap.api.exception.ErrorUtils; import io.cdap.plugin.db.CommonSchemaReader; import org.jetbrains.annotations.NotNull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Types; +import java.util.ArrayList; +import java.util.List; import java.util.Set; import javax.annotation.Nullable; @@ -70,6 +78,7 @@ public class OracleSourceSchemaReader extends CommonSchemaReader { private final Boolean isTimestampOldBehavior; private final Boolean isPrecisionlessNumAsDecimal; private final Boolean isTimestampLtzFieldTimestamp; + private Connection connection; public OracleSourceSchemaReader() { this(null, false, false, false); @@ -136,11 +145,143 @@ public Schema getSchema(ResultSetMetaData metadata, int index) throws SQLExcepti } return Schema.decimalOf(precision, scale); } + case Types.STRUCT: + if (connection == null) { + throw new SQLException("Cannot resolve STRUCT schema without a database connection. " + + "Use getSchemaFields(ResultSet) to enable STRUCT type resolution."); + } + String typeName = metadata.getColumnTypeName(index); + String owner = typeName.substring(0, typeName.lastIndexOf('.')); + return getStructSchema(connection, typeName, owner); default: return super.getSchema(metadata, index); } } + @Override + public List getSchemaFields(ResultSet resultSet) throws SQLException { + this.connection = resultSet.getStatement().getConnection(); + return super.getSchemaFields(resultSet); + } + + /** + * Builds a CDAP RECORD schema for an Oracle STRUCT type by querying the + * database metadata + * for the type's attributes. + * + * @param connection the database connection + * @param typeName the Oracle type name (e.g., "ADDRESS_TYPE") + * @return a CDAP RECORD schema with fields corresponding to the STRUCT's + * attributes + */ + private Schema getStructSchema(Connection connection, String typeName, String owner) throws SQLException { + List fields = new ArrayList<>(); + String sql = "SELECT * FROM ALL_TYPE_ATTRS WHERE TYPE_NAME = ? AND OWNER = ? ORDER BY ATTR_NO"; + + try (PreparedStatement stmt = connection.prepareStatement(sql)) { + + stmt.setString(1, typeName.substring(typeName.lastIndexOf('.') + 1)); + stmt.setString(2, owner); + + try (ResultSet attrRs = stmt.executeQuery()) { + while (attrRs.next()) { + String attrName = attrRs.getString("ATTR_NAME"); + String attrTypeName = attrRs.getString("ATTR_TYPE_NAME"); + int attrSize = attrRs.getInt("PRECISION"); + int attrScale = attrRs.getInt("SCALE"); + + Schema attrSchema = mapPrimitiveOracleType(attrTypeName, attrSize, attrScale, attrName); + if (attrSchema != null) { + fields.add(Schema.Field.of(attrName, attrSchema)); + } else { + String nestedStructOwner = attrRs.getString("ATTR_TYPE_OWNER"); + Schema nestedSchema = getStructSchema(connection, attrTypeName, nestedStructOwner); + fields.add(Schema.Field.of(attrName, nestedSchema)); + } + } + } + } + if (fields.isEmpty()) { + throw new SQLException(String.format( + "No attributes found for Oracle STRUCT type '%s'. " + + "Ensure the type exists and is accessible.", + typeName)); + } + + return Schema.recordOf(typeName, fields); + } + + private Schema mapPrimitiveOracleType(String typeName, int precision, int scale, String columnName) { + switch (typeName) { + case "TIMESTAMP WITH TZ": + return isTimestampOldBehavior ? Schema.of(Schema.Type.STRING) : Schema.of(Schema.LogicalType.TIMESTAMP_MICROS); + case "TIMESTAMP WITH LTZ": + return getTimestampLtzSchema(); + case "TIMESTAMP": + return isTimestampOldBehavior ? + Schema.of(Schema.LogicalType.TIMESTAMP_MICROS) : Schema.of(Schema.LogicalType.DATETIME); + case "DATE": + return Schema.of(Schema.LogicalType.DATE); + case "TIME": + return Schema.of(Schema.LogicalType.TIME_MICROS); + case "BINARY FLOAT": + case "REAL": + case "FLOAT": + return Schema.of(Schema.Type.FLOAT); + case "BINARY DOUBLE": + case "DOUBLE": + return Schema.of(Schema.Type.DOUBLE); + case "BFILE": + case "BLOB": + case "RAW": + case "LONG RAW": + return Schema.of(Schema.Type.BYTES); + case "INTERVAL DAY TO SECOND": + case "INTERVAL YEAR TO MONTH": + case "VARCHAR2": + case "VARCHAR": + case "CHAR": + case "CHAR2": + case "CLOB": + case "NCLOB": + case "LONG": + return Schema.of(Schema.Type.STRING); + case "INTEGER": + return Schema.of(Schema.Type.INT); + case "NUMBER": + case "DECIMAL": + if (Double.class.getTypeName().equals(typeName)) { + return Schema.of(Schema.Type.DOUBLE); + } else { + if (precision == 0) { + if (isPrecisionlessNumAsDecimal) { + precision = 38; + scale = 0; + LOG.warn(String.format("%s type with undefined precision and scale is detected, " + + "there may be a precision loss while running the pipeline. " + + "Please define an output precision and scale for field to avoid " + + "precision loss.", typeName)); + return Schema.decimalOf(precision, scale); + } else { + LOG.warn(String.format("%s type without precision and scale, " + + "converting into STRING type to avoid any precision loss.", + typeName)); + return Schema.of(Schema.Type.STRING); + } + } + return Schema.decimalOf(precision, scale); + } + case "ARRAY": + case "OTHER": + case "XML": + String errorMessage = String.format("Column %s has unsupported SQL type of %s.", columnName, typeName); + throw ErrorUtils.getProgramFailureException(new ErrorCategory(ErrorCategory.ErrorCategoryEnum.PLUGIN), + errorMessage, errorMessage, ErrorType.SYSTEM, true, null); + default: + return null; + } + } + private @NotNull Schema getTimestampLtzSchema() { return isTimestampOldBehavior || isTimestampLtzFieldTimestamp ? Schema.of(Schema.LogicalType.TIMESTAMP_MICROS) diff --git a/oracle-plugin/src/test/java/io/cdap/plugin/oracle/OracleSchemaReaderTest.java b/oracle-plugin/src/test/java/io/cdap/plugin/oracle/OracleSchemaReaderTest.java index 1ff77c533..8898e8ab4 100644 --- a/oracle-plugin/src/test/java/io/cdap/plugin/oracle/OracleSchemaReaderTest.java +++ b/oracle-plugin/src/test/java/io/cdap/plugin/oracle/OracleSchemaReaderTest.java @@ -20,13 +20,15 @@ import io.cdap.cdap.api.data.schema.Schema; import org.junit.Assert; import org.junit.Test; -import org.junit.runner.RunWith; import org.mockito.Mockito; -import org.mockito.junit.MockitoJUnitRunner; +import java.sql.Connection; +import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; +import java.sql.Statement; +import java.sql.Types; import java.util.List; public class OracleSchemaReaderTest { @@ -37,6 +39,12 @@ public void getSchema_timestampLTZFieldTrue_returnTimestamp() throws SQLExceptio ResultSet resultSet = Mockito.mock(ResultSet.class); ResultSetMetaData metadata = Mockito.mock(ResultSetMetaData.class); + Statement statement = Mockito.mock(Statement.class); + Connection connection = Mockito.mock(Connection.class); + + Mockito.when(resultSet.getMetaData()).thenReturn(metadata); + Mockito.when(resultSet.getStatement()).thenReturn(statement); + Mockito.when(statement.getConnection()).thenReturn(connection); Mockito.when(resultSet.getMetaData()).thenReturn(metadata); @@ -68,9 +76,12 @@ public void getSchema_timestampLTZFieldFalse_returnDatetime() throws SQLExceptio ResultSet resultSet = Mockito.mock(ResultSet.class); ResultSetMetaData metadata = Mockito.mock(ResultSetMetaData.class); + Statement statement = Mockito.mock(Statement.class); + Connection connection = Mockito.mock(Connection.class); Mockito.when(resultSet.getMetaData()).thenReturn(metadata); - + Mockito.when(resultSet.getStatement()).thenReturn(statement); + Mockito.when(statement.getConnection()).thenReturn(connection); Mockito.when(metadata.getColumnCount()).thenReturn(2); // -101 is for TIMESTAMP_TZ Mockito.when(metadata.getColumnType(1)).thenReturn(-101); @@ -91,4 +102,44 @@ public void getSchema_timestampLTZFieldFalse_returnDatetime() throws SQLExceptio Assert.assertEquals(expectedSchemaFields.get(1).getName(), actualSchemaFields.get(1).getName()); Assert.assertEquals(expectedSchemaFields.get(1).getSchema(), actualSchemaFields.get(1).getSchema()); } + + @Test + public void getSchemaFields_structType_returnRecord() throws SQLException { + OracleSourceSchemaReader schemaReader = new OracleSourceSchemaReader(); + ResultSet resultSet = Mockito.mock(ResultSet.class); + ResultSetMetaData metadata = Mockito.mock(ResultSetMetaData.class); + Statement statement = Mockito.mock(Statement.class); + Connection connection = Mockito.mock(Connection.class); + PreparedStatement stmt = Mockito.mock(PreparedStatement.class); + ResultSet attrRs = Mockito.mock(ResultSet.class); + Mockito.when(resultSet.getMetaData()).thenReturn(metadata); + Mockito.when(resultSet.getStatement()).thenReturn(statement); + Mockito.when(statement.getConnection()).thenReturn(connection); + Mockito.when(connection.prepareStatement(Mockito.anyString())).thenReturn(stmt); + Mockito.when(stmt.executeQuery()).thenReturn(attrRs); + Mockito.when(metadata.getColumnCount()).thenReturn(1); + Mockito.when(metadata.getColumnType(1)).thenReturn(Types.STRUCT); + Mockito.when(metadata.getColumnName(1)).thenReturn("address"); + Mockito.when(metadata.getColumnTypeName(1)).thenReturn("CS_ITN.ADDRESS_TYPE"); + Mockito.when(metadata.getSchemaName(1)).thenReturn("TEST_SCHEMA"); + Mockito.when(attrRs.next()).thenReturn(true, true, false); + Mockito.when(attrRs.getString("ATTR_NAME")).thenReturn("STREET", "CITY"); + Mockito.when(attrRs.getString("ATTR_TYPE_NAME")).thenReturn("VARCHAR2", "VARCHAR2"); + Mockito.when(attrRs.getInt("PRECISION")).thenReturn(0, 0); + Mockito.when(attrRs.getInt("SCALE")).thenReturn(0, 0); + + List actualFields = schemaReader.getSchemaFields(resultSet); + + Schema.Field addressField = actualFields.get(0); + Schema addressSchema = addressField.getSchema().isNullable() + ? addressField.getSchema().getNonNullable() : addressField.getSchema(); + List structFields = addressSchema.getFields(); + Assert.assertEquals(1, actualFields.size()); + Assert.assertEquals("address", addressField.getName()); + Assert.assertEquals(Schema.Type.RECORD, addressSchema.getType()); + Assert.assertEquals("CS_ITN.ADDRESS_TYPE", addressSchema.getRecordName()); + Assert.assertEquals(2, structFields.size()); + Assert.assertEquals("STREET", structFields.get(0).getName()); + Assert.assertEquals("CITY", structFields.get(1).getName()); + } } diff --git a/oracle-plugin/src/test/java/io/cdap/plugin/oracle/OracleSourceDBRecordUnitTest.java b/oracle-plugin/src/test/java/io/cdap/plugin/oracle/OracleSourceDBRecordUnitTest.java index 77136e841..244e84ea6 100644 --- a/oracle-plugin/src/test/java/io/cdap/plugin/oracle/OracleSourceDBRecordUnitTest.java +++ b/oracle-plugin/src/test/java/io/cdap/plugin/oracle/OracleSourceDBRecordUnitTest.java @@ -25,6 +25,7 @@ import org.mockito.junit.MockitoJUnitRunner; import java.math.BigDecimal; +import java.sql.Date; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.Timestamp; @@ -234,4 +235,5 @@ public void validateTimestampTZTypeNullHandling() throws Exception { StructuredRecord record = builder.build(); Assert.assertNull(record.get("field1")); } + }