Skip to content

Commit

Permalink
Nextgen Proto Pythonic API: Add 'in' operator
Browse files Browse the repository at this point in the history
The “in” operator will be consistent with HasField but a little different with Proto Plus.

The detail behavior of “in” operator in Nextgen for Struct (to be consist with old Struct behavior):
-Raise TypeError if not pass a string
-Check if the key is in the struct.fields

The detail behavior of “in” operator in Nextgen(for other message):
-Raise ValueError if not pass a string
-Raise ValueError if the string is not a field
-For Oneof: Check any field under the oneof is set
-For has-presence field: check if set
-For non-has-presence field (include repeated fields): raise ValueError

PiperOrigin-RevId: 621240977
  • Loading branch information
anandolee authored and copybara-github committed Apr 2, 2024
1 parent 3a2cd26 commit de8e550
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 4 deletions.
31 changes: 31 additions & 0 deletions python/google/protobuf/internal/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,24 @@ def __eq__(self, other):
self.assertNotEqual(m, ComparesWithFoo())
self.assertNotEqual(ComparesWithFoo(), m)

def testIn(self, message_module):
m = message_module.TestAllTypes()
self.assertNotIn('optional_nested_message', m)
self.assertNotIn('oneof_bytes', m)
self.assertNotIn('oneof_string', m)
with self.assertRaises(ValueError) as e:
'repeated_int32' in m
with self.assertRaises(ValueError) as e:
'repeated_nested_message' in m
with self.assertRaises(ValueError) as e:
1 in m
with self.assertRaises(ValueError) as e:
'not_a_field' in m
test_util.SetAllFields(m)
self.assertIn('optional_nested_message', m)
self.assertIn('oneof_bytes', m)
self.assertNotIn('oneof_string', m)


# Class to test proto2-only features (required, extensions, etc.)
@testing_refleaks.TestCase
Expand Down Expand Up @@ -1345,6 +1363,9 @@ def testFieldPresence(self):
self.assertTrue(message.HasField('optional_int32'))
self.assertTrue(message.HasField('optional_bool'))
self.assertTrue(message.HasField('optional_nested_message'))
self.assertIn('optional_int32', message)
self.assertIn('optional_bool', message)
self.assertIn('optional_nested_message', message)

# Set the fields to non-default values.
message.optional_int32 = 5
Expand All @@ -1363,6 +1384,9 @@ def testFieldPresence(self):
self.assertFalse(message.HasField('optional_int32'))
self.assertFalse(message.HasField('optional_bool'))
self.assertFalse(message.HasField('optional_nested_message'))
self.assertNotIn('optional_int32', message)
self.assertNotIn('optional_bool', message)
self.assertNotIn('optional_nested_message', message)
self.assertEqual(0, message.optional_int32)
self.assertEqual(False, message.optional_bool)
self.assertEqual(0, message.optional_nested_message.bb)
Expand Down Expand Up @@ -1689,6 +1713,12 @@ def testFieldPresence(self):
with self.assertRaises(ValueError):
message.HasField('repeated_nested_message')

# Can not test "in" operator.
with self.assertRaises(ValueError):
'repeated_int32' in message
with self.assertRaises(ValueError):
'repeated_nested_message' in message

# Fields should default to their type-specific default.
self.assertEqual(0, message.optional_int32)
self.assertEqual(0, message.optional_float)
Expand All @@ -1699,6 +1729,7 @@ def testFieldPresence(self):
# Setting a submessage should still return proper presence information.
message.optional_nested_message.bb = 0
self.assertTrue(message.HasField('optional_nested_message'))
self.assertIn('optional_nested_message', message)

# Set the fields to non-default values.
message.optional_int32 = 5
Expand Down
12 changes: 12 additions & 0 deletions python/google/protobuf/internal/python_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,17 @@ def __unicode__(self):
cls.__unicode__ = __unicode__


def _AddContainsMethod(message_descriptor, cls):

def __contains__(self, field_or_key):
if (message_descriptor.full_name == 'google.protobuf.Struct'):
return field_or_key in self.fields
else:
return self.HasField(field_or_key)

cls.__contains__ = __contains__


def _BytesForNonRepeatedElement(value, field_number, field_type):
"""Returns the number of bytes needed to serialize a non-repeated element.
The returned byte count includes space for tag information and any
Expand Down Expand Up @@ -1394,6 +1405,7 @@ def _AddMessageMethods(message_descriptor, cls):
_AddStrMethod(message_descriptor, cls)
_AddReprMethod(message_descriptor, cls)
_AddUnicodeMethod(message_descriptor, cls)
_AddContainsMethod(message_descriptor, cls)
_AddByteSizeMethod(message_descriptor, cls)
_AddSerializeToStringMethod(message_descriptor, cls)
_AddSerializePartialToStringMethod(message_descriptor, cls)
Expand Down
3 changes: 0 additions & 3 deletions python/google/protobuf/internal/well_known_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,9 +497,6 @@ class Struct(object):
def __getitem__(self, key):
return _GetStructValue(self.fields[key])

def __contains__(self, item):
return item in self.fields

def __setitem__(self, key, value):
_SetStructValue(self.fields[key], value)

Expand Down
9 changes: 9 additions & 0 deletions python/google/protobuf/internal/well_known_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,15 @@ def testStruct(self):
self.assertEqual([6, True, False, None, inner_struct],
list(struct['key5'].items()))

def testInOperator(self):
struct = struct_pb2.Struct()
struct['key'] = 5

self.assertIn('key', struct)
self.assertNotIn('fields', struct)
with self.assertRaises(TypeError) as e:
1 in struct

def testStructAssignment(self):
# Tests struct assignment from another struct
s1 = struct_pb2.Struct()
Expand Down
23 changes: 23 additions & 0 deletions python/google/protobuf/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,29 @@ def __unicode__(self):
"""Outputs a human-readable representation of the message."""
raise NotImplementedError

def __contains__(self, field_name):
"""Checks if a certain field is set for the message.
Has presence fields return true if the field is set, false if the field is
not set. Fields without presence do raise `ValueError` (this includes
repeated fields, map fields, and implicit presence fields).
If field_name is not defined in the message descriptor, `ValueError` will
be raised.
Note: WKT Struct checks if the key is contained in fields.
Args:
field_name (str): The name of the field to check for presence.
Returns:
bool: Whether a value has been set for the named field.
Raises:
ValueError: if the `field_name` is not a member of this message or
`field_name` is not a string.
"""
raise NotImplementedError

def MergeFrom(self, other_msg):
"""Merges the contents of the specified message into current message.
Expand Down
55 changes: 54 additions & 1 deletion python/google/protobuf/pyext/message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "google/protobuf/pyext/message.h"

#include <Python.h>
#include <structmember.h> // A Python header file.

#include <cstdint>
Expand All @@ -36,6 +37,7 @@
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/strtod.h"
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
#include "google/protobuf/map_field.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/unknown_field_set.h"
Expand Down Expand Up @@ -85,6 +87,12 @@ class MessageReflectionFriend {
return reflection->IsLazyField(field) ||
reflection->IsLazyExtension(message, field);
}
static bool ContainsMapKey(const Reflection* reflection,
const Message& message,
const FieldDescriptor* field,
const MapKey& map_key) {
return reflection->ContainsMapKey(message, field, map_key);
}
};

static PyObject* kDESCRIPTOR;
Expand Down Expand Up @@ -1293,11 +1301,16 @@ PyObject* HasField(CMessage* self, PyObject* arg) {
char* field_name;
Py_ssize_t size;
field_name = const_cast<char*>(PyUnicode_AsUTF8AndSize(arg, &size));
Message* message = self->message;

if (!field_name) {
PyErr_Format(PyExc_ValueError,
"The field name passed to message %s"
" is not a str.",
message->GetDescriptor()->name().c_str());
return nullptr;
}

Message* message = self->message;
bool is_in_oneof;
const FieldDescriptor* field_descriptor = FindFieldWithOneofs(
message, absl::string_view(field_name, size), &is_in_oneof);
Expand Down Expand Up @@ -2290,6 +2303,44 @@ PyObject* ToUnicode(CMessage* self) {
return decoded;
}

PyObject* Contains(CMessage* self, PyObject* arg) {
Message* message = self->message;
const Descriptor* descriptor = message->GetDescriptor();
// For WKT Struct, check if the key is in the fields.
if (descriptor->full_name() == "google.protobuf.Struct") {
const Reflection* reflection = message->GetReflection();
const FieldDescriptor* map_field = descriptor->FindFieldByName("fields");
const FieldDescriptor* key_field = map_field->message_type()->map_key();
PyObject* py_string = CheckString(arg, key_field);
if (!py_string) {
PyErr_SetString(PyExc_TypeError,
"The key passed to Struct message must be a str.");
return nullptr;
}
char* value;
Py_ssize_t value_len;
if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) {
Py_DECREF(py_string);
Py_RETURN_FALSE;
}
std::string key_str;
key_str.assign(value, value_len);
Py_DECREF(py_string);

MapKey map_key;
map_key.SetStringValue(key_str);
if (MessageReflectionFriend::ContainsMapKey(reflection, *message, map_field,
map_key)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}

// For other messages, check with HasField.
return HasField(self, arg);
}

// CMessage static methods:
PyObject* _CheckCalledFromGeneratedFile(PyObject* unused,
PyObject* unused_arg) {
Expand Down Expand Up @@ -2338,6 +2389,8 @@ static PyMethodDef Methods[] = {
"Makes a deep copy of the class."},
{"__unicode__", (PyCFunction)ToUnicode, METH_NOARGS,
"Outputs a unicode representation of the message."},
{"__contains__", (PyCFunction)Contains, METH_O,
"Checks if a message field is set."},
{"ByteSize", (PyCFunction)ByteSize, METH_NOARGS,
"Returns the size of the message in bytes."},
{"Clear", (PyCFunction)Clear, METH_NOARGS, "Clears the message."},
Expand Down
20 changes: 20 additions & 0 deletions python/message.c
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,24 @@ static PyObject* PyUpb_Message_HasField(PyObject* _self, PyObject* arg) {
NULL);
}

static PyObject* PyUpb_Message_Contains(PyObject* _self, PyObject* arg) {
const upb_MessageDef* msgdef = PyUpb_Message_GetMsgdef(_self);
// For WKT Struct, check if the key is in the fields.
if (strcmp(upb_MessageDef_FullName(msgdef), "google.protobuf.Struct") == 0) {
PyUpb_Message* self = (void*)_self;
upb_Message* msg = PyUpb_Message_GetMsg(self);
const upb_FieldDef* f = upb_MessageDef_FindFieldByName(msgdef, "fields");
const upb_Map* map = upb_Message_GetFieldByDef(msg, f).map_val;
const upb_MessageDef* entry_m = upb_FieldDef_MessageSubDef(f);
const upb_FieldDef* key_f = upb_MessageDef_Field(entry_m, 0);
upb_MessageValue u_key;
if (!PyUpb_PyToUpb(arg, key_f, &u_key, NULL)) return NULL;
return PyBool_FromLong(upb_Map_Get(map, u_key, NULL));
}
// For other messages, check with HasField.
return PyUpb_Message_HasField(_self, arg);
}

static PyObject* PyUpb_Message_FindInitializationErrors(PyObject* _self,
PyObject* arg);

Expand Down Expand Up @@ -1640,6 +1658,8 @@ static PyMethodDef PyUpb_Message_Methods[] = {
// TODO
//{ "__unicode__", (PyCFunction)ToUnicode, METH_NOARGS,
// "Outputs a unicode representation of the message." },
{"__contains__", PyUpb_Message_Contains, METH_O,
"Checks if a message field is set."},
{"ByteSize", (PyCFunction)PyUpb_Message_ByteSize, METH_NOARGS,
"Returns the size of the message in bytes."},
{"Clear", (PyCFunction)PyUpb_Message_Clear, METH_NOARGS,
Expand Down

0 comments on commit de8e550

Please sign in to comment.