Skip to content

Commit

Permalink
[releases/2.1] Cherry-pick: [High Priority] Fix Path type conversion …
Browse files Browse the repository at this point in the history
…and type hints for client codegen (#952)

[High Priority] Fix Path type conversion and type hints for client codegen (#931)

* generator: Fix Path type conversion for client outputs

* generator: Annotate array configuration parameters as Iterable[T] and array output parameters as Sequence[T]

* generator: Fix lint errors

* service: Fix lint errors

* service: Revert debugging code

* service: Simplify convert_paths_to_strings and convert_strings_to_paths

* generator: Simplify client imports

`from typing import big, list, of, things` results in a lot of conditional code. It's simpler to use `import typing` and reference `typing.Iterable`, etc.

* generator: Use pathlib.PurePath for configuration parameters

* generator: Fix lint errors

* generator: Move path conversion and dict/list conversion into client_support module

* service: Fix parameter name in client_support

* service: Update client_support to use type specialization constants

(cherry picked from commit 6e0aedb)
  • Loading branch information
bkeryan authored Sep 27, 2024
1 parent 7c19d3c commit 7df18c2
Show file tree
Hide file tree
Showing 6 changed files with 338 additions and 217 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,27 @@
from google.protobuf import descriptor_pool
from google.protobuf.descriptor_pb2 import FieldDescriptorProto
from google.protobuf.type_pb2 import Field
from ni_measurement_plugin_sdk_service._internal.grpc_servicer import frame_metadata_dict
from ni_measurement_plugin_sdk_service._internal.grpc_servicer import (
frame_metadata_dict,
)
from ni_measurement_plugin_sdk_service._internal.stubs.ni.measurementlink.measurement.v2 import (
measurement_service_pb2 as v2_measurement_service_pb2,
measurement_service_pb2_grpc as v2_measurement_service_pb2_grpc,
)
from ni_measurement_plugin_sdk_service.discovery import DiscoveryClient
from ni_measurement_plugin_sdk_service.grpc.channelpool import GrpcChannelPool
from ni_measurement_plugin_sdk_service.measurement.client_support import (
ParameterMetadata,
create_file_descriptor,
deserialize_parameters,
ParameterMetadata,
)


_V2_MEASUREMENT_SERVICE_INTERFACE = "ni.measurementlink.measurement.v2.MeasurementService"

_INVALID_CHARS = "`~!@#$%^&*()-+={}[]\\|:;',<>.?/ \n"

_XY_DATA_IMPORT = "from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types.xydata_pb2 import DoubleXYData"
_PATH_IMPORT = "from pathlib import Path"
_PATH_IMPORT = "import pathlib"

_PROTO_DATATYPE_TO_PYTYPE_LOOKUP = {
Field.TYPE_INT32: int,
Expand Down Expand Up @@ -178,21 +179,24 @@ def get_configuration_and_output_metadata_by_index(
)
configuration_metadata = frame_metadata_dict(configuration_parameter_list)
output_metadata = frame_metadata_dict(output_parameter_list)
deserialized_parameters = deserialize_parameters(

# Disable path conversion to avoid normalizing path separators and eliminate the need to
# convert Path objects to strings.
default_values = deserialize_parameters(
configuration_metadata,
metadata.measurement_signature.configuration_defaults.value,
f"{service_class}.Configurations",
convert_paths=False,
)

for k, v in deserialized_parameters.items():
if issubclass(type(v), Enum):
default_value = v.value
elif issubclass(type(v), list) and any(issubclass(type(e), Enum) for e in v):
default_value = [e.value for e in v]
else:
default_value = v

configuration_metadata[k] = configuration_metadata[k]._replace(default_value=default_value)
for id, default_value in enumerate(default_values, start=1):
if isinstance(default_value, Enum):
default_value = default_value.value
elif isinstance(default_value, list) and any(isinstance(e, Enum) for e in default_value):
default_value = [e.value for e in default_value]
configuration_metadata[id] = configuration_metadata[id]._replace(
default_value=default_value
)

return configuration_metadata, output_metadata

Expand All @@ -211,19 +215,21 @@ def get_configuration_parameters_with_type_and_default_values(
parameter_names.append(parameter_name)

default_value = metadata.default_value
parameter_type = _get_python_type_as_str(metadata.type, metadata.repeated)
parameter_type = _get_configuration_python_type_as_str(metadata.type, metadata.repeated)
if isinstance(default_value, str):
default_value = repr(default_value)

if metadata.annotations and metadata.annotations.get("ni/type_specialization") == "path":
parameter_type = "Path"
parameter_type = "pathlib.PurePath"
built_in_import_modules.append(_PATH_IMPORT)
if metadata.repeated:
formatted_value = ", ".join(f"Path({repr(value)})" for value in default_value)
formatted_value = ", ".join(
f"pathlib.PurePath({repr(value)})" for value in default_value
)
default_value = f"[{formatted_value}]"
parameter_type = f"List[{parameter_type}]"
parameter_type = f"typing.Iterable[{parameter_type}]"
else:
default_value = f"Path({default_value})"
default_value = f"pathlib.PurePath({default_value})"

if metadata.message_type:
raise click.ClickException(
Expand All @@ -243,7 +249,7 @@ def get_configuration_parameters_with_type_and_default_values(
concatenated_default_value = ", ".join(values)
concatenated_default_value = f"[{concatenated_default_value}]"

parameter_type = f"List[{parameter_type}]"
parameter_type = f"typing.Iterable[{parameter_type}]"
default_value = concatenated_default_value
else:
enum_value = next((e.name for e in enum_type if e.value == default_value), None)
Expand All @@ -267,27 +273,29 @@ def get_output_parameters_with_type(
output_parameters_with_type: List[str] = []
for metadata in output_metadata.values():
parameter_name = _get_python_identifier(metadata.display_name)
parameter_type = _get_python_type_as_str(metadata.type, metadata.repeated)
parameter_type = _get_output_python_type_as_str(metadata.type, metadata.repeated)

if metadata.annotations and metadata.annotations.get("ni/type_specialization") == "path":
parameter_type = "Path"
parameter_type = "pathlib.Path"
built_in_import_modules.append(_PATH_IMPORT)

if metadata.repeated:
parameter_type = f"List[{parameter_type}]"
parameter_type = f"typing.Sequence[{parameter_type}]"

if metadata.message_type and metadata.message_type == "ni.protobuf.types.DoubleXYData":
parameter_type = "DoubleXYData"
custom_import_modules.append(_XY_DATA_IMPORT)

if metadata.repeated:
parameter_type = f"List[{parameter_type}]"
parameter_type = f"typing.Sequence[{parameter_type}]"

if metadata.annotations and metadata.annotations.get("ni/type_specialization") == "enum":
enum_type_name = _get_enum_type(
metadata.display_name, metadata.annotations["ni/enum.values"], enum_values_by_type
).__name__
parameter_type = f"List[{enum_type_name}]" if metadata.repeated else enum_type_name
parameter_type = (
f"typing.Sequence[{enum_type_name}]" if metadata.repeated else enum_type_name
)

output_parameters_with_type.append(f"{parameter_name}: {parameter_type}")

Expand Down Expand Up @@ -374,14 +382,25 @@ def _get_python_identifier(input_string: str) -> str:
return valid_identifier


def _get_python_type_as_str(type: Field.Kind.ValueType, is_array: bool) -> str:
def _get_configuration_python_type_as_str(type: Field.Kind.ValueType, is_array: bool) -> str:
python_type = _PROTO_DATATYPE_TO_PYTYPE_LOOKUP.get(type)

if python_type is None:
raise TypeError(f"Invalid data type: '{type}'.")

if is_array:
return f"typing.Iterable[{python_type.__name__}]"
return python_type.__name__


def _get_output_python_type_as_str(type: Field.Kind.ValueType, is_array: bool) -> str:
python_type = _PROTO_DATATYPE_TO_PYTYPE_LOOKUP.get(type)

if python_type is None:
raise TypeError(f"Invalid data type: '{type}'.")

if is_array:
return f"List[{python_type.__name__}]"
return f"typing.Sequence[{python_type.__name__}]"
return python_type.__name__


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,12 @@ from typing import Any
from __future__ import annotations

import logging
import pathlib
import threading
import typing
% if len(enum_by_class_name):
from enum import Enum
% endif
from pathlib import Path
<%
typing_imports = ["Any", "Generator", "List", "Optional"]
if output_metadata:
typing_imports += ["NamedTuple"]
if "from pathlib import Path" in built_in_import_modules:
typing_imports += ["Iterable"]
%>\
from typing import ${", ".join(sorted(typing_imports))}

import grpc
from google.protobuf import any_pb2, descriptor_pool
Expand All @@ -46,11 +39,11 @@ ${module}
from ni_measurement_plugin_sdk_service.discovery import DiscoveryClient
from ni_measurement_plugin_sdk_service.grpc.channelpool import GrpcChannelPool
from ni_measurement_plugin_sdk_service.measurement.client_support import (
ParameterMetadata,
create_file_descriptor,
% if output_metadata:
deserialize_parameters,
% endif
ParameterMetadata,
serialize_parameters,
)
from ni_measurement_plugin_sdk_service.pin_map import PinMapClient
Expand All @@ -73,7 +66,7 @@ class ${enum_name.__name__}(Enum):
<% output_type = "None" %>\
% if output_metadata:

class Outputs(NamedTuple):
class Outputs(typing.NamedTuple):
"""Outputs for the ${display_name | repr} measurement plug-in."""

% for output_parameter in output_parameters_with_type:
Expand All @@ -89,10 +82,10 @@ class ${class_name}:
def __init__(
self,
*,
discovery_client: Optional[DiscoveryClient] = None,
pin_map_client: Optional[PinMapClient] = None,
grpc_channel: Optional[grpc.Channel] = None,
grpc_channel_pool: Optional[GrpcChannelPool] = None,
discovery_client: typing.Optional[DiscoveryClient] = None,
pin_map_client: typing.Optional[PinMapClient] = None,
grpc_channel: typing.Optional[grpc.Channel] = None,
grpc_channel_pool: typing.Optional[GrpcChannelPool] = None,
):
"""Initialize the Measurement Plug-In Client.

Expand All @@ -111,8 +104,8 @@ class ${class_name}:
self._grpc_channel_pool = grpc_channel_pool
self._discovery_client = discovery_client
self._pin_map_client = pin_map_client
self._stub: Optional[v2_measurement_service_pb2_grpc.MeasurementServiceStub] = None
self._measure_response: Optional[
self._stub: typing.Optional[v2_measurement_service_pb2_grpc.MeasurementServiceStub] = None
self._measure_response: typing.Optional[
grpc.CallIterator[v2_measurement_service_pb2.MeasureResponse]
] = None
self._configuration_metadata = {
Expand Down Expand Up @@ -170,12 +163,12 @@ class ${class_name}:
self._pin_map_context = val

@property
def sites(self) -> Optional[List[int]]:
def sites(self) -> typing.Optional[typing.List[int]]:
"""The sites where the measurement must be executed."""
return self._pin_map_context.sites

@sites.setter
def sites(self, val: List[int]) -> None:
def sites(self, val: typing.List[int]) -> None:
if self._pin_map_context is None:
raise AttributeError(
"Cannot set sites because the pin map context is None. Please provide a pin map context or register a pin map before setting sites."
Expand Down Expand Up @@ -230,14 +223,14 @@ class ${class_name}:
)

def _create_measure_request(
self, parameter_values: List[Any]
self, parameter_values: typing.List[typing.Any]
) -> v2_measurement_service_pb2.MeasureRequest:
serialized_configuration = any_pb2.Any(
type_url=${configuration_parameters_type_url | repr},
value=serialize_parameters(
parameter_metadata_dict=self._configuration_metadata,
parameter_values=parameter_values,
service_name=f"{self._service_class}.Configurations",
self._configuration_metadata,
parameter_values,
f"{self._service_class}.Configurations",
)
)
return v2_measurement_service_pb2.MeasureRequest(
Expand All @@ -249,17 +242,13 @@ class ${class_name}:
def _deserialize_response(
self, response: v2_measurement_service_pb2.MeasureResponse
) -> Outputs:
if self._output_metadata:
result = [None] * max(self._output_metadata.keys())
else:
result = []
output_values = deserialize_parameters(
self._output_metadata, response.outputs.value, f"{self._service_class}.Outputs"
return Outputs._make(
deserialize_parameters(
self._output_metadata,
response.outputs.value,
f"{self._service_class}.Outputs",
)
)

for k, v in output_values.items():
result[k - 1] = v
return Outputs._make(result)
% endif

def measure(
Expand All @@ -285,17 +274,13 @@ class ${class_name}:
def stream_measure(
self,
${configuration_parameters_with_type_and_default_values}
) -> Generator[${output_type}, None, None] :
) -> typing.Generator[${output_type}, None, None] :
"""Perform a streaming measurement.

Returns:
Stream of measurement outputs.
"""
% if "from pathlib import Path" in built_in_import_modules:
parameter_values = _convert_paths_to_strings([${measure_api_parameters}])
% else:
parameter_values = [${measure_api_parameters}]
% endif
with self._initialization_lock:
if self._measure_response is not None:
raise RuntimeError(
Expand Down Expand Up @@ -327,33 +312,11 @@ class ${class_name}:
else:
return False

def register_pin_map(self, pin_map_path: Path) -> None:
def register_pin_map(self, pin_map_path: pathlib.Path) -> None:
"""Registers the pin map with the pin map service.

Args:
pin_map_path: Absolute path of the pin map file.
"""
pin_map_id = self._get_pin_map_client().update_pin_map(pin_map_path)
self._pin_map_context = self._pin_map_context._replace(pin_map_id=pin_map_id)

% if "from pathlib import Path" in built_in_import_modules:

def _convert_paths_to_strings(parameter_values: Iterable[Any]) -> List[Any]:
result: List[Any] = []

for parameter_value in parameter_values:
if isinstance(parameter_value, list):
converted_list = []
for value in parameter_value:
if isinstance(value, Path):
converted_list.append(str(value))
else:
converted_list.append(value)
result.append(converted_list)
elif isinstance(parameter_value, Path):
result.append(str(parameter_value))
else:
result.append(parameter_value)
return result

% endif
Loading

0 comments on commit 7df18c2

Please sign in to comment.