Skip to content

Commit

Permalink
chore: fix mypy errors in pagers, errors, _test_api_client, and _extr…
Browse files Browse the repository at this point in the history
…a_utils modules

PiperOrigin-RevId: 734123684
  • Loading branch information
sararob authored and copybara-github committed Mar 6, 2025
1 parent 5e84ddc commit 257b435
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 55 deletions.
54 changes: 36 additions & 18 deletions google/genai/_extra_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

import inspect
import logging
import sys
import typing
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
import sys

import pydantic

Expand All @@ -38,7 +38,7 @@


def _create_generate_content_config_model(
config: types.GenerateContentConfigOrDict
config: types.GenerateContentConfigOrDict,
) -> types.GenerateContentConfig:
if isinstance(config, dict):
return types.GenerateContentConfig(**config)
Expand Down Expand Up @@ -78,9 +78,9 @@ def format_destination(

def get_function_map(
config: Optional[types.GenerateContentConfigOrDict] = None,
) -> dict[str, object]:
) -> dict[str, Callable]:
"""Returns a function map from the config."""
function_map: dict[str, object] = {}
function_map: dict[str, Callable] = {}
if not config:
return function_map
config_model = _create_generate_content_config_model(config)
Expand All @@ -97,6 +97,16 @@ def get_function_map(
return function_map


def convert_number_values_for_dict_function_call_args(
args: dict[str, Any],
) -> dict[str, Any]:
"""Converts float values in dict with no decimal to integers."""
return {
key: convert_number_values_for_function_call_args(value)
for key, value in args.items()
}


def convert_number_values_for_function_call_args(
args: Union[dict[str, object], list[object], object],
) -> Union[dict[str, object], list[object], object]:
Expand Down Expand Up @@ -215,27 +225,35 @@ def invoke_function_from_dict_args(

def get_function_response_parts(
response: types.GenerateContentResponse,
function_map: dict[str, object],
function_map: dict[str, Callable],
) -> list[types.Part]:
"""Returns the function response parts from the response."""
func_response_parts = []
if response.candidates is not None and isinstance(response.candidates[0].content, types.Content) and response.candidates[0].content.parts is not None:
if (
response.candidates is not None
and isinstance(response.candidates[0].content, types.Content)
and response.candidates[0].content.parts is not None
):
for part in response.candidates[0].content.parts:
if not part.function_call:
continue
func_name = part.function_call.name
func = function_map[func_name]
args = convert_number_values_for_function_call_args(part.function_call.args)
func_response: dict[str, Any]
try:
func_response = {'result': invoke_function_from_dict_args(args, func)}
except Exception as e: # pylint: disable=broad-except
func_response = {'error': str(e)}
func_response_part = types.Part.from_function_response(
name=func_name, response=func_response
)

func_response_parts.append(func_response_part)
if func_name is not None and part.function_call.args is not None:
func = function_map[func_name]
args = convert_number_values_for_dict_function_call_args(
part.function_call.args
)
func_response: dict[str, Any]
try:
func_response = {
'result': invoke_function_from_dict_args(args, func)
}
except Exception as e: # pylint: disable=broad-except
func_response = {'error': str(e)}
func_response_part = types.Part.from_function_response(
name=func_name, response=func_response
)
func_response_parts.append(func_response_part)
return func_response_parts


Expand Down
9 changes: 4 additions & 5 deletions google/genai/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def pil_to_blob(img) -> types.Blob:


def t_part(
client: _api_client.BaseApiClient, part: Optional[types.PartUnionDict]
part: Optional[types.PartUnionDict]
) -> types.Part:
try:
import PIL.Image
Expand All @@ -268,16 +268,15 @@ def t_part(


def t_parts(
client: _api_client.BaseApiClient,
parts: Optional[Union[list[types.PartUnionDict], types.PartUnionDict]],
) -> list[types.Part]:
#
if parts is None or (isinstance(parts, list) and not parts):
raise ValueError('content parts are required.')
if isinstance(parts, list):
return [t_part(client, part) for part in parts]
return [t_part(part) for part in parts]
else:
return [t_part(client, parts)]
return [t_part(parts)]


def t_image_predictions(
Expand Down Expand Up @@ -408,7 +407,7 @@ def _handle_current_part(
accumulated_parts: list[types.Part],
current_part: types.PartUnionDict,
):
current_part = t_part(client, current_part)
current_part = t_part(current_part)
if _is_user_part(current_part) == _are_user_parts(accumulated_parts):
accumulated_parts.append(current_part)
else:
Expand Down
15 changes: 12 additions & 3 deletions google/genai/pagers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
class _BasePager(Generic[T]):
"""Base pager class for iterating through paginated results."""

def __init__(
def _init_page(
self,
name: PagedItem,
request: Callable[..., Any],
Expand All @@ -54,6 +54,15 @@ def __init__(

self._page_size = request_config.get('page_size', len(self._page))

def __init__(
self,
name: PagedItem,
request: Callable[..., Any],
response: Any,
config: Any,
):
self._init_page(name, request, response, config)

@property
def page(self) -> list[T]:
"""Returns the current page, which is a list of items.
Expand All @@ -72,7 +81,7 @@ def page(self) -> list[T]:
return self._page

@property
def name(self) -> str:
def name(self) -> PagedItem:
"""Returns the type of paged item (for example, ``batch_jobs``).
Usage:
Expand Down Expand Up @@ -139,7 +148,7 @@ def _init_next_page(self, response: Any) -> None:
Args:
response: The response object from the API request.
"""
self.__init__(self.name, self._request, response, self.config)
self._init_page(self.name, self._request, response, self.config)


class Pager(_BasePager[T]):
Expand Down
22 changes: 11 additions & 11 deletions google/genai/tests/transformers/test_t_part.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,54 +24,54 @@

def test_none():
with pytest.raises(ValueError):
t.t_part(None, None)
t.t_part(None)


def test_empty_string():
assert t.t_part(None, '') == types.Part(text='')
assert t.t_part('') == types.Part(text='')


def test_string():
assert t.t_part(None, 'test') == types.Part(text='test')
assert t.t_part('test') == types.Part(text='test')


def test_file():
assert t.t_part(
None, types.File(uri='gs://test', mime_type='image/png')
types.File(uri='gs://test', mime_type='image/png')
) == types.Part(
file_data=types.FileData(file_uri='gs://test', mime_type='image/png')
)


def test_file_no_uri():
with pytest.raises(ValueError):
t.t_part(None, types.File(mime_type='image/png'))
t.t_part(types.File(mime_type='image/png'))


def test_file_no_mime_type():
with pytest.raises(ValueError):
t.t_part(None, types.File(uri='gs://test'))
t.t_part(types.File(uri='gs://test'))


def test_empty_dict():
assert t.t_part(None, {}) == types.Part()
assert t.t_part({}) == types.Part()


def test_dict():
assert t.t_part(None, {'text': 'test'}) == types.Part(text='test')
assert t.t_part({'text': 'test'}) == types.Part(text='test')


def test_invalid_dict():
with pytest.raises(pydantic.ValidationError):
t.t_part(None, {'invalid_key': 'test'})
t.t_part({'invalid_key': 'test'})


def test_part():
assert t.t_part(None, types.Part(text='test')) == types.Part(text='test')
assert t.t_part(types.Part(text='test')) == types.Part(text='test')


def test_int():
try:
t.t_part(None, 1)
t.t_part(1)
except ValueError as e:
assert 'Unsupported content part type: <class \'int\'>' in str(e)
24 changes: 12 additions & 12 deletions google/genai/tests/transformers/test_t_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,41 +24,41 @@

def test_none():
with pytest.raises(ValueError):
t.t_parts(None, None)
t.t_parts(None)


def test_empty_list():
with pytest.raises(ValueError):
t.t_parts(None, [])
t.t_parts([])


def test_list():
assert t.t_parts(None, ['test1', 'test2']) == [
assert t.t_parts(['test1', 'test2']) == [
types.Part(text='test1'),
types.Part(text='test2'),
]


def test_empty_dict():
assert t.t_parts(None, {}) == [types.Part()]
assert t.t_parts({}) == [types.Part()]


def test_dict():
assert t.t_parts(None, {'text': 'test'}) == [types.Part(text='test')]
assert t.t_parts({'text': 'test'}) == [types.Part(text='test')]


def test_invalid_dict():
with pytest.raises(pydantic.ValidationError):
t.t_parts(None, {'invalid_key': 'test'})
t.t_parts({'invalid_key': 'test'})


def test_string():
assert t.t_parts(None, 'test') == [types.Part(text='test')]
assert t.t_parts('test') == [types.Part(text='test')]


def test_file():
assert t.t_parts(
None, types.File(uri='gs://test', mime_type='image/png')
types.File(uri='gs://test', mime_type='image/png')
) == [
types.Part(
file_data=types.FileData(file_uri='gs://test', mime_type='image/png')
Expand All @@ -68,20 +68,20 @@ def test_file():

def test_file_no_uri():
with pytest.raises(ValueError):
t.t_parts(None, types.File(mime_type='image/png'))
t.t_parts(types.File(mime_type='image/png'))


def test_file_no_mime_type():
with pytest.raises(ValueError):
t.t_parts(None, types.File(uri='gs://test'))
t.t_parts(types.File(uri='gs://test'))


def test_part():
assert t.t_parts(None, types.Part(text='test')) == [types.Part(text='test')]
assert t.t_parts(types.Part(text='test')) == [types.Part(text='test')]


def test_int():
try:
t.t_parts(None, 1)
t.t_parts(1)
except ValueError as e:
assert 'Unsupported content part type: <class \'int\'>' in str(e)
12 changes: 6 additions & 6 deletions google/genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import sys
import types as builtin_types
import typing
from typing import Any, Callable, Literal, Optional, Union, _UnionGenericAlias
from typing import Any, Callable, Literal, Optional, Union, _UnionGenericAlias # type: ignore
import pydantic
from pydantic import Field
from typing_extensions import TypedDict
Expand Down Expand Up @@ -720,7 +720,7 @@ class as well.
def __init__(self, parts: Union['PartUnionDict', list['PartUnionDict']]):
from . import _transformers as t

super().__init__(parts=t.t_parts(None, parts=parts))
super().__init__(parts=t.t_parts(parts=parts))


class ModelContent(Content):
Expand Down Expand Up @@ -748,7 +748,7 @@ class as well.
def __init__(self, parts: Union['PartUnionDict', list['PartUnionDict']]):
from . import _transformers as t

super().__init__(parts=t.t_parts(None, parts=parts))
super().__init__(parts=t.t_parts(parts=parts))


class ContentDict(TypedDict, total=False):
Expand Down Expand Up @@ -1736,7 +1736,7 @@ class FileDict(TypedDict, total=False):
if _is_pillow_image_imported:
PartUnion = Union[File, Part, PIL_Image, str]
else:
PartUnion = Union[File, Part, str]
PartUnion = Union[File, Part, str] # type: ignore[misc]


PartUnionDict = Union[PartUnion, PartDict]
Expand Down Expand Up @@ -3052,7 +3052,7 @@ def _from_response(
):

class Placeholder(pydantic.BaseModel):
placeholder: response_schema
placeholder: response_schema # type: ignore[valid-type]

try:
parsed = {'placeholder': json.loads(result.text)}
Expand Down Expand Up @@ -3082,7 +3082,7 @@ class Placeholder(pydantic.BaseModel):
try:

class Placeholder(pydantic.BaseModel): # type: ignore[no-redef]
placeholder: response_schema
placeholder: response_schema # type: ignore[valid-type]

parsed = {'placeholder': json.loads(result.text)}
placeholder = Placeholder.model_validate(parsed)
Expand Down

0 comments on commit 257b435

Please sign in to comment.