Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;

import com.google.common.base.CaseFormat;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand All @@ -29,7 +27,6 @@
import com.google.protobuf.BytesValue;
import com.google.protobuf.DoubleValue;
import com.google.protobuf.Duration;
import com.google.protobuf.FieldMask;
import com.google.protobuf.FloatValue;
import com.google.protobuf.Int32Value;
import com.google.protobuf.Int64Value;
Expand All @@ -44,8 +41,6 @@
import dev.cel.common.annotations.Internal;
import dev.cel.common.internal.ProtoTimeUtils;
import dev.cel.common.internal.WellKnownProto;
import java.util.ArrayList;
import java.util.List;

/**
* {@code BaseProtoCelValueConverter} contains the common logic for converting between native Java
Expand Down Expand Up @@ -103,15 +98,6 @@ protected Object fromWellKnownProto(MessageLiteOrBuilder message, WellKnownProto
return UnsignedLong.valueOf(((UInt32Value) message).getValue());
case UINT64_VALUE:
return UnsignedLong.fromLongBits(((UInt64Value) message).getValue());
case FIELD_MASK:
FieldMask fieldMask = (FieldMask) message;
List<String> paths = new ArrayList<>(fieldMask.getPathsCount());
for (String path : fieldMask.getPathsList()) {
if (!path.isEmpty()) {
paths.add(CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, path));
}
}
return normalizePrimitive(Joiner.on(",").join(paths));
case EMPTY:
return ImmutableMap.of();
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ protected Object fromWellKnownProto(MessageLiteOrBuilder msg, WellKnownProto wel
"Unpacking failed for message: " + message.getDescriptorForType().getFullName(), e);
}
return toRuntimeValue(unpackedMessage);
case FIELD_MASK:
return ProtoMessageValue.create(
(Message) message, celDescriptorPool, this, celOptions.enableJsonFieldNames());
default:
return super.fromWellKnownProto(message, wellKnownProto);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.ExtensionRegistryLite;
import com.google.protobuf.MessageLite;
import com.google.protobuf.MessageLiteOrBuilder;
import com.google.protobuf.WireFormat;
import dev.cel.common.annotations.Internal;
import dev.cel.common.internal.CelLiteDescriptorPool;
Expand Down Expand Up @@ -178,12 +179,27 @@ public Object toRuntimeValue(Object value) {
return ProtoMessageLiteValue.create(msg, descriptor.getProtoTypeName(), this);
}

return super.fromWellKnownProto(msg, wellKnownProto);
return fromWellKnownProto(msg, wellKnownProto);
}

return super.toRuntimeValue(value);
}

@Override
protected Object fromWellKnownProto(MessageLiteOrBuilder msg, WellKnownProto wellKnownProto) {
if (wellKnownProto == WellKnownProto.FIELD_MASK) {
MessageLite message = (MessageLite) msg;
MessageLiteDescriptor descriptor =
descriptorPool
.findDescriptor(message)
.orElseThrow(
() -> new NoSuchElementException("Could not find a descriptor for: " + message));
return ProtoMessageLiteValue.create(message, descriptor.getProtoTypeName(), this);
}

return super.fromWellKnownProto(msg, wellKnownProto);
}

private Object getDefaultValue(FieldLiteDescriptor fieldDescriptor) {
EncodingType encodingType = fieldDescriptor.getEncodingType();
switch (encodingType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.google.protobuf.DoubleValue;
import com.google.protobuf.Duration;
import com.google.protobuf.ExtensionRegistryLite;
import com.google.protobuf.FieldMask;
import com.google.protobuf.FloatValue;
import com.google.protobuf.Int32Value;
import com.google.protobuf.Int64Value;
Expand Down Expand Up @@ -104,6 +105,17 @@ public void fromProtoMessageToCelValue_withWellKnownProto_convertsToPrimitivesFr
assertThat(adaptedValue).isEqualTo(testCase.value);
}

@Test
public void fromProtoMessageToCelValue_fieldMask_returnsProtoMessageLiteValue() {
FieldMask fieldMask = FieldMask.newBuilder().addPaths("foo").addPaths("bar").build();

Object adaptedValue = PROTO_LITE_CEL_VALUE_CONVERTER.toRuntimeValue(fieldMask);

assertThat(adaptedValue).isInstanceOf(ProtoMessageLiteValue.class);
assertThat(((ProtoMessageLiteValue) adaptedValue).select("paths"))
.isEqualTo(ImmutableList.of("foo", "bar"));
}

/** Test cases for repeated_int64: 1L,2L,3L */
@SuppressWarnings("ImmutableEnumChecker") // Test only
private enum RepeatedFieldBytesTestCase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,24 @@ public void selectField_durationOutOfRange_success(int seconds, int nanos) {
.isEqualTo(Duration.ofSeconds(seconds, nanos));
}

@Test
public void selectField_fieldMask_returnsProtoMessageValue() {
TestAllTypes testAllTypes =
TestAllTypes.newBuilder()
.setFieldMask(
com.google.protobuf.FieldMask.newBuilder().addPaths("foo").addPaths("bar"))
.build();

ProtoMessageValue protoMessageValue =
ProtoMessageValue.create(
testAllTypes, DefaultDescriptorPool.INSTANCE, PROTO_CEL_VALUE_CONVERTER, false);

Object selected = protoMessageValue.select("field_mask");
assertThat(selected).isInstanceOf(ProtoMessageValue.class);
assertThat(((ProtoMessageValue) selected).select("paths"))
.isEqualTo(ImmutableList.of("foo", "bar"));
}

@SuppressWarnings("ImmutableEnumChecker") // Test only
private enum SelectFieldJsonValueTestCase {
NULL(Value.newBuilder().build(), NullValue.NULL_VALUE),
Expand Down
Loading