Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
from toolbox_core.utils import params_to_pydantic_model


def _get_tool_description(core_tool: ToolboxCoreTool) -> str:
description = core_tool._description
if isinstance(description, str):
return description
return core_tool.__doc__ or ""


# This class is an internal implementation detail and is not exposed to the
# end-user. It should not be used directly by external code. Changes to this
# class will not be considered breaking changes to the public API.
Expand All @@ -44,7 +51,7 @@ def __init__(
# BaseTool class before assigning values to member variables.
super().__init__(
name=core_tool.__name__,
description=core_tool.__doc__,
description=_get_tool_description(core_tool),
args_schema=params_to_pydantic_model(core_tool._name, core_tool._params),
)
self.__core_tool = core_tool
Expand Down
9 changes: 8 additions & 1 deletion packages/toolbox-langchain/src/toolbox_langchain/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
from toolbox_core.utils import params_to_pydantic_model


def _get_tool_description(core_tool: ToolboxCoreSyncTool) -> str:
description = core_tool._description
if isinstance(description, str):
return description
return core_tool.__doc__ or ""


class ToolboxTool(BaseTool):
"""
A subclass of LangChain's BaseTool that supports features specific to
Expand All @@ -42,7 +49,7 @@ def __init__(
# BaseTool class before assigning values to member variables.
super().__init__(
name=core_tool.__name__,
description=core_tool.__doc__,
description=_get_tool_description(core_tool),
args_schema=params_to_pydantic_model(core_tool._name, core_tool._params),
)
self.__core_tool = core_tool
Expand Down
3 changes: 2 additions & 1 deletion packages/toolbox-langchain/tests/test_async_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ async def test_toolbox_tool_init(self, tool_schema_dict):
)
tool = AsyncToolboxTool(core_tool=core_tool_instance)
assert tool.name == "test_tool"
assert tool.description == core_tool_instance.__doc__
assert tool.description == core_tool_instance._description
assert "Args:" not in tool.description

@pytest.mark.parametrize(
"params_to_bind",
Expand Down
21 changes: 18 additions & 3 deletions packages/toolbox-langchain/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,13 @@ def mock_core_tool(self, tool_schema_dict):
sync_mock = Mock(spec=ToolboxCoreSyncTool)

sync_mock.__name__ = "test_tool_name_for_langchain"
sync_mock.__doc__ = tool_schema_dict["description"]
sync_mock._description = tool_schema_dict["description"]
sync_mock.__doc__ = (
f"{tool_schema_dict['description']}\n\n"
"Args:\n"
" param1 (str): Param 1\n"
" param2 (int): Param 2"
)
sync_mock._name = "TestToolPydanticModel"
sync_mock._params = [
CoreParameterSchema(**p) for p in tool_schema_dict["parameters"]
Expand All @@ -123,6 +129,7 @@ def mock_core_tool(self, tool_schema_dict):

new_mock_instance_for_methods = Mock(spec=ToolboxCoreSyncTool)
new_mock_instance_for_methods.__name__ = sync_mock.__name__
new_mock_instance_for_methods._description = sync_mock._description
new_mock_instance_for_methods.__doc__ = sync_mock.__doc__
new_mock_instance_for_methods._name = sync_mock._name
new_mock_instance_for_methods._params = sync_mock._params
Expand All @@ -145,7 +152,13 @@ def mock_core_tool(self, tool_schema_dict):
def mock_core_sync_auth_tool(self, auth_tool_schema_dict):
sync_mock = Mock(spec=ToolboxCoreSyncTool)
sync_mock.__name__ = "test_auth_tool_lc_name"
sync_mock.__doc__ = auth_tool_schema_dict["description"]
sync_mock._description = auth_tool_schema_dict["description"]
sync_mock.__doc__ = (
f"{auth_tool_schema_dict['description']}\n\n"
"Args:\n"
" param1 (str): Param 1\n"
" param2 (int): Param 2"
)
sync_mock._name = "TestAuthToolPydanticModel"
sync_mock._params = [
CoreParameterSchema(**p) for p in auth_tool_schema_dict["parameters"]
Expand All @@ -159,6 +172,7 @@ def mock_core_sync_auth_tool(self, auth_tool_schema_dict):

new_mock_instance_for_methods = Mock(spec=ToolboxCoreSyncTool)
new_mock_instance_for_methods.__name__ = sync_mock.__name__
new_mock_instance_for_methods._description = sync_mock._description
new_mock_instance_for_methods.__doc__ = sync_mock.__doc__
new_mock_instance_for_methods._name = sync_mock._name
new_mock_instance_for_methods._params = sync_mock._params
Expand Down Expand Up @@ -188,7 +202,8 @@ def test_toolbox_tool_init(self, mock_core_tool):
tool = ToolboxTool(core_tool=mock_core_tool)

assert tool.name == mock_core_tool.__name__
assert tool.description == mock_core_tool.__doc__
assert tool.description == mock_core_tool._description
assert "Args:" not in tool.description
assert tool._ToolboxTool__core_tool == mock_core_tool

expected_args_schema = params_to_pydantic_model(
Expand Down
Loading