diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py index 02b8d911a31a..d1d532fcb166 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py @@ -393,6 +393,17 @@ def count_tokens_openai( elif field == "description": tool_tokens += 2 tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore + elif field == "anyOf": + tool_tokens -= 3 + for o in v["anyOf"]: + tool_tokens += 3 + tool_tokens += len(encoding.encode(o["type"])) + elif field == "default": + tool_tokens += 2 + tool_tokens += len(encoding.encode(json.dumps(v["default"]))) + elif field == "title": + tool_tokens += 2 + tool_tokens += len(encoding.encode(v["title"])) elif field == "enum": tool_tokens -= 3 for o in v["enum"]: # pyright: ignore @@ -404,7 +415,9 @@ def count_tokens_openai( if len(parameters["properties"]) == 0: # pyright: ignore tool_tokens -= 2 num_tokens += tool_tokens - num_tokens += 12 + + if oai_tools: + num_tokens += 12 return num_tokens diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py index 58558cceb5f4..4b59b70b7d4e 100644 --- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py @@ -2,7 +2,7 @@ import json import logging import os -from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Tuple, TypeVar +from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Tuple, TypeVar, Optional from unittest.mock import AsyncMock, MagicMock import httpx @@ -450,11 +450,30 @@ def tool1(test: str, test2: str) -> str: def tool2(test1: int, test2: List[int]) -> str: return str(test1) + str(test2) - tools = [FunctionTool(tool1, description="example tool 1"), FunctionTool(tool2, description="example tool 2")] + def tool3( + test1: Annotated[Optional[str], "example"] = None, + test2: Literal["1", "2"] = "2" + ) -> str: + return str(test1) + str(test2) + + tools = [ + FunctionTool(tool1, description="example tool 1"), + FunctionTool(tool2, description="example tool 2"), + FunctionTool(tool3, description="example tool 3") + ] mockcalculate_vision_tokens = MagicMock() monkeypatch.setattr("autogen_ext.models.openai._openai_client.calculate_vision_tokens", mockcalculate_vision_tokens) + # Test count_tokens without tools + num_tokens = client.count_tokens(messages) + assert num_tokens + + # Check that calculate_vision_tokens was called + mockcalculate_vision_tokens.assert_called_once() + mockcalculate_vision_tokens.reset_mock() + + # Test count_tokens with tools num_tokens = client.count_tokens(messages, tools=tools) assert num_tokens