From 6e0aedb5c39071a708d367d7d5e495b5587fc529 Mon Sep 17 00:00:00 2001 From: Brad Keryan Date: Fri, 27 Sep 2024 11:28:57 -0500 Subject: [PATCH] [High Priority] Fix Path type conversion and type hints for client codegen (#931) * generator: Fix Path type conversion for client outputs * generator: Annotate array configuration parameters as Iterable[T] and array output parameters as Sequence[T] * generator: Fix lint errors * service: Fix lint errors * service: Revert debugging code * service: Simplify convert_paths_to_strings and convert_strings_to_paths * generator: Simplify client imports `from typing import big, list, of, things` results in a lot of conditional code. It's simpler to use `import typing` and reference `typing.Iterable`, etc. * generator: Use pathlib.PurePath for configuration parameters * generator: Fix lint errors * generator: Move path conversion and dict/list conversion into client_support module * service: Fix parameter name in client_support * service: Update client_support to use type specialization constants --- .../client/_support.py | 73 +++++--- .../measurement_plugin_client.py.mako | 85 +++------ .../test_non_streaming_measurement_client.py | 95 ++++++++-- .../non_streaming_data_measurement_client.py | 163 ++++++++---------- .../void_measurement_client.py | 34 ++-- .../measurement/client_support.py | 105 ++++++++++- 6 files changed, 338 insertions(+), 217 deletions(-) diff --git a/packages/generator/ni_measurement_plugin_sdk_generator/client/_support.py b/packages/generator/ni_measurement_plugin_sdk_generator/client/_support.py index c05a3e81b..8b8f1998b 100644 --- a/packages/generator/ni_measurement_plugin_sdk_generator/client/_support.py +++ b/packages/generator/ni_measurement_plugin_sdk_generator/client/_support.py @@ -13,7 +13,9 @@ from google.protobuf import descriptor_pool from google.protobuf.descriptor_pb2 import FieldDescriptorProto from google.protobuf.type_pb2 import Field -from ni_measurement_plugin_sdk_service._internal.grpc_servicer import frame_metadata_dict +from ni_measurement_plugin_sdk_service._internal.grpc_servicer import ( + frame_metadata_dict, +) from ni_measurement_plugin_sdk_service._internal.stubs.ni.measurementlink.measurement.v2 import ( measurement_service_pb2 as v2_measurement_service_pb2, measurement_service_pb2_grpc as v2_measurement_service_pb2_grpc, @@ -21,18 +23,17 @@ from ni_measurement_plugin_sdk_service.discovery import DiscoveryClient from ni_measurement_plugin_sdk_service.grpc.channelpool import GrpcChannelPool from ni_measurement_plugin_sdk_service.measurement.client_support import ( + ParameterMetadata, create_file_descriptor, deserialize_parameters, - ParameterMetadata, ) - _V2_MEASUREMENT_SERVICE_INTERFACE = "ni.measurementlink.measurement.v2.MeasurementService" _INVALID_CHARS = "`~!@#$%^&*()-+={}[]\\|:;',<>.?/ \n" _XY_DATA_IMPORT = "from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types.xydata_pb2 import DoubleXYData" -_PATH_IMPORT = "from pathlib import Path" +_PATH_IMPORT = "import pathlib" _PROTO_DATATYPE_TO_PYTYPE_LOOKUP = { Field.TYPE_INT32: int, @@ -178,21 +179,24 @@ def get_configuration_and_output_metadata_by_index( ) configuration_metadata = frame_metadata_dict(configuration_parameter_list) output_metadata = frame_metadata_dict(output_parameter_list) - deserialized_parameters = deserialize_parameters( + + # Disable path conversion to avoid normalizing path separators and eliminate the need to + # convert Path objects to strings. + default_values = deserialize_parameters( configuration_metadata, metadata.measurement_signature.configuration_defaults.value, f"{service_class}.Configurations", + convert_paths=False, ) - for k, v in deserialized_parameters.items(): - if issubclass(type(v), Enum): - default_value = v.value - elif issubclass(type(v), list) and any(issubclass(type(e), Enum) for e in v): - default_value = [e.value for e in v] - else: - default_value = v - - configuration_metadata[k] = configuration_metadata[k]._replace(default_value=default_value) + for id, default_value in enumerate(default_values, start=1): + if isinstance(default_value, Enum): + default_value = default_value.value + elif isinstance(default_value, list) and any(isinstance(e, Enum) for e in default_value): + default_value = [e.value for e in default_value] + configuration_metadata[id] = configuration_metadata[id]._replace( + default_value=default_value + ) return configuration_metadata, output_metadata @@ -211,19 +215,21 @@ def get_configuration_parameters_with_type_and_default_values( parameter_names.append(parameter_name) default_value = metadata.default_value - parameter_type = _get_python_type_as_str(metadata.type, metadata.repeated) + parameter_type = _get_configuration_python_type_as_str(metadata.type, metadata.repeated) if isinstance(default_value, str): default_value = repr(default_value) if metadata.annotations and metadata.annotations.get("ni/type_specialization") == "path": - parameter_type = "Path" + parameter_type = "pathlib.PurePath" built_in_import_modules.append(_PATH_IMPORT) if metadata.repeated: - formatted_value = ", ".join(f"Path({repr(value)})" for value in default_value) + formatted_value = ", ".join( + f"pathlib.PurePath({repr(value)})" for value in default_value + ) default_value = f"[{formatted_value}]" - parameter_type = f"List[{parameter_type}]" + parameter_type = f"typing.Iterable[{parameter_type}]" else: - default_value = f"Path({default_value})" + default_value = f"pathlib.PurePath({default_value})" if metadata.message_type: raise click.ClickException( @@ -243,7 +249,7 @@ def get_configuration_parameters_with_type_and_default_values( concatenated_default_value = ", ".join(values) concatenated_default_value = f"[{concatenated_default_value}]" - parameter_type = f"List[{parameter_type}]" + parameter_type = f"typing.Iterable[{parameter_type}]" default_value = concatenated_default_value else: enum_value = next((e.name for e in enum_type if e.value == default_value), None) @@ -267,27 +273,29 @@ def get_output_parameters_with_type( output_parameters_with_type: List[str] = [] for metadata in output_metadata.values(): parameter_name = _get_python_identifier(metadata.display_name) - parameter_type = _get_python_type_as_str(metadata.type, metadata.repeated) + parameter_type = _get_output_python_type_as_str(metadata.type, metadata.repeated) if metadata.annotations and metadata.annotations.get("ni/type_specialization") == "path": - parameter_type = "Path" + parameter_type = "pathlib.Path" built_in_import_modules.append(_PATH_IMPORT) if metadata.repeated: - parameter_type = f"List[{parameter_type}]" + parameter_type = f"typing.Sequence[{parameter_type}]" if metadata.message_type and metadata.message_type == "ni.protobuf.types.DoubleXYData": parameter_type = "DoubleXYData" custom_import_modules.append(_XY_DATA_IMPORT) if metadata.repeated: - parameter_type = f"List[{parameter_type}]" + parameter_type = f"typing.Sequence[{parameter_type}]" if metadata.annotations and metadata.annotations.get("ni/type_specialization") == "enum": enum_type_name = _get_enum_type( metadata.display_name, metadata.annotations["ni/enum.values"], enum_values_by_type ).__name__ - parameter_type = f"List[{enum_type_name}]" if metadata.repeated else enum_type_name + parameter_type = ( + f"typing.Sequence[{enum_type_name}]" if metadata.repeated else enum_type_name + ) output_parameters_with_type.append(f"{parameter_name}: {parameter_type}") @@ -374,14 +382,25 @@ def _get_python_identifier(input_string: str) -> str: return valid_identifier -def _get_python_type_as_str(type: Field.Kind.ValueType, is_array: bool) -> str: +def _get_configuration_python_type_as_str(type: Field.Kind.ValueType, is_array: bool) -> str: + python_type = _PROTO_DATATYPE_TO_PYTYPE_LOOKUP.get(type) + + if python_type is None: + raise TypeError(f"Invalid data type: '{type}'.") + + if is_array: + return f"typing.Iterable[{python_type.__name__}]" + return python_type.__name__ + + +def _get_output_python_type_as_str(type: Field.Kind.ValueType, is_array: bool) -> str: python_type = _PROTO_DATATYPE_TO_PYTYPE_LOOKUP.get(type) if python_type is None: raise TypeError(f"Invalid data type: '{type}'.") if is_array: - return f"List[{python_type.__name__}]" + return f"typing.Sequence[{python_type.__name__}]" return python_type.__name__ diff --git a/packages/generator/ni_measurement_plugin_sdk_generator/client/templates/measurement_plugin_client.py.mako b/packages/generator/ni_measurement_plugin_sdk_generator/client/templates/measurement_plugin_client.py.mako index 24d382dfc..1cb940c54 100644 --- a/packages/generator/ni_measurement_plugin_sdk_generator/client/templates/measurement_plugin_client.py.mako +++ b/packages/generator/ni_measurement_plugin_sdk_generator/client/templates/measurement_plugin_client.py.mako @@ -19,19 +19,12 @@ from typing import Any from __future__ import annotations import logging +import pathlib import threading +import typing % if len(enum_by_class_name): from enum import Enum % endif -from pathlib import Path -<% - typing_imports = ["Any", "Generator", "List", "Optional"] - if output_metadata: - typing_imports += ["NamedTuple"] - if "from pathlib import Path" in built_in_import_modules: - typing_imports += ["Iterable"] -%>\ -from typing import ${", ".join(sorted(typing_imports))} import grpc from google.protobuf import any_pb2, descriptor_pool @@ -46,11 +39,11 @@ ${module} from ni_measurement_plugin_sdk_service.discovery import DiscoveryClient from ni_measurement_plugin_sdk_service.grpc.channelpool import GrpcChannelPool from ni_measurement_plugin_sdk_service.measurement.client_support import ( + ParameterMetadata, create_file_descriptor, % if output_metadata: deserialize_parameters, % endif - ParameterMetadata, serialize_parameters, ) from ni_measurement_plugin_sdk_service.pin_map import PinMapClient @@ -73,7 +66,7 @@ class ${enum_name.__name__}(Enum): <% output_type = "None" %>\ % if output_metadata: -class Outputs(NamedTuple): +class Outputs(typing.NamedTuple): """Outputs for the ${display_name | repr} measurement plug-in.""" % for output_parameter in output_parameters_with_type: @@ -89,10 +82,10 @@ class ${class_name}: def __init__( self, *, - discovery_client: Optional[DiscoveryClient] = None, - pin_map_client: Optional[PinMapClient] = None, - grpc_channel: Optional[grpc.Channel] = None, - grpc_channel_pool: Optional[GrpcChannelPool] = None, + discovery_client: typing.Optional[DiscoveryClient] = None, + pin_map_client: typing.Optional[PinMapClient] = None, + grpc_channel: typing.Optional[grpc.Channel] = None, + grpc_channel_pool: typing.Optional[GrpcChannelPool] = None, ): """Initialize the Measurement Plug-In Client. @@ -111,8 +104,8 @@ class ${class_name}: self._grpc_channel_pool = grpc_channel_pool self._discovery_client = discovery_client self._pin_map_client = pin_map_client - self._stub: Optional[v2_measurement_service_pb2_grpc.MeasurementServiceStub] = None - self._measure_response: Optional[ + self._stub: typing.Optional[v2_measurement_service_pb2_grpc.MeasurementServiceStub] = None + self._measure_response: typing.Optional[ grpc.CallIterator[v2_measurement_service_pb2.MeasureResponse] ] = None self._configuration_metadata = { @@ -170,12 +163,12 @@ class ${class_name}: self._pin_map_context = val @property - def sites(self) -> Optional[List[int]]: + def sites(self) -> typing.Optional[typing.List[int]]: """The sites where the measurement must be executed.""" return self._pin_map_context.sites @sites.setter - def sites(self, val: List[int]) -> None: + def sites(self, val: typing.List[int]) -> None: if self._pin_map_context is None: raise AttributeError( "Cannot set sites because the pin map context is None. Please provide a pin map context or register a pin map before setting sites." @@ -230,14 +223,14 @@ class ${class_name}: ) def _create_measure_request( - self, parameter_values: List[Any] + self, parameter_values: typing.List[typing.Any] ) -> v2_measurement_service_pb2.MeasureRequest: serialized_configuration = any_pb2.Any( type_url=${configuration_parameters_type_url | repr}, value=serialize_parameters( - parameter_metadata_dict=self._configuration_metadata, - parameter_values=parameter_values, - service_name=f"{self._service_class}.Configurations", + self._configuration_metadata, + parameter_values, + f"{self._service_class}.Configurations", ) ) return v2_measurement_service_pb2.MeasureRequest( @@ -249,17 +242,13 @@ class ${class_name}: def _deserialize_response( self, response: v2_measurement_service_pb2.MeasureResponse ) -> Outputs: - if self._output_metadata: - result = [None] * max(self._output_metadata.keys()) - else: - result = [] - output_values = deserialize_parameters( - self._output_metadata, response.outputs.value, f"{self._service_class}.Outputs" + return Outputs._make( + deserialize_parameters( + self._output_metadata, + response.outputs.value, + f"{self._service_class}.Outputs", + ) ) - - for k, v in output_values.items(): - result[k - 1] = v - return Outputs._make(result) % endif def measure( @@ -285,17 +274,13 @@ class ${class_name}: def stream_measure( self, ${configuration_parameters_with_type_and_default_values} - ) -> Generator[${output_type}, None, None] : + ) -> typing.Generator[${output_type}, None, None] : """Perform a streaming measurement. Returns: Stream of measurement outputs. """ - % if "from pathlib import Path" in built_in_import_modules: - parameter_values = _convert_paths_to_strings([${measure_api_parameters}]) - % else: parameter_values = [${measure_api_parameters}] - % endif with self._initialization_lock: if self._measure_response is not None: raise RuntimeError( @@ -327,7 +312,7 @@ class ${class_name}: else: return False - def register_pin_map(self, pin_map_path: Path) -> None: + def register_pin_map(self, pin_map_path: pathlib.Path) -> None: """Registers the pin map with the pin map service. Args: @@ -335,25 +320,3 @@ class ${class_name}: """ pin_map_id = self._get_pin_map_client().update_pin_map(pin_map_path) self._pin_map_context = self._pin_map_context._replace(pin_map_id=pin_map_id) - -% if "from pathlib import Path" in built_in_import_modules: - -def _convert_paths_to_strings(parameter_values: Iterable[Any]) -> List[Any]: - result: List[Any] = [] - - for parameter_value in parameter_values: - if isinstance(parameter_value, list): - converted_list = [] - for value in parameter_value: - if isinstance(value, Path): - converted_list.append(str(value)) - else: - converted_list.append(value) - result.append(converted_list) - elif isinstance(parameter_value, Path): - result.append(str(parameter_value)) - else: - result.append(parameter_value) - return result - -% endif \ No newline at end of file diff --git a/packages/generator/tests/acceptance/test_non_streaming_measurement_client.py b/packages/generator/tests/acceptance/test_non_streaming_measurement_client.py index 5341d5864..1b510af65 100644 --- a/packages/generator/tests/acceptance/test_non_streaming_measurement_client.py +++ b/packages/generator/tests/acceptance/test_non_streaming_measurement_client.py @@ -1,8 +1,9 @@ import importlib.util import pathlib +from collections.abc import Sequence from enum import Enum from types import ModuleType -from typing import Generator +from typing import Any, Generator, Tuple, Type, Union import pytest from ni_measurement_plugin_sdk_service.measurement.service import MeasurementService @@ -48,14 +49,14 @@ def test___measurement_plugin_client___measure___returns_output( "string with \ttabspace", "string with \nnewline", ], - path_out="sample\\path\\for\\test", + path_out=pathlib.Path("sample\\path\\for\\test"), path_array_out=[ - "path\\with\\forward\\slash", - "path\\with\\backslash", - "path with 'single quotes'", - 'path with "double quotes"', - "path\twith\ttabs", - "path\nwith\nnewlines", + pathlib.Path("path\\with\\forward\\slash"), + pathlib.Path("path\\with\\backslash"), + pathlib.Path("path with 'single quotes'"), + pathlib.Path('path with "double quotes"'), + pathlib.Path("path\twith\ttabs"), + pathlib.Path("path\nwith\nnewlines"), ], io_out="resource", io_array_out=["resource1", "resource2"], @@ -74,6 +75,17 @@ def test___measurement_plugin_client___measure___returns_output( assert str(response) == str(expected_output) +def test___measurement_plugin_client___measure___converts_output_types( + measurement_plugin_client_module: ModuleType, +) -> None: + test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement") + measurement_plugin_client = test_measurement_client_type() + + response = measurement_plugin_client.measure() + + _verify_output_types(response, measurement_plugin_client_module) + + def test___measurement_plugin_client___stream_measure___returns_output( measurement_plugin_client_module: ModuleType, ) -> None: @@ -92,14 +104,14 @@ def test___measurement_plugin_client___stream_measure___returns_output( "string with \ttabspace", "string with \nnewline", ], - path_out="sample\\path\\for\\test", + path_out=pathlib.Path("sample\\path\\for\\test"), path_array_out=[ - "path\\with\\forward\\slash", - "path\\with\\backslash", - "path with 'single quotes'", - 'path with "double quotes"', - "path\twith\ttabs", - "path\nwith\nnewlines", + pathlib.Path("path\\with\\forward\\slash"), + pathlib.Path("path\\with\\backslash"), + pathlib.Path("path with 'single quotes'"), + pathlib.Path('path with "double quotes"'), + pathlib.Path("path\twith\ttabs"), + pathlib.Path("path\nwith\nnewlines"), ], io_out="resource", io_array_out=["resource1", "resource2"], @@ -120,6 +132,19 @@ def test___measurement_plugin_client___stream_measure___returns_output( assert str(responses[0]) == str(expected_output) +def test___measurement_plugin_client___stream_measure___converts_output_types( + measurement_plugin_client_module: ModuleType, +) -> None: + test_measurement_client_type = getattr(measurement_plugin_client_module, "TestMeasurement") + measurement_plugin_client = test_measurement_client_type() + + response_iterator = measurement_plugin_client.stream_measure() + + responses = [response for response in response_iterator] + assert len(responses) == 1 + _verify_output_types(responses[0], measurement_plugin_client_module) + + @pytest.fixture(scope="module") def measurement_client_directory( create_client: CliRunnerFunction, @@ -168,3 +193,43 @@ def measurement_service( """Test fixture that creates and hosts a Measurement Plug-In Service.""" with non_streaming_data_measurement.measurement_service.host_service() as service: yield service + + +def _verify_output_types(outputs: Any, measurement_plugin_client_module: ModuleType) -> None: + output_type = getattr(measurement_plugin_client_module, "Outputs") + enum_type = getattr(measurement_plugin_client_module, "EnumInEnum") + protobuf_enum_type = getattr(measurement_plugin_client_module, "ProtobufEnumInEnum") + + _assert_type(outputs, output_type) + _assert_type(outputs.float_out, float) + _assert_collection_type(outputs.double_array_out, Sequence, float) + _assert_type(outputs.bool_out, bool) + _assert_type(outputs.string_out, str) + _assert_collection_type(outputs.string_array_out, Sequence, str) + _assert_type(outputs.path_out, pathlib.Path) + _assert_collection_type(outputs.path_array_out, Sequence, pathlib.Path) + _assert_type(outputs.io_out, str) + _assert_collection_type(outputs.io_array_out, Sequence, str) + _assert_type(outputs.integer_out, int) + _assert_type(outputs.xy_data_out, type(None)) + _assert_type(outputs.io_out, str) + _assert_collection_type(outputs.io_array_out, Sequence, str) + _assert_type(outputs.enum_out, enum_type) + _assert_collection_type(outputs.enum_array_out, Sequence, enum_type) + _assert_type(outputs.protobuf_enum_out, protobuf_enum_type) + + +def _assert_type(value: Any, expected_type: Union[Type[Any], Tuple[Type[Any], ...]]) -> None: + assert isinstance( + value, expected_type + ), f"{value!r} has type {type(value)}, expected {expected_type}" + + +def _assert_collection_type( + value: Any, + expected_type: Union[Type[Any], Tuple[Type[Any], ...]], + expected_element_type: Union[Type[Any], Tuple[Type[Any], ...]], +) -> None: + _assert_type(value, expected_type) + for element in value: + _assert_type(element, expected_element_type) diff --git a/packages/generator/tests/test_assets/example_renders/measurement_plugin_client/non_streaming_data_measurement_client.py b/packages/generator/tests/test_assets/example_renders/measurement_plugin_client/non_streaming_data_measurement_client.py index f35b50ce8..0f8a6ee2f 100644 --- a/packages/generator/tests/test_assets/example_renders/measurement_plugin_client/non_streaming_data_measurement_client.py +++ b/packages/generator/tests/test_assets/example_renders/measurement_plugin_client/non_streaming_data_measurement_client.py @@ -3,10 +3,10 @@ from __future__ import annotations import logging +import pathlib import threading +import typing from enum import Enum -from pathlib import Path -from typing import Any, Generator, Iterable, List, NamedTuple, Optional import grpc from google.protobuf import any_pb2, descriptor_pool @@ -21,9 +21,9 @@ from ni_measurement_plugin_sdk_service.discovery import DiscoveryClient from ni_measurement_plugin_sdk_service.grpc.channelpool import GrpcChannelPool from ni_measurement_plugin_sdk_service.measurement.client_support import ( + ParameterMetadata, create_file_descriptor, deserialize_parameters, - ParameterMetadata, serialize_parameters, ) from ni_measurement_plugin_sdk_service.pin_map import PinMapClient @@ -52,22 +52,22 @@ class ProtobufEnumInEnum(Enum): BLACK = 3 -class Outputs(NamedTuple): +class Outputs(typing.NamedTuple): """Outputs for the 'Non-Streaming Data Measurement (Py)' measurement plug-in.""" float_out: float - double_array_out: List[float] + double_array_out: typing.Sequence[float] bool_out: bool string_out: str - string_array_out: List[str] - path_out: Path - path_array_out: List[Path] + string_array_out: typing.Sequence[str] + path_out: pathlib.Path + path_array_out: typing.Sequence[pathlib.Path] io_out: str - io_array_out: List[str] + io_array_out: typing.Sequence[str] integer_out: int xy_data_out: DoubleXYData enum_out: EnumInEnum - enum_array_out: List[EnumInEnum] + enum_array_out: typing.Sequence[EnumInEnum] protobuf_enum_out: ProtobufEnumInEnum @@ -77,10 +77,10 @@ class NonStreamingDataMeasurementClient: def __init__( self, *, - discovery_client: Optional[DiscoveryClient] = None, - pin_map_client: Optional[PinMapClient] = None, - grpc_channel: Optional[grpc.Channel] = None, - grpc_channel_pool: Optional[GrpcChannelPool] = None, + discovery_client: typing.Optional[DiscoveryClient] = None, + pin_map_client: typing.Optional[PinMapClient] = None, + grpc_channel: typing.Optional[grpc.Channel] = None, + grpc_channel_pool: typing.Optional[GrpcChannelPool] = None, ): """Initialize the Measurement Plug-In Client. @@ -99,8 +99,8 @@ def __init__( self._grpc_channel_pool = grpc_channel_pool self._discovery_client = discovery_client self._pin_map_client = pin_map_client - self._stub: Optional[v2_measurement_service_pb2_grpc.MeasurementServiceStub] = None - self._measure_response: Optional[ + self._stub: typing.Optional[v2_measurement_service_pb2_grpc.MeasurementServiceStub] = None + self._measure_response: typing.Optional[ grpc.CallIterator[v2_measurement_service_pb2.MeasureResponse] ] = None self._configuration_metadata = { @@ -440,12 +440,12 @@ def pin_map_context(self, val: PinMapContext) -> None: self._pin_map_context = val @property - def sites(self) -> Optional[List[int]]: + def sites(self) -> typing.Optional[typing.List[int]]: """The sites where the measurement must be executed.""" return self._pin_map_context.sites @sites.setter - def sites(self, val: List[int]) -> None: + def sites(self, val: typing.List[int]) -> None: if self._pin_map_context is None: raise AttributeError( "Cannot set sites because the pin map context is None. Please provide a pin map context or register a pin map before setting sites." @@ -502,14 +502,14 @@ def _create_file_descriptor(self) -> None: ) def _create_measure_request( - self, parameter_values: List[Any] + self, parameter_values: typing.List[typing.Any] ) -> v2_measurement_service_pb2.MeasureRequest: serialized_configuration = any_pb2.Any( type_url="type.googleapis.com/ni.measurementlink.measurement.v2.MeasurementConfigurations", value=serialize_parameters( - parameter_metadata_dict=self._configuration_metadata, - parameter_values=parameter_values, - service_name=f"{self._service_class}.Configurations", + self._configuration_metadata, + parameter_values, + f"{self._service_class}.Configurations", ), ) return v2_measurement_service_pb2.MeasureRequest( @@ -520,25 +520,21 @@ def _create_measure_request( def _deserialize_response( self, response: v2_measurement_service_pb2.MeasureResponse ) -> Outputs: - if self._output_metadata: - result = [None] * max(self._output_metadata.keys()) - else: - result = [] - output_values = deserialize_parameters( - self._output_metadata, response.outputs.value, f"{self._service_class}.Outputs" + return Outputs._make( + deserialize_parameters( + self._output_metadata, + response.outputs.value, + f"{self._service_class}.Outputs", + ) ) - for k, v in output_values.items(): - result[k - 1] = v - return Outputs._make(result) - def measure( self, float_in: float = 0.05999999865889549, - double_array_in: List[float] = [0.1, 0.2, 0.3], + double_array_in: typing.Iterable[float] = [0.1, 0.2, 0.3], bool_in: bool = False, string_in: str = "sample string", - string_array_in: List[str] = [ + string_array_in: typing.Iterable[str] = [ "string with /forwardslash", "string with \\backslash", "string with 'single quotes'", @@ -546,20 +542,20 @@ def measure( "string with \ttabspace", "string with \nnewline", ], - path_in: Path = Path("sample\\path\\for\\test"), - path_array_in: List[Path] = [ - Path("path/with/forward/slash"), - Path("path\\with\\backslash"), - Path("path with 'single quotes'"), - Path('path with "double quotes"'), - Path("path\twith\ttabs"), - Path("path\nwith\nnewlines"), + path_in: pathlib.PurePath = pathlib.PurePath("sample\\path\\for\\test"), + path_array_in: typing.Iterable[pathlib.PurePath] = [ + pathlib.PurePath("path/with/forward/slash"), + pathlib.PurePath("path\\with\\backslash"), + pathlib.PurePath("path with 'single quotes'"), + pathlib.PurePath('path with "double quotes"'), + pathlib.PurePath("path\twith\ttabs"), + pathlib.PurePath("path\nwith\nnewlines"), ], io_in: str = "resource", - io_array_in: List[str] = ["resource1", "resource2"], + io_array_in: typing.Iterable[str] = ["resource1", "resource2"], integer_in: int = 10, enum_in: EnumInEnum = EnumInEnum.BLUE, - enum_array_in: List[EnumInEnum] = [EnumInEnum.RED, EnumInEnum.GREEN], + enum_array_in: typing.Iterable[EnumInEnum] = [EnumInEnum.RED, EnumInEnum.GREEN], protobuf_enum_in: ProtobufEnumInEnum = ProtobufEnumInEnum.BLACK, ) -> Outputs: """Perform a single measurement. @@ -589,10 +585,10 @@ def measure( def stream_measure( self, float_in: float = 0.05999999865889549, - double_array_in: List[float] = [0.1, 0.2, 0.3], + double_array_in: typing.Iterable[float] = [0.1, 0.2, 0.3], bool_in: bool = False, string_in: str = "sample string", - string_array_in: List[str] = [ + string_array_in: typing.Iterable[str] = [ "string with /forwardslash", "string with \\backslash", "string with 'single quotes'", @@ -600,44 +596,42 @@ def stream_measure( "string with \ttabspace", "string with \nnewline", ], - path_in: Path = Path("sample\\path\\for\\test"), - path_array_in: List[Path] = [ - Path("path/with/forward/slash"), - Path("path\\with\\backslash"), - Path("path with 'single quotes'"), - Path('path with "double quotes"'), - Path("path\twith\ttabs"), - Path("path\nwith\nnewlines"), + path_in: pathlib.PurePath = pathlib.PurePath("sample\\path\\for\\test"), + path_array_in: typing.Iterable[pathlib.PurePath] = [ + pathlib.PurePath("path/with/forward/slash"), + pathlib.PurePath("path\\with\\backslash"), + pathlib.PurePath("path with 'single quotes'"), + pathlib.PurePath('path with "double quotes"'), + pathlib.PurePath("path\twith\ttabs"), + pathlib.PurePath("path\nwith\nnewlines"), ], io_in: str = "resource", - io_array_in: List[str] = ["resource1", "resource2"], + io_array_in: typing.Iterable[str] = ["resource1", "resource2"], integer_in: int = 10, enum_in: EnumInEnum = EnumInEnum.BLUE, - enum_array_in: List[EnumInEnum] = [EnumInEnum.RED, EnumInEnum.GREEN], + enum_array_in: typing.Iterable[EnumInEnum] = [EnumInEnum.RED, EnumInEnum.GREEN], protobuf_enum_in: ProtobufEnumInEnum = ProtobufEnumInEnum.BLACK, - ) -> Generator[Outputs, None, None]: + ) -> typing.Generator[Outputs, None, None]: """Perform a streaming measurement. Returns: Stream of measurement outputs. """ - parameter_values = _convert_paths_to_strings( - [ - float_in, - double_array_in, - bool_in, - string_in, - string_array_in, - path_in, - path_array_in, - io_in, - io_array_in, - integer_in, - enum_in, - enum_array_in, - protobuf_enum_in, - ] - ) + parameter_values = [ + float_in, + double_array_in, + bool_in, + string_in, + string_array_in, + path_in, + path_array_in, + io_in, + io_array_in, + integer_in, + enum_in, + enum_array_in, + protobuf_enum_in, + ] with self._initialization_lock: if self._measure_response is not None: raise RuntimeError( @@ -664,7 +658,7 @@ def cancel(self) -> bool: else: return False - def register_pin_map(self, pin_map_path: Path) -> None: + def register_pin_map(self, pin_map_path: pathlib.Path) -> None: """Registers the pin map with the pin map service. Args: @@ -672,22 +666,3 @@ def register_pin_map(self, pin_map_path: Path) -> None: """ pin_map_id = self._get_pin_map_client().update_pin_map(pin_map_path) self._pin_map_context = self._pin_map_context._replace(pin_map_id=pin_map_id) - - -def _convert_paths_to_strings(parameter_values: Iterable[Any]) -> List[Any]: - result: List[Any] = [] - - for parameter_value in parameter_values: - if isinstance(parameter_value, list): - converted_list = [] - for value in parameter_value: - if isinstance(value, Path): - converted_list.append(str(value)) - else: - converted_list.append(value) - result.append(converted_list) - elif isinstance(parameter_value, Path): - result.append(str(parameter_value)) - else: - result.append(parameter_value) - return result diff --git a/packages/generator/tests/test_assets/example_renders/measurement_plugin_client/void_measurement_client.py b/packages/generator/tests/test_assets/example_renders/measurement_plugin_client/void_measurement_client.py index b789f982c..285468dde 100644 --- a/packages/generator/tests/test_assets/example_renders/measurement_plugin_client/void_measurement_client.py +++ b/packages/generator/tests/test_assets/example_renders/measurement_plugin_client/void_measurement_client.py @@ -3,9 +3,9 @@ from __future__ import annotations import logging +import pathlib import threading -from pathlib import Path -from typing import Any, Generator, List, Optional +import typing import grpc from google.protobuf import any_pb2, descriptor_pool @@ -17,8 +17,8 @@ from ni_measurement_plugin_sdk_service.discovery import DiscoveryClient from ni_measurement_plugin_sdk_service.grpc.channelpool import GrpcChannelPool from ni_measurement_plugin_sdk_service.measurement.client_support import ( - create_file_descriptor, ParameterMetadata, + create_file_descriptor, serialize_parameters, ) from ni_measurement_plugin_sdk_service.pin_map import PinMapClient @@ -35,10 +35,10 @@ class VoidMeasurementClient: def __init__( self, *, - discovery_client: Optional[DiscoveryClient] = None, - pin_map_client: Optional[PinMapClient] = None, - grpc_channel: Optional[grpc.Channel] = None, - grpc_channel_pool: Optional[GrpcChannelPool] = None, + discovery_client: typing.Optional[DiscoveryClient] = None, + pin_map_client: typing.Optional[PinMapClient] = None, + grpc_channel: typing.Optional[grpc.Channel] = None, + grpc_channel_pool: typing.Optional[GrpcChannelPool] = None, ): """Initialize the Measurement Plug-In Client. @@ -57,8 +57,8 @@ def __init__( self._grpc_channel_pool = grpc_channel_pool self._discovery_client = discovery_client self._pin_map_client = pin_map_client - self._stub: Optional[v2_measurement_service_pb2_grpc.MeasurementServiceStub] = None - self._measure_response: Optional[ + self._stub: typing.Optional[v2_measurement_service_pb2_grpc.MeasurementServiceStub] = None + self._measure_response: typing.Optional[ grpc.CallIterator[v2_measurement_service_pb2.MeasureResponse] ] = None self._configuration_metadata = { @@ -93,12 +93,12 @@ def pin_map_context(self, val: PinMapContext) -> None: self._pin_map_context = val @property - def sites(self) -> Optional[List[int]]: + def sites(self) -> typing.Optional[typing.List[int]]: """The sites where the measurement must be executed.""" return self._pin_map_context.sites @sites.setter - def sites(self, val: List[int]) -> None: + def sites(self, val: typing.List[int]) -> None: if self._pin_map_context is None: raise AttributeError( "Cannot set sites because the pin map context is None. Please provide a pin map context or register a pin map before setting sites." @@ -155,14 +155,14 @@ def _create_file_descriptor(self) -> None: ) def _create_measure_request( - self, parameter_values: List[Any] + self, parameter_values: typing.List[typing.Any] ) -> v2_measurement_service_pb2.MeasureRequest: serialized_configuration = any_pb2.Any( type_url="type.googleapis.com/ni.measurementlink.measurement.v2.MeasurementConfigurations", value=serialize_parameters( - parameter_metadata_dict=self._configuration_metadata, - parameter_values=parameter_values, - service_name=f"{self._service_class}.Configurations", + self._configuration_metadata, + parameter_values, + f"{self._service_class}.Configurations", ), ) return v2_measurement_service_pb2.MeasureRequest( @@ -180,7 +180,7 @@ def measure(self, integer_in: int = 10) -> None: for response in stream_measure_response: pass - def stream_measure(self, integer_in: int = 10) -> Generator[None, None, None]: + def stream_measure(self, integer_in: int = 10) -> typing.Generator[None, None, None]: """Perform a streaming measurement. Returns: @@ -213,7 +213,7 @@ def cancel(self) -> bool: else: return False - def register_pin_map(self, pin_map_path: Path) -> None: + def register_pin_map(self, pin_map_path: pathlib.Path) -> None: """Registers the pin map with the pin map service. Args: diff --git a/packages/service/ni_measurement_plugin_sdk_service/measurement/client_support.py b/packages/service/ni_measurement_plugin_sdk_service/measurement/client_support.py index 26b2b6e8c..4198957a7 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/measurement/client_support.py +++ b/packages/service/ni_measurement_plugin_sdk_service/measurement/client_support.py @@ -1,11 +1,24 @@ """Support functions for the Measurement Plug-In Client.""" -from ni_measurement_plugin_sdk_service._internal.parameter.decoder import deserialize_parameters -from ni_measurement_plugin_sdk_service._internal.parameter.encoder import serialize_parameters -from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ParameterMetadata +from pathlib import Path +from typing import Any, Dict, Sequence + +from google.protobuf.descriptor_pb2 import FieldDescriptorProto + +from ni_measurement_plugin_sdk_service._annotations import TYPE_SPECIALIZATION_KEY +from ni_measurement_plugin_sdk_service._internal.parameter.decoder import ( + deserialize_parameters as _internal_deserialize_parameters, +) +from ni_measurement_plugin_sdk_service._internal.parameter.encoder import ( + serialize_parameters as _internal_serialize_parameters, +) +from ni_measurement_plugin_sdk_service._internal.parameter.metadata import ( + ParameterMetadata, +) from ni_measurement_plugin_sdk_service._internal.parameter.serialization_descriptors import ( create_file_descriptor, ) +from ni_measurement_plugin_sdk_service.measurement.info import TypeSpecialization __all__ = [ "create_file_descriptor", @@ -13,3 +26,89 @@ "ParameterMetadata", "serialize_parameters", ] + + +def deserialize_parameters( + parameter_metadata_dict: Dict[int, ParameterMetadata], + parameter_bytes: bytes, + message_name: str, + *, + convert_paths: bool = True +) -> Sequence[Any]: + """Deserialize parameter bytes into separate parameter values. + + Args: + parameter_metadata_dict: Parameter metadata by ID. + + parameter_byte: Byte string to deserialize. + + message_name: gRPC message name (e.g. f"{service_class}.Outputs"). + + convert_paths: Specifies whether to convert path parameters to pathlib.Path. + + Returns: + Deserialized parameter values, ordered by ID. + """ + parameter_values = _internal_deserialize_parameters( + parameter_metadata_dict, parameter_bytes, message_name + ) + + for id in parameter_values.keys(): + metadata = parameter_metadata_dict[id] + if ( + convert_paths + and metadata.type == FieldDescriptorProto.TYPE_STRING + and metadata.annotations + and metadata.annotations.get(TYPE_SPECIALIZATION_KEY) == TypeSpecialization.Path.value + ): + if metadata.repeated: + parameter_values[id] = [Path(value) for value in parameter_values[id]] + else: + parameter_values[id] = Path(parameter_values[id]) + + if parameter_metadata_dict: + result = [None] * max(parameter_metadata_dict.keys()) + else: + result = [] + + for k, v in parameter_values.items(): + result[k - 1] = v + + return result + + +def serialize_parameters( + parameter_metadata_dict: Dict[int, ParameterMetadata], + parameter_values: Sequence[Any], + message_name: str, +) -> bytes: + """Serialize parameter values into a parameter byte string. + + Args: + parameter_metadata_dict: Parameter metadata by ID. + + parameter_values: Parameter values to serialize, ordered by ID. + + message_name: gRPC message name (e.g. f"{service_class}.Configurations"). + + Returns: + Serialized byte string containing parameter values. + """ + new_parameter_values = list(parameter_values) + + for id in parameter_metadata_dict.keys(): + index = id - 1 + metadata = parameter_metadata_dict[id] + if ( + metadata.type == FieldDescriptorProto.TYPE_STRING + and metadata.annotations + and metadata.annotations.get(TYPE_SPECIALIZATION_KEY) == TypeSpecialization.Path.value + ): + if metadata.repeated: + new_parameter_values[index] = [str(value) for value in parameter_values[index]] + else: + new_parameter_values[index] = str(parameter_values[index]) + + return _internal_serialize_parameters( + parameter_metadata_dict, new_parameter_values, message_name + )