Skip to content

Commit

Permalink
E2B tool - Improve description wuth uploaded files info (langchain-ai…
Browse files Browse the repository at this point in the history
  • Loading branch information
jakubno authored and HoaNQ9 committed Feb 2, 2024
1 parent 7ce1e1c commit fa3c041
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions libs/langchain/langchain/tools/e2b_data_analysis/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.pydantic_v1 import BaseModel, Field
from langchain.pydantic_v1 import BaseModel, Field, PrivateAttr
from langchain.tools import BaseTool, Tool
from langchain.tools.e2b_data_analysis.unparse import Unparser

Expand Down Expand Up @@ -97,7 +97,7 @@ class E2BDataAnalysisTool(BaseTool):
name = "e2b_data_analysis"
args_schema: Type[BaseModel] = E2BDataAnalysisToolArguments
session: Any
uploaded_files: List[UploadedFile] = Field(default_factory=list)
_uploaded_files: List[UploadedFile] = PrivateAttr(default_factory=list)

def __init__(
self,
Expand All @@ -119,7 +119,8 @@ def __init__(

# If no API key is provided, E2B will try to read it from the environment
# variable E2B_API_KEY
session = DataAnalysis(
super().__init__(description=base_description, **kwargs)
self.session = DataAnalysis(
api_key=api_key,
cwd=cwd,
env_vars=env_vars,
Expand All @@ -128,21 +129,19 @@ def __init__(
on_exit=on_exit,
on_artifact=on_artifact,
)
super().__init__(session=session, description=base_description, **kwargs)
self.uploaded_files = []

def close(self) -> None:
"""Close the cloud sandbox."""
self.uploaded_files = []
self._uploaded_files = []
self.session.close()

@property
def uploaded_files_description(self) -> str:
if len(self.uploaded_files) == 0:
if len(self._uploaded_files) == 0:
return ""
lines = ["The following files available in the sandbox:"]

for f in self.uploaded_files:
for f in self._uploaded_files:
if f.description == "":
lines.append(f"- path: `{f.remote_path}`")
else:
Expand Down Expand Up @@ -206,15 +205,19 @@ def upload_file(self, file: IO, description: str) -> UploadedFile:
remote_path=remote_path,
description=description,
)
self.uploaded_files.append(f)
self._uploaded_files.append(f)
self.description = self.description + "\n" + self.uploaded_files_description
return f

def remove_uploaded_file(self, uploaded_file: UploadedFile) -> None:
"""Remove uploaded file from the sandbox."""
self.session.filesystem.remove(uploaded_file.remote_path)
self.uploaded_files = [
f for f in self.uploaded_files if f.remote_path != uploaded_file.remote_path
self._uploaded_files = [
f
for f in self._uploaded_files
if f.remote_path != uploaded_file.remote_path
]
self.description = self.description + "\n" + self.uploaded_files_description

def as_tool(self) -> Tool:
return Tool.from_function(
Expand Down

0 comments on commit fa3c041

Please sign in to comment.