Skip to content

Commit

Permalink
Improve python executor's error logging (#275)
Browse files Browse the repository at this point in the history
* Improve python executor's error logging
  • Loading branch information
aymeric-roucher authored Jan 20, 2025
1 parent 3c18d4d commit 7a91123
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 37 deletions.
12 changes: 2 additions & 10 deletions src/smolagents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,16 +972,8 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]:
]
observation += "Execution logs:\n" + execution_logs
except Exception as e:
if isinstance(e, SyntaxError):
error_msg = (
f"Code execution failed on line {e.lineno} due to: {type(e).__name__}\n"
f"{e.text}"
f"{' ' * (e.offset or 0)}^\n"
f"Error: {str(e)}"
)
else:
error_msg = str(e)
if "Import of " in str(e) and " is not allowed" in str(e):
error_msg = str(e)
if "Import of " in error_msg and " is not allowed" in error_msg:
self.logger.log(
"[bold red]Warning to user: Code execution failed due to an unauthorized import - Consider passing said import under `additional_authorized_imports` when initializing your CodeAgent.",
level=LogLevel.INFO,
Expand Down
24 changes: 18 additions & 6 deletions src/smolagents/local_python_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def evaluate_call(
func = ERRORS[func_name]
else:
raise InterpreterError(
f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
f"It is not permitted to evaluate other functions than the provided tools or functions defined/imported in previous code (tried to execute {call.func.id})."
)

elif isinstance(call.func, ast.Subscript):
Expand Down Expand Up @@ -1245,7 +1245,16 @@ def evaluate_python_code(
updated by this function to contain all variables as they are evaluated.
The print outputs will be stored in the state under the key 'print_outputs'.
"""
expression = ast.parse(code)
try:
expression = ast.parse(code)
except SyntaxError as e:
raise InterpreterError(
f"Code execution failed on line {e.lineno} due to: {type(e).__name__}\n"
f"{e.text}"
f"{' ' * (e.offset or 0)}^\n"
f"Error: {str(e)}"
)

if state is None:
state = {}
if static_tools is None:
Expand Down Expand Up @@ -1273,10 +1282,13 @@ def final_answer(value):
state["print_outputs"] = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
is_final_answer = True
return e.value, is_final_answer
except InterpreterError as e:
msg = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
raise InterpreterError(msg)
except Exception as e:
exception_type = type(e).__name__
error_msg = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
error_msg = (
f"Code execution failed at line '{ast.get_source_segment(code, node)}' due to: {exception_type}:{str(e)}"
)
raise InterpreterError(error_msg)


class LocalPythonInterpreter:
Expand Down
35 changes: 27 additions & 8 deletions src/smolagents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def validate_arguments(self):
"Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
)

json_schema = _convert_type_hints_to_json_schema(self.forward)
json_schema = _convert_type_hints_to_json_schema(
self.forward
) # This function will raise an error on missing docstrings, contrary to get_json_schema
for key, value in self.inputs.items():
if "nullable" in value:
assert key in json_schema and "nullable" in json_schema[key], (
Expand Down Expand Up @@ -885,6 +887,16 @@ def from_mcp(cls, server_parameters) -> "ToolCollection":
yield cls(tools)


def get_tool_json_schema(tool_function):
tool_json_schema = get_json_schema(tool_function)["function"]
tool_parameters = tool_json_schema["parameters"]
inputs_schema = tool_parameters["properties"]
for input_name in inputs_schema:
if "required" not in tool_parameters or input_name not in tool_parameters["required"]:
inputs_schema[input_name]["nullable"] = True
return tool_json_schema


def tool(tool_function: Callable) -> Tool:
"""
Converts a function into an instance of a Tool subclass.
Expand All @@ -893,12 +905,19 @@ def tool(tool_function: Callable) -> Tool:
tool_function: Your function. Should have type hints for each input and a type hint for the output.
Should also have a docstring description including an 'Args:' part where each argument is described.
"""
parameters = get_json_schema(tool_function)["function"]
if "return" not in parameters:
tool_json_schema = get_tool_json_schema(tool_function)
if "return" not in tool_json_schema:
raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")

class SimpleTool(Tool):
def __init__(self, name, description, inputs, output_type, function):
def __init__(
self,
name: str,
description: str,
inputs: Dict[str, Dict[str, str]],
output_type: str,
function: Callable,
):
self.name = name
self.description = description
self.inputs = inputs
Expand All @@ -907,10 +926,10 @@ def __init__(self, name, description, inputs, output_type, function):
self.is_initialized = True

simple_tool = SimpleTool(
parameters["name"],
parameters["description"],
parameters["parameters"]["properties"],
parameters["return"]["type"],
name=tool_json_schema["name"],
description=tool_json_schema["description"],
inputs=tool_json_schema["parameters"]["properties"],
output_type=tool_json_schema["return"]["type"],
function=tool_function,
)
original_signature = inspect.signature(tool_function)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def test_code_agent_code_errors_show_offending_lines(self):
output = agent.run("What is 2 multiplied by 3.6452?")
assert isinstance(output, AgentText)
assert output == "got an error"
assert "Code execution failed at line 'print = 2' because of" in str(agent.logs)
assert "Code execution failed at line 'print = 2' due to: InterpreterError" in str(agent.logs)

def test_code_agent_syntax_error_show_offending_lines(self):
agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error)
Expand Down Expand Up @@ -426,7 +426,7 @@ def test_code_agent_missing_import_triggers_advice_in_error_log(self):
with console.capture() as capture:
agent.run("Count to 3")
str_output = capture.get()
assert "import under `additional_authorized_imports`" in str_output
assert "Consider passing said import under" in str_output.replace("\n", "")

def test_multiagents(self):
class FakeModelMultiagentsManagerAgent:
Expand Down
34 changes: 23 additions & 11 deletions tests/test_python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,25 +630,32 @@ def test_adding_int_to_list_raises_error(self):
assert "Cannot add non-list value 1 to a list." in str(e)

def test_error_highlights_correct_line_of_code(self):
code = """# Ok this is a very long code
# It has many commented lines
a = 1
code = """a = 1
b = 2
# Here is another piece
counts = [1, 2, 3]
counts += 1
b += 1"""
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert "Code execution failed at line 'counts += 1" in str(e)

def test_error_type_returned_in_function_call(self):
code = """def error_function():
raise ValueError("error")
error_function()"""
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code)
assert "error" in str(e)
assert "ValueError" in str(e)

def test_assert(self):
code = """
assert 1 == 1
assert 1 == 2
"""
with pytest.raises(AssertionError) as e:
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code, BASE_PYTHON_TOOLS, state={})
assert "1 == 2" in str(e) and "1 == 1" not in str(e)

Expand Down Expand Up @@ -845,6 +852,13 @@ def test_for(self):
result, _ = evaluate_python_code(code, {"print": print, "map": map}, state={})
assert result == {"Worker A": "8:00 pm", "Worker B": "11:45 am"}

def test_syntax_error_points_error(self):
code = "a = ;"
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code)
assert "SyntaxError" in str(e)
assert " ^" in str(e)

def test_fix_final_answer_code(self):
test_cases = [
(
Expand Down Expand Up @@ -890,18 +904,16 @@ def test_dangerous_subpackage_access_blocked(self):

# Import of whitelisted modules should succeed but dangerous submodules should not exist
code = "import random;random._os.system('echo bad command passed')"
with pytest.raises(AttributeError) as e:
with pytest.raises(InterpreterError) as e:
evaluate_python_code(code)
assert "module 'random' has no attribute '_os'" in str(e)
assert "AttributeError:module 'random' has no attribute '_os'" in str(e)

code = "import doctest;doctest.inspect.os.system('echo bad command passed')"
with pytest.raises(AttributeError):
with pytest.raises(InterpreterError):
evaluate_python_code(code, authorized_imports=["doctest"])

def test_close_matches_subscript(self):
code = 'capitals = {"Czech Republic": "Prague", "Monaco": "Monaco", "Bhutan": "Thimphu"};capitals["Butan"]'
with pytest.raises(Exception) as e:
evaluate_python_code(code)
assert "Maybe you meant one of these indexes instead" in str(
e
) and "['Bhutan']" in str(e).replace("\\", "")
assert "Maybe you meant one of these indexes instead" in str(e) and "['Bhutan']" in str(e).replace("\\", "")
14 changes: 14 additions & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,20 @@ def forward(self, location, celsius: str) -> str:
GetWeatherTool3()
assert "Nullable" in str(e)

def test_tool_default_parameters_is_nullable(self):
@tool
def get_weather(location: str, celsius: bool = False) -> str:
"""
Get weather in the next days at given location.
Args:
location: the location
celsius: is the temperature given in celsius
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"

assert get_weather.inputs["celsius"]["nullable"]


@pytest.fixture
def mock_server_parameters():
Expand Down

0 comments on commit 7a91123

Please sign in to comment.