Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/google/adk/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from pydantic import TypeAdapter
from typing_extensions import override
from typing_extensions import TypeAlias

Expand Down Expand Up @@ -180,9 +181,9 @@ class LlmAgent(BaseAgent):
"""

# Controlled input/output configurations - Start
input_schema: Optional[type[BaseModel]] = None
input_schema: Optional[Any] = None
"""The input schema when agent is used as a tool."""
output_schema: Optional[type[BaseModel]] = None
output_schema: Optional[Any] = None
"""The output schema when agent replies.

NOTE:
Expand Down Expand Up @@ -470,9 +471,11 @@ def __maybe_save_output_to_state(self, event: Event):
# Do not attempt to parse it as JSON.
if not result.strip():
return
result = self.output_schema.model_validate_json(result).model_dump(
exclude_none=True
)
validated_result = TypeAdapter(self.output_schema).validate_json(result)
if isinstance(validated_result, BaseModel):
result = validated_result.model_dump(exclude_none=True)
else:
result = validated_result
event.actions.state_delta[self.output_key] = result

@model_validator(mode='after')
Expand Down
8 changes: 0 additions & 8 deletions src/google/adk/tools/_function_parameter_parse_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,6 @@ def _is_builtin_primitive_or_compound(
return annotation in _py_builtin_type_to_schema_type.keys()


def _raise_for_any_of_if_mldev(schema: types.Schema):
if schema.any_of:
raise ValueError(
'AnyOf is not supported in function declaration schema for Google AI.'
)


def _update_for_default_if_mldev(schema: types.Schema):
if schema.default is not None:
# TODO(kech): Remove this workaround once mldev supports default value.
Expand All @@ -74,7 +67,6 @@ def _raise_if_schema_unsupported(
variant: GoogleLLMVariant, schema: types.Schema
):
if variant == GoogleLLMVariant.GEMINI_API:
_raise_for_any_of_if_mldev(schema)
_update_for_default_if_mldev(schema)


Expand Down
11 changes: 8 additions & 3 deletions src/google/adk/tools/agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import model_validator
from pydantic import TypeAdapter
from typing_extensions import override

from . import _automatic_function_calling_util
Expand Down Expand Up @@ -113,7 +114,7 @@ async def run_async(
tool_context.actions.skip_summarization = True

if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
input_value = self.agent.input_schema.model_validate(args)
input_value = TypeAdapter(self.agent.input_schema).validate_python(args)
content = types.Content(
role='user',
parts=[
Expand Down Expand Up @@ -157,9 +158,13 @@ async def run_async(
return ''
merged_text = '\n'.join(p.text for p in last_event.content.parts if p.text)
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
tool_result = self.agent.output_schema.model_validate_json(
validated_result = TypeAdapter(self.agent.output_schema).validate_json(
merged_text
).model_dump(exclude_none=True)
)
if isinstance(validated_result, BaseModel):
tool_result = validated_result.model_dump(exclude_none=True)
else:
tool_result = validated_result
else:
tool_result = merged_text
return tool_result
Expand Down
55 changes: 54 additions & 1 deletion tests/unittests/agents/test_llm_agent_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,25 @@

from typing import Any
from typing import cast
from typing import Literal
from typing import Optional
from typing import Union

from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.loop_agent import LoopAgent
from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.events.event import Event
from google.adk.models.llm_request import LlmRequest
from google.adk.models.registry import LLMRegistry
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.genai import types
from google.genai.types import Part
from pydantic import BaseModel
import pytest

from .. import testing_utils


async def _create_readonly_context(
agent: LlmAgent, state: Optional[dict[str, Any]] = None
Expand Down Expand Up @@ -279,3 +284,51 @@ def test_allow_transfer_by_default():

assert not agent.disallow_transfer_to_parent
assert not agent.disallow_transfer_to_peers


def test_output_schema_with_union():
"""Tests if agent can have a Union type in output_schema."""

class CustomOutput1(BaseModel):
custom_output1: str

class CustomOutput2(BaseModel):
custom_output2: str

agent = LlmAgent(
name='test_agent',
output_schema=Union[CustomOutput1, CustomOutput2, Literal['option3']],
output_key='test_output',
)

# Test with the first type
event1 = Event(
author='test_agent',
content=types.Content(
parts=[Part(text='{"custom_output1": "response1"}')]
),
)
agent._LlmAgent__maybe_save_output_to_state(event1)
assert event1.actions.state_delta['test_output'] == {
'custom_output1': 'response1'
}

# Test with the second type
event2 = Event(
author='test_agent',
content=types.Content(
parts=[Part(text='{"custom_output2": "response2"}')]
),
)
agent._LlmAgent__maybe_save_output_to_state(event2)
assert event2.actions.state_delta['test_output'] == {
'custom_output2': 'response2'
}

# Test with the literal type
event3 = Event(
author='test_agent',
content=types.Content(parts=[Part(text='"option3"')]),
)
agent._LlmAgent__maybe_save_output_to_state(event3)
assert event3.actions.state_delta['test_output'] == 'option3'
Loading