Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[releases/2.1] Cherry-pick: [High Priority] Fix Path type conversion and type hints for client codegen #952

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,27 @@
from google.protobuf import descriptor_pool
from google.protobuf.descriptor_pb2 import FieldDescriptorProto
from google.protobuf.type_pb2 import Field
from ni_measurement_plugin_sdk_service._internal.grpc_servicer import frame_metadata_dict
from ni_measurement_plugin_sdk_service._internal.grpc_servicer import (
frame_metadata_dict,
)
from ni_measurement_plugin_sdk_service._internal.stubs.ni.measurementlink.measurement.v2 import (
measurement_service_pb2 as v2_measurement_service_pb2,
measurement_service_pb2_grpc as v2_measurement_service_pb2_grpc,
)
from ni_measurement_plugin_sdk_service.discovery import DiscoveryClient
from ni_measurement_plugin_sdk_service.grpc.channelpool import GrpcChannelPool
from ni_measurement_plugin_sdk_service.measurement.client_support import (
ParameterMetadata,
create_file_descriptor,
deserialize_parameters,
ParameterMetadata,
)


_V2_MEASUREMENT_SERVICE_INTERFACE = "ni.measurementlink.measurement.v2.MeasurementService"

_INVALID_CHARS = "`~!@#$%^&*()-+={}[]\\|:;',<>.?/ \n"

_XY_DATA_IMPORT = "from ni_measurement_plugin_sdk_service._internal.stubs.ni.protobuf.types.xydata_pb2 import DoubleXYData"
_PATH_IMPORT = "from pathlib import Path"
_PATH_IMPORT = "import pathlib"

_PROTO_DATATYPE_TO_PYTYPE_LOOKUP = {
Field.TYPE_INT32: int,
Expand Down Expand Up @@ -178,21 +179,24 @@ def get_configuration_and_output_metadata_by_index(
)
configuration_metadata = frame_metadata_dict(configuration_parameter_list)
output_metadata = frame_metadata_dict(output_parameter_list)
deserialized_parameters = deserialize_parameters(

# Disable path conversion to avoid normalizing path separators and eliminate the need to
# convert Path objects to strings.
default_values = deserialize_parameters(
configuration_metadata,
metadata.measurement_signature.configuration_defaults.value,
f"{service_class}.Configurations",
convert_paths=False,
)

for k, v in deserialized_parameters.items():
if issubclass(type(v), Enum):
default_value = v.value
elif issubclass(type(v), list) and any(issubclass(type(e), Enum) for e in v):
default_value = [e.value for e in v]
else:
default_value = v

configuration_metadata[k] = configuration_metadata[k]._replace(default_value=default_value)
for id, default_value in enumerate(default_values, start=1):
if isinstance(default_value, Enum):
default_value = default_value.value
elif isinstance(default_value, list) and any(isinstance(e, Enum) for e in default_value):
default_value = [e.value for e in default_value]
configuration_metadata[id] = configuration_metadata[id]._replace(
default_value=default_value
)

return configuration_metadata, output_metadata

Expand All @@ -211,19 +215,21 @@ def get_configuration_parameters_with_type_and_default_values(
parameter_names.append(parameter_name)

default_value = metadata.default_value
parameter_type = _get_python_type_as_str(metadata.type, metadata.repeated)
parameter_type = _get_configuration_python_type_as_str(metadata.type, metadata.repeated)
if isinstance(default_value, str):
default_value = repr(default_value)

if metadata.annotations and metadata.annotations.get("ni/type_specialization") == "path":
parameter_type = "Path"
parameter_type = "pathlib.PurePath"
built_in_import_modules.append(_PATH_IMPORT)
if metadata.repeated:
formatted_value = ", ".join(f"Path({repr(value)})" for value in default_value)
formatted_value = ", ".join(
f"pathlib.PurePath({repr(value)})" for value in default_value
)
default_value = f"[{formatted_value}]"
parameter_type = f"List[{parameter_type}]"
parameter_type = f"typing.Iterable[{parameter_type}]"
else:
default_value = f"Path({default_value})"
default_value = f"pathlib.PurePath({default_value})"

if metadata.message_type:
raise click.ClickException(
Expand All @@ -243,7 +249,7 @@ def get_configuration_parameters_with_type_and_default_values(
concatenated_default_value = ", ".join(values)
concatenated_default_value = f"[{concatenated_default_value}]"

parameter_type = f"List[{parameter_type}]"
parameter_type = f"typing.Iterable[{parameter_type}]"
default_value = concatenated_default_value
else:
enum_value = next((e.name for e in enum_type if e.value == default_value), None)
Expand All @@ -267,27 +273,29 @@ def get_output_parameters_with_type(
output_parameters_with_type: List[str] = []
for metadata in output_metadata.values():
parameter_name = _get_python_identifier(metadata.display_name)
parameter_type = _get_python_type_as_str(metadata.type, metadata.repeated)
parameter_type = _get_output_python_type_as_str(metadata.type, metadata.repeated)

if metadata.annotations and metadata.annotations.get("ni/type_specialization") == "path":
parameter_type = "Path"
parameter_type = "pathlib.Path"
built_in_import_modules.append(_PATH_IMPORT)

if metadata.repeated:
parameter_type = f"List[{parameter_type}]"
parameter_type = f"typing.Sequence[{parameter_type}]"

if metadata.message_type and metadata.message_type == "ni.protobuf.types.DoubleXYData":
parameter_type = "DoubleXYData"
custom_import_modules.append(_XY_DATA_IMPORT)

if metadata.repeated:
parameter_type = f"List[{parameter_type}]"
parameter_type = f"typing.Sequence[{parameter_type}]"

if metadata.annotations and metadata.annotations.get("ni/type_specialization") == "enum":
enum_type_name = _get_enum_type(
metadata.display_name, metadata.annotations["ni/enum.values"], enum_values_by_type
).__name__
parameter_type = f"List[{enum_type_name}]" if metadata.repeated else enum_type_name
parameter_type = (
f"typing.Sequence[{enum_type_name}]" if metadata.repeated else enum_type_name
)

output_parameters_with_type.append(f"{parameter_name}: {parameter_type}")

Expand Down Expand Up @@ -374,14 +382,25 @@ def _get_python_identifier(input_string: str) -> str:
return valid_identifier


def _get_python_type_as_str(type: Field.Kind.ValueType, is_array: bool) -> str:
def _get_configuration_python_type_as_str(type: Field.Kind.ValueType, is_array: bool) -> str:
python_type = _PROTO_DATATYPE_TO_PYTYPE_LOOKUP.get(type)

if python_type is None:
raise TypeError(f"Invalid data type: '{type}'.")

if is_array:
return f"typing.Iterable[{python_type.__name__}]"
return python_type.__name__


def _get_output_python_type_as_str(type: Field.Kind.ValueType, is_array: bool) -> str:
python_type = _PROTO_DATATYPE_TO_PYTYPE_LOOKUP.get(type)

if python_type is None:
raise TypeError(f"Invalid data type: '{type}'.")

if is_array:
return f"List[{python_type.__name__}]"
return f"typing.Sequence[{python_type.__name__}]"
return python_type.__name__


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,12 @@ from typing import Any
from __future__ import annotations

import logging
import pathlib
import threading
import typing
% if len(enum_by_class_name):
from enum import Enum
% endif
from pathlib import Path
<%
typing_imports = ["Any", "Generator", "List", "Optional"]
if output_metadata:
typing_imports += ["NamedTuple"]
if "from pathlib import Path" in built_in_import_modules:
typing_imports += ["Iterable"]
%>\
from typing import ${", ".join(sorted(typing_imports))}

import grpc
from google.protobuf import any_pb2, descriptor_pool
Expand All @@ -46,11 +39,11 @@ ${module}
from ni_measurement_plugin_sdk_service.discovery import DiscoveryClient
from ni_measurement_plugin_sdk_service.grpc.channelpool import GrpcChannelPool
from ni_measurement_plugin_sdk_service.measurement.client_support import (
ParameterMetadata,
create_file_descriptor,
% if output_metadata:
deserialize_parameters,
% endif
ParameterMetadata,
serialize_parameters,
)
from ni_measurement_plugin_sdk_service.pin_map import PinMapClient
Expand All @@ -73,7 +66,7 @@ class ${enum_name.__name__}(Enum):
<% output_type = "None" %>\
% if output_metadata:

class Outputs(NamedTuple):
class Outputs(typing.NamedTuple):
"""Outputs for the ${display_name | repr} measurement plug-in."""

% for output_parameter in output_parameters_with_type:
Expand All @@ -89,10 +82,10 @@ class ${class_name}:
def __init__(
self,
*,
discovery_client: Optional[DiscoveryClient] = None,
pin_map_client: Optional[PinMapClient] = None,
grpc_channel: Optional[grpc.Channel] = None,
grpc_channel_pool: Optional[GrpcChannelPool] = None,
discovery_client: typing.Optional[DiscoveryClient] = None,
pin_map_client: typing.Optional[PinMapClient] = None,
grpc_channel: typing.Optional[grpc.Channel] = None,
grpc_channel_pool: typing.Optional[GrpcChannelPool] = None,
):
"""Initialize the Measurement Plug-In Client.

Expand All @@ -111,8 +104,8 @@ class ${class_name}:
self._grpc_channel_pool = grpc_channel_pool
self._discovery_client = discovery_client
self._pin_map_client = pin_map_client
self._stub: Optional[v2_measurement_service_pb2_grpc.MeasurementServiceStub] = None
self._measure_response: Optional[
self._stub: typing.Optional[v2_measurement_service_pb2_grpc.MeasurementServiceStub] = None
self._measure_response: typing.Optional[
grpc.CallIterator[v2_measurement_service_pb2.MeasureResponse]
] = None
self._configuration_metadata = {
Expand Down Expand Up @@ -170,12 +163,12 @@ class ${class_name}:
self._pin_map_context = val

@property
def sites(self) -> Optional[List[int]]:
def sites(self) -> typing.Optional[typing.List[int]]:
"""The sites where the measurement must be executed."""
return self._pin_map_context.sites

@sites.setter
def sites(self, val: List[int]) -> None:
def sites(self, val: typing.List[int]) -> None:
if self._pin_map_context is None:
raise AttributeError(
"Cannot set sites because the pin map context is None. Please provide a pin map context or register a pin map before setting sites."
Expand Down Expand Up @@ -230,14 +223,14 @@ class ${class_name}:
)

def _create_measure_request(
self, parameter_values: List[Any]
self, parameter_values: typing.List[typing.Any]
) -> v2_measurement_service_pb2.MeasureRequest:
serialized_configuration = any_pb2.Any(
type_url=${configuration_parameters_type_url | repr},
value=serialize_parameters(
parameter_metadata_dict=self._configuration_metadata,
parameter_values=parameter_values,
service_name=f"{self._service_class}.Configurations",
self._configuration_metadata,
parameter_values,
f"{self._service_class}.Configurations",
)
)
return v2_measurement_service_pb2.MeasureRequest(
Expand All @@ -249,17 +242,13 @@ class ${class_name}:
def _deserialize_response(
self, response: v2_measurement_service_pb2.MeasureResponse
) -> Outputs:
if self._output_metadata:
result = [None] * max(self._output_metadata.keys())
else:
result = []
output_values = deserialize_parameters(
self._output_metadata, response.outputs.value, f"{self._service_class}.Outputs"
return Outputs._make(
deserialize_parameters(
self._output_metadata,
response.outputs.value,
f"{self._service_class}.Outputs",
)
)

for k, v in output_values.items():
result[k - 1] = v
return Outputs._make(result)
% endif

def measure(
Expand All @@ -285,17 +274,13 @@ class ${class_name}:
def stream_measure(
self,
${configuration_parameters_with_type_and_default_values}
) -> Generator[${output_type}, None, None] :
) -> typing.Generator[${output_type}, None, None] :
"""Perform a streaming measurement.

Returns:
Stream of measurement outputs.
"""
% if "from pathlib import Path" in built_in_import_modules:
parameter_values = _convert_paths_to_strings([${measure_api_parameters}])
% else:
parameter_values = [${measure_api_parameters}]
% endif
with self._initialization_lock:
if self._measure_response is not None:
raise RuntimeError(
Expand Down Expand Up @@ -327,33 +312,11 @@ class ${class_name}:
else:
return False

def register_pin_map(self, pin_map_path: Path) -> None:
def register_pin_map(self, pin_map_path: pathlib.Path) -> None:
"""Registers the pin map with the pin map service.

Args:
pin_map_path: Absolute path of the pin map file.
"""
pin_map_id = self._get_pin_map_client().update_pin_map(pin_map_path)
self._pin_map_context = self._pin_map_context._replace(pin_map_id=pin_map_id)

% if "from pathlib import Path" in built_in_import_modules:

def _convert_paths_to_strings(parameter_values: Iterable[Any]) -> List[Any]:
result: List[Any] = []

for parameter_value in parameter_values:
if isinstance(parameter_value, list):
converted_list = []
for value in parameter_value:
if isinstance(value, Path):
converted_list.append(str(value))
else:
converted_list.append(value)
result.append(converted_list)
elif isinstance(parameter_value, Path):
result.append(str(parameter_value))
else:
result.append(parameter_value)
return result

% endif
Loading