diff --git a/common/src/main/java/dev/cel/common/values/BaseProtoCelValueConverter.java b/common/src/main/java/dev/cel/common/values/BaseProtoCelValueConverter.java index 6851deed5..9fc218abe 100644 --- a/common/src/main/java/dev/cel/common/values/BaseProtoCelValueConverter.java +++ b/common/src/main/java/dev/cel/common/values/BaseProtoCelValueConverter.java @@ -17,8 +17,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; -import com.google.common.base.CaseFormat; -import com.google.common.base.Joiner; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -29,7 +27,6 @@ import com.google.protobuf.BytesValue; import com.google.protobuf.DoubleValue; import com.google.protobuf.Duration; -import com.google.protobuf.FieldMask; import com.google.protobuf.FloatValue; import com.google.protobuf.Int32Value; import com.google.protobuf.Int64Value; @@ -44,8 +41,6 @@ import dev.cel.common.annotations.Internal; import dev.cel.common.internal.ProtoTimeUtils; import dev.cel.common.internal.WellKnownProto; -import java.util.ArrayList; -import java.util.List; /** * {@code BaseProtoCelValueConverter} contains the common logic for converting between native Java @@ -103,15 +98,6 @@ protected Object fromWellKnownProto(MessageLiteOrBuilder message, WellKnownProto return UnsignedLong.valueOf(((UInt32Value) message).getValue()); case UINT64_VALUE: return UnsignedLong.fromLongBits(((UInt64Value) message).getValue()); - case FIELD_MASK: - FieldMask fieldMask = (FieldMask) message; - List paths = new ArrayList<>(fieldMask.getPathsCount()); - for (String path : fieldMask.getPathsList()) { - if (!path.isEmpty()) { - paths.add(CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, path)); - } - } - return normalizePrimitive(Joiner.on(",").join(paths)); case EMPTY: return ImmutableMap.of(); default: diff --git a/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java b/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java index 948df759c..89d1b708f 100644 --- a/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java +++ b/common/src/main/java/dev/cel/common/values/ProtoCelValueConverter.java @@ -71,6 +71,9 @@ protected Object fromWellKnownProto(MessageLiteOrBuilder msg, WellKnownProto wel "Unpacking failed for message: " + message.getDescriptorForType().getFullName(), e); } return toRuntimeValue(unpackedMessage); + case FIELD_MASK: + return ProtoMessageValue.create( + (Message) message, celDescriptorPool, this, celOptions.enableJsonFieldNames()); default: return super.fromWellKnownProto(message, wellKnownProto); } diff --git a/common/src/main/java/dev/cel/common/values/ProtoLiteCelValueConverter.java b/common/src/main/java/dev/cel/common/values/ProtoLiteCelValueConverter.java index 3fbb0ad75..64d6ec1d4 100644 --- a/common/src/main/java/dev/cel/common/values/ProtoLiteCelValueConverter.java +++ b/common/src/main/java/dev/cel/common/values/ProtoLiteCelValueConverter.java @@ -28,6 +28,7 @@ import com.google.protobuf.CodedInputStream; import com.google.protobuf.ExtensionRegistryLite; import com.google.protobuf.MessageLite; +import com.google.protobuf.MessageLiteOrBuilder; import com.google.protobuf.WireFormat; import dev.cel.common.annotations.Internal; import dev.cel.common.internal.CelLiteDescriptorPool; @@ -178,12 +179,27 @@ public Object toRuntimeValue(Object value) { return ProtoMessageLiteValue.create(msg, descriptor.getProtoTypeName(), this); } - return super.fromWellKnownProto(msg, wellKnownProto); + return fromWellKnownProto(msg, wellKnownProto); } return super.toRuntimeValue(value); } + @Override + protected Object fromWellKnownProto(MessageLiteOrBuilder msg, WellKnownProto wellKnownProto) { + if (wellKnownProto == WellKnownProto.FIELD_MASK) { + MessageLite message = (MessageLite) msg; + MessageLiteDescriptor descriptor = + descriptorPool + .findDescriptor(message) + .orElseThrow( + () -> new NoSuchElementException("Could not find a descriptor for: " + message)); + return ProtoMessageLiteValue.create(message, descriptor.getProtoTypeName(), this); + } + + return super.fromWellKnownProto(msg, wellKnownProto); + } + private Object getDefaultValue(FieldLiteDescriptor fieldDescriptor) { EncodingType encodingType = fieldDescriptor.getEncodingType(); switch (encodingType) { diff --git a/common/src/test/java/dev/cel/common/values/ProtoLiteCelValueConverterTest.java b/common/src/test/java/dev/cel/common/values/ProtoLiteCelValueConverterTest.java index cec3e0fbf..3b66171e4 100644 --- a/common/src/test/java/dev/cel/common/values/ProtoLiteCelValueConverterTest.java +++ b/common/src/test/java/dev/cel/common/values/ProtoLiteCelValueConverterTest.java @@ -27,6 +27,7 @@ import com.google.protobuf.DoubleValue; import com.google.protobuf.Duration; import com.google.protobuf.ExtensionRegistryLite; +import com.google.protobuf.FieldMask; import com.google.protobuf.FloatValue; import com.google.protobuf.Int32Value; import com.google.protobuf.Int64Value; @@ -104,6 +105,17 @@ public void fromProtoMessageToCelValue_withWellKnownProto_convertsToPrimitivesFr assertThat(adaptedValue).isEqualTo(testCase.value); } + @Test + public void fromProtoMessageToCelValue_fieldMask_returnsProtoMessageLiteValue() { + FieldMask fieldMask = FieldMask.newBuilder().addPaths("foo").addPaths("bar").build(); + + Object adaptedValue = PROTO_LITE_CEL_VALUE_CONVERTER.toRuntimeValue(fieldMask); + + assertThat(adaptedValue).isInstanceOf(ProtoMessageLiteValue.class); + assertThat(((ProtoMessageLiteValue) adaptedValue).select("paths")) + .isEqualTo(ImmutableList.of("foo", "bar")); + } + /** Test cases for repeated_int64: 1L,2L,3L */ @SuppressWarnings("ImmutableEnumChecker") // Test only private enum RepeatedFieldBytesTestCase { diff --git a/common/src/test/java/dev/cel/common/values/ProtoMessageValueTest.java b/common/src/test/java/dev/cel/common/values/ProtoMessageValueTest.java index 365dd32b4..c25a8ea06 100644 --- a/common/src/test/java/dev/cel/common/values/ProtoMessageValueTest.java +++ b/common/src/test/java/dev/cel/common/values/ProtoMessageValueTest.java @@ -303,6 +303,24 @@ public void selectField_durationOutOfRange_success(int seconds, int nanos) { .isEqualTo(Duration.ofSeconds(seconds, nanos)); } + @Test + public void selectField_fieldMask_returnsProtoMessageValue() { + TestAllTypes testAllTypes = + TestAllTypes.newBuilder() + .setFieldMask( + com.google.protobuf.FieldMask.newBuilder().addPaths("foo").addPaths("bar")) + .build(); + + ProtoMessageValue protoMessageValue = + ProtoMessageValue.create( + testAllTypes, DefaultDescriptorPool.INSTANCE, PROTO_CEL_VALUE_CONVERTER, false); + + Object selected = protoMessageValue.select("field_mask"); + assertThat(selected).isInstanceOf(ProtoMessageValue.class); + assertThat(((ProtoMessageValue) selected).select("paths")) + .isEqualTo(ImmutableList.of("foo", "bar")); + } + @SuppressWarnings("ImmutableEnumChecker") // Test only private enum SelectFieldJsonValueTestCase { NULL(Value.newBuilder().build(), NullValue.NULL_VALUE),