Skip to content

Commit

Permalink
Fix pytype (google-gemini#39)
Browse files Browse the repository at this point in the history
* Squashed commit of the following:

commit de62d09
Author: Mark Daoust <markdaoust@google.com>
Date:   Tue Jun 6 17:26:10 2023 -0700

    Revert discuss types (tests failing with 3.9).

commit 7e2ccd8
Author: Mark Daoust <markdaoust@google.com>
Date:   Fri Jun 2 13:59:38 2023 -0700

    format

commit 70cc6d5
Author: Mark Daoust <markdaoust@google.com>
Date:   Fri Jun 2 13:47:20 2023 -0700

    Add future annotations for py3.9

commit d74f436
Author: Mark Daoust <markdaoust@google.com>
Date:   Fri Jun 2 13:30:54 2023 -0700

    Update Unions

commit 4d2e710
Author: Mark Daoust <markdaoust@google.com>
Date:   Fri Jun 2 13:23:37 2023 -0700

    Update Unions

commit 3d93a3e
Author: Mark Daoust <markdaoust@google.com>
Date:   Thu Jun 1 15:49:15 2023 -0700

    Pytype passes with python 3.10

    + modernize some annotations

* Debug: revert to Union for non-lazy annotations

* format
  • Loading branch information
MarkDaoust committed Jun 9, 2023
1 parent 8cc54d1 commit ef27fec
Show file tree
Hide file tree
Showing 17 changed files with 118 additions and 98 deletions.
1 change: 1 addition & 0 deletions google/generativeai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
```
"""
from __future__ import annotations

from google.generativeai import types
from google.generativeai import version
Expand Down
11 changes: 6 additions & 5 deletions google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import os
from typing import cast, Optional, Union
Expand All @@ -35,17 +36,17 @@

def configure(
*,
api_key: Optional[str] = None,
credentials: Union[ga_credentials.Credentials, dict, None] = None,
api_key: str | None = None,
credentials: ga_credentials.Credentials | dict | None = None,
# The user can pass a string to choose `rest` or `grpc` or 'grpc_asyncio'.
# See `_transport_registry` in `DiscussServiceClientMeta`.
# Since the transport classes align with the client classes it wouldn't make
# sense to accept a `Transport` object here even though the client classes can.
# We could accept a dict since all the `Transport` classes take the same args,
# but that seems rare. Users that need it can just switch to the low level API.
transport: Union[str, None] = None,
client_options: Union[client_options_lib.ClientOptions, dict, None] = None,
client_info: Optional[gapic_v1.client_info.ClientInfo] = None,
transport: str | None = None,
client_options: client_options_lib.ClientOptions | dict | None = None,
client_info: gapic_v1.client_info.ClientInfo | None = None,
):
"""Captures default client configuration.
Expand Down
89 changes: 45 additions & 44 deletions google/generativeai/discuss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import dataclasses
import sys
Expand Down Expand Up @@ -150,9 +151,9 @@ def _make_examples(examples: discuss_types.ExamplesOptions) -> List[glm.Example]
def _make_message_prompt_dict(
prompt: discuss_types.MessagePromptOptions = None,
*,
context: Optional[str] = None,
examples: Optional[discuss_types.ExamplesOptions] = None,
messages: Optional[discuss_types.MessagesOptions] = None,
context: str | None = None,
examples: discuss_types.ExamplesOptions | None = None,
messages: discuss_types.MessagesOptions | None = None,
) -> glm.MessagePrompt:
if prompt is None:
prompt = dict(
Expand Down Expand Up @@ -196,9 +197,9 @@ def _make_message_prompt_dict(
def _make_message_prompt(
prompt: discuss_types.MessagePromptOptions = None,
*,
context: Optional[str] = None,
examples: Optional[discuss_types.ExamplesOptions] = None,
messages: Optional[discuss_types.MessagesOptions] = None,
context: str | None = None,
examples: discuss_types.ExamplesOptions | None = None,
messages: discuss_types.MessagesOptions | None = None,
) -> glm.MessagePrompt:
prompt = _make_message_prompt_dict(
prompt=prompt, context=context, examples=examples, messages=messages
Expand All @@ -208,15 +209,15 @@ def _make_message_prompt(

def _make_generate_message_request(
*,
model: Optional[model_types.ModelNameOptions],
context: Optional[str] = None,
examples: Optional[discuss_types.ExamplesOptions] = None,
messages: Optional[discuss_types.MessagesOptions] = None,
temperature: Optional[float] = None,
candidate_count: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[float] = None,
prompt: Optional[discuss_types.MessagePromptOptions] = None,
model: model_types.ModelNameOptions | None,
context: str | None = None,
examples: discuss_types.ExamplesOptions | None = None,
messages: discuss_types.MessagesOptions | None = None,
temperature: float | None = None,
candidate_count: int | None = None,
top_p: float | None = None,
top_k: float | None = None,
prompt: discuss_types.MessagePromptOptions | None = None,
) -> glm.GenerateMessageRequest:
model = model_types.make_model_name(model)

Expand Down Expand Up @@ -247,16 +248,16 @@ def inner(f):

def chat(
*,
model: Optional[model_types.ModelNameOptions] = "models/chat-bison-001",
context: Optional[str] = None,
examples: Optional[discuss_types.ExamplesOptions] = None,
messages: Optional[discuss_types.MessagesOptions] = None,
temperature: Optional[float] = None,
candidate_count: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[float] = None,
prompt: Optional[discuss_types.MessagePromptOptions] = None,
client: Optional[glm.DiscussServiceClient] = None,
model: model_types.ModelNameOptions | None = "models/chat-bison-001",
context: str | None = None,
examples: discuss_types.ExamplesOptions | None = None,
messages: discuss_types.MessagesOptions | None = None,
temperature: float | None = None,
candidate_count: int | None = None,
top_p: float | None = None,
top_k: float | None = None,
prompt: discuss_types.MessagePromptOptions | None = None,
client: glm.DiscussServiceClient | None = None,
) -> discuss_types.ChatResponse:
"""Calls the API and returns a `types.ChatResponse` containing the response.
Expand Down Expand Up @@ -345,16 +346,16 @@ def chat(
@set_doc(chat.__doc__)
async def chat_async(
*,
model: Optional[model_types.ModelNameOptions] = None,
context: Optional[str] = None,
examples: Optional[discuss_types.ExamplesOptions] = None,
messages: Optional[discuss_types.MessagesOptions] = None,
temperature: Optional[float] = None,
candidate_count: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[float] = None,
prompt: Optional[discuss_types.MessagePromptOptions] = None,
client: Optional[glm.DiscussServiceAsyncClient] = None,
model: model_types.ModelNameOptions | None = None,
context: str | None = None,
examples: discuss_types.ExamplesOptions | None = None,
messages: discuss_types.MessagesOptions | None = None,
temperature: float | None = None,
candidate_count: int | None = None,
top_p: float | None = None,
top_k: float | None = None,
prompt: discuss_types.MessagePromptOptions | None = None,
client: glm.DiscussServiceAsyncClient | None = None,
) -> discuss_types.ChatResponse:
request = _make_generate_message_request(
model=model,
Expand All @@ -380,7 +381,7 @@ async def chat_async(
@set_doc(discuss_types.ChatResponse.__doc__)
@dataclasses.dataclass(**DATACLASS_KWARGS, init=False)
class ChatResponse(discuss_types.ChatResponse):
_client: Optional[glm.DiscussServiceClient] = dataclasses.field(
_client: glm.DiscussServiceClient | None = dataclasses.field(
default=lambda: None, repr=False
)

Expand All @@ -390,7 +391,7 @@ def __init__(self, **kwargs):

@property
@set_doc(discuss_types.ChatResponse.last.__doc__)
def last(self) -> Optional[str]:
def last(self) -> str | None:
if self.messages[-1]:
return self.messages[-1]["content"]
else:
Expand Down Expand Up @@ -445,7 +446,7 @@ async def reply_async(
def _build_chat_response(
request: glm.GenerateMessageRequest,
response: glm.GenerateMessageResponse,
client: Union[glm.DiscussServiceClient, glm.DiscussServiceAsyncClient],
client: glm.DiscussServiceClient | glm.DiscussServiceAsyncClient,
) -> ChatResponse:
request = type(request).to_dict(request)
prompt = request.pop("prompt")
Expand Down Expand Up @@ -473,7 +474,7 @@ def _build_chat_response(

def _generate_response(
request: glm.GenerateMessageRequest,
client: Optional[glm.DiscussServiceClient] = None,
client: glm.DiscussServiceClient | None = None,
) -> ChatResponse:
if client is None:
client = get_default_discuss_client()
Expand All @@ -485,7 +486,7 @@ def _generate_response(

async def _generate_response_async(
request: glm.GenerateMessageRequest,
client: Optional[glm.DiscussServiceAsyncClient] = None,
client: glm.DiscussServiceAsyncClient | None = None,
) -> ChatResponse:
if client is None:
client = get_default_discuss_async_client()
Expand All @@ -498,11 +499,11 @@ async def _generate_response_async(
def count_message_tokens(
*,
prompt: discuss_types.MessagePromptOptions = None,
context: Optional[str] = None,
examples: Optional[discuss_types.ExamplesOptions] = None,
messages: Optional[discuss_types.MessagesOptions] = None,
context: str | None = None,
examples: discuss_types.ExamplesOptions | None = None,
messages: discuss_types.MessagesOptions | None = None,
model: model_types.ModelNameOptions = DEFAULT_DISCUSS_MODEL,
client: Optional[glm.DiscussServiceAsyncClient] = None,
client: glm.DiscussServiceAsyncClient | None = None,
):
model = model_types.make_model_name(model)
prompt = _make_message_prompt(
Expand Down
1 change: 1 addition & 0 deletions google/generativeai/docstring_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations


def strip_oneof(docstring):
Expand Down
8 changes: 5 additions & 3 deletions google/generativeai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import re
from typing import Optional, List

Expand All @@ -37,9 +39,9 @@ def __init__(
self,
*,
page_size: int,
page_token: Optional[str],
page_token: str | None,
models: List[model_types.Model],
client: Optional[glm.ModelServiceClient],
client: glm.ModelServiceClient | None,
):
self._page_size = page_size
self._page_token = page_token
Expand Down Expand Up @@ -73,7 +75,7 @@ def _list_models(page_size, page_token, client):


def list_models(
*, page_size: Optional[int] = None, client: Optional[glm.ModelServiceClient] = None
*, page_size: int | None = None, client: glm.ModelServiceClient | None = None
) -> model_types.ModelsIterable:
"""Lists available models.
Expand Down
2 changes: 1 addition & 1 deletion google/generativeai/notebook/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def load_ipython_extension(ipython):
# Since we're in an interactive environment, make the tables prettier.
try:
# pylint: disable-next=g-import-not-at-top
from google import colab
from google import colab # type: ignore

colab.data_table.enable_dataframe_formatter()
except ImportError:
Expand Down
10 changes: 5 additions & 5 deletions google/generativeai/notebook/flag_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import argparse
import dataclasses
import enum
from typing import Any, Callable, Sequence, Union, Tuple
from typing import Any, Callable, Sequence, Tuple, Union

from google.generativeai.notebook.lib import llmfn_inputs_source
from google.generativeai.notebook.lib import llmfn_outputs
Expand All @@ -49,10 +49,10 @@
_DESTTYPES = Union[
_PARSETYPES,
enum.Enum,
Tuple[str, Callable[[str, str], Any]], # For --compare_fn
Sequence[str], # For --ground_truth
llmfn_inputs_source.LLMFnInputsSource, # For --inputs
llmfn_outputs.LLMFnOutputsSink, # For --outputs
Tuple[str, Callable[[str, str], Any]],
Sequence[str], # For --compare_fn
llmfn_inputs_source.LLMFnInputsSource, # For --ground_truth
llmfn_outputs.LLMFnOutputsSink, # For --inputs # For --outputs
]

# The signature of a function that converts a command line argument from the
Expand Down
13 changes: 11 additions & 2 deletions google/generativeai/notebook/lib/llm_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,16 @@

import abc
import dataclasses
from typing import AbstractSet, Any, Callable, Iterable, Mapping, Optional, Sequence
from typing import (
AbstractSet,
Any,
Callable,
Iterable,
Mapping,
Optional,
Sequence,
Union,
)

from google.generativeai.notebook.lib import llmfn_input_utils
from google.generativeai.notebook.lib import llmfn_output_row
Expand Down Expand Up @@ -117,7 +126,7 @@ def _generate_prompts(

class LLMFunction(
Callable[
[Optional[llmfn_input_utils.LLMFunctionInputs]],
[Union[llmfn_input_utils.LLMFunctionInputs, None]],
llmfn_outputs.LLMFnOutputs,
],
metaclass=abc.ABCMeta,
Expand Down
4 changes: 2 additions & 2 deletions google/generativeai/notebook/lib/llm_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Unittest for llm_function."""
from __future__ import annotations

from typing import Any, Callable, Optional, Mapping, Sequence
from typing import Any, Callable, Mapping, Sequence

from absl.testing import absltest
from google.generativeai.notebook.lib import llm_function
Expand Down Expand Up @@ -61,7 +61,7 @@ class LLMFunctionBasicTest(absltest.TestCase):

def _test_is_callable(
self,
llm_fn: Callable[[Optional[Sequence[tuple[str, str]]]], LLMFnOutputs],
llm_fn: Callable[[Sequence[tuple[str, str]] | None], LLMFnOutputs],
) -> LLMFnOutputs:
return llm_fn(None)

Expand Down
5 changes: 1 addition & 4 deletions google/generativeai/notebook/lib/llmfn_input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@

_ColumnOrderValuesList = Mapping[str, Sequence[str]]

LLMFunctionInputs = Union[
_ColumnOrderValuesList,
llmfn_inputs_source.LLMFnInputsSource,
]
LLMFunctionInputs = Union[_ColumnOrderValuesList, llmfn_inputs_source.LLMFnInputsSource]


def _is_column_order_values_list(inputs: Any) -> bool:
Expand Down
3 changes: 2 additions & 1 deletion google/generativeai/notebook/text_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from google.api_core import retry
from google.generativeai import text
from google.generativeai.types import text_types
from google.generativeai.notebook.lib import model as model_lib


Expand All @@ -30,7 +31,7 @@ def _generate_text(
temperature: float | None = None,
candidate_count: int | None = None,
**kwargs,
) -> text.Completion:
) -> text_types.Completion:
if model is not None:
kwargs["model"] = model
if temperature is not None:
Expand Down
Loading

0 comments on commit ef27fec

Please sign in to comment.