diff --git a/python/google/protobuf/field_mask.py b/python/google/protobuf/field_mask.py new file mode 100644 index 000000000000..d336c7246b44 --- /dev/null +++ b/python/google/protobuf/field_mask.py @@ -0,0 +1,76 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Contains the FieldMask helper APIs.""" + +from typing import Optional + +from google.protobuf import descriptor + +from google.protobuf.field_mask_pb2 import FieldMask + + +def to_json_string(mask_msg: FieldMask) -> str: + """Converts FieldMask to string according to proto3 JSON spec.""" + return mask_msg.ToJsonString() + + +def from_json_string(value: str) -> FieldMask: + """Converts string to FieldMask according to proto3 JSON spec.""" + msg = FieldMask() + msg.FromJsonString(value) + return msg + + +def is_valid_for_descriptor( + mask_msg: FieldMask, message_descriptor: descriptor.Descriptor +) -> bool: + """Checks whether the FieldMask is valid for Message Descriptor.""" + return mask_msg.IsValidForDescriptor(message_descriptor) + + +def all_fields_from_descriptor( + message_descriptor: descriptor.Descriptor, +) -> FieldMask: + """Gets all direct fields of Message Descriptor to FieldMask.""" + msg = FieldMask() + msg.AllFieldsFromDescriptor(message_descriptor) + return msg + + +def canonical_form_from_mask(mask_msg: FieldMask) -> FieldMask: + """Converts a FieldMask to the canonical form FieldMask.""" + msg = FieldMask() + msg.CanonicalFormFromMask(mask_msg) + return msg + + +def union(mask1: FieldMask, mask2: FieldMask) -> FieldMask: + """Merges mask1 and mask2 into a new FieldMask.""" + msg = FieldMask() + msg.Union(mask1, mask2) + return msg + + +def intersect(mask1: FieldMask, mask2: FieldMask) -> FieldMask: + """Intersects mask1 and mask2 into a new FieldMask.""" + msg = FieldMask() + msg.Intersect(mask1, mask2) + return msg + + +def merge_message( + mask_msg: FieldMask, + source: FieldMask, + destination: FieldMask, + replace_message_field: Optional[bool] = False, + replace_repeated_field: Optional[bool] = False, +) -> None: + """Merges fields specified in field_mask from source to destination.""" + return mask_msg.MergeMessage( + source, destination, replace_message_field, replace_repeated_field + ) diff --git a/python/google/protobuf/internal/field_mask_test.py b/python/google/protobuf/internal/field_mask_test.py index ed7c9ef60e26..e73af8346d49 100644 --- a/python/google/protobuf/internal/field_mask_test.py +++ b/python/google/protobuf/internal/field_mask_test.py @@ -9,10 +9,12 @@ import unittest -from google.protobuf import field_mask_pb2 +from google.protobuf import descriptor +from google.protobuf import field_mask as field_mask_nextgen from google.protobuf.internal import field_mask from google.protobuf.internal import test_util -from google.protobuf import descriptor + +from google.protobuf import field_mask_pb2 from google.protobuf import map_unittest_pb2 from google.protobuf import unittest_pb2 @@ -372,6 +374,46 @@ def testCamelCaseToSnakeCase(self): 'Fail to parse FieldMask: Path name foo_bar must not contain "_"s.', field_mask._CamelCaseToSnakeCase, 'foo_bar') + def test_field_mask_nextgen(self): + mask_msg = field_mask_pb2.FieldMask() + mask_msg.paths.append('foo') + self.assertEqual('foo', field_mask_nextgen.to_json_string(mask_msg)) + + msg = field_mask_nextgen.from_json_string('fooBar,barQuz') + self.assertEqual(['foo_bar', 'bar_quz'], msg.paths) + + msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR + msg = field_mask_nextgen.all_fields_from_descriptor(msg_descriptor) + self.assertEqual(79, len(msg.paths)) + self.assertTrue( + field_mask_nextgen.is_valid_for_descriptor(msg, msg_descriptor) + ) + + mask1 = field_mask_nextgen.from_json_string('foo,bar,foo') + mask2 = field_mask_nextgen.canonical_form_from_mask(mask1) + self.assertEqual('bar,foo', field_mask_nextgen.to_json_string(mask2)) + + mask1 = field_mask_nextgen.from_json_string('foo,baz.bb') + mask2 = field_mask_nextgen.from_json_string('baz.bb,quz') + union_mask = field_mask_nextgen.union(mask1, mask2) + intersect_mask = field_mask_nextgen.intersect(mask1, mask2) + self.assertEqual( + 'baz.bb,foo,quz', field_mask_nextgen.to_json_string(union_mask) + ) + self.assertEqual( + 'baz.bb', field_mask_nextgen.to_json_string(intersect_mask) + ) + + new_msg = unittest_pb2.TestOneof2() + dst = unittest_pb2.TestOneof2() + dst.foo_message.moo_int = 1 + mask = field_mask_nextgen.from_json_string( + 'fooMessage,fooLazyMessage.mooInt' + ) + field_mask_nextgen.merge_message(mask, new_msg, dst) + self.assertIn('foo_message', dst) + self.assertNotIn('foo_lazy_message', dst) + if __name__ == '__main__': unittest.main()