Skip to content

Commit 5983fa0

Browse files
committed
🚀 multiple improvements
- booth sync/async methods - status output - code/output logging in response
1 parent 31b239b commit 5983fa0

File tree

1 file changed

+177
-25
lines changed

1 file changed

+177
-25
lines changed

codeinterpreterapi/session.py

Lines changed: 177 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,22 @@
2222
from langchain.tools import BaseTool, StructuredTool
2323

2424
from codeinterpreterapi.agents import OpenAIFunctionsAgent
25-
from codeinterpreterapi.chains import get_file_modifications, remove_download_link
25+
from codeinterpreterapi.chains import (
26+
aget_file_modifications,
27+
aremove_download_link,
28+
get_file_modifications,
29+
remove_download_link,
30+
)
2631
from codeinterpreterapi.config import settings
32+
from codeinterpreterapi.parser import CodeAgentOutputParser, CodeChatAgentOutputParser
2733
from codeinterpreterapi.prompts import code_interpreter_system_message
2834
from codeinterpreterapi.schema import (
2935
CodeInput,
3036
CodeInterpreterResponse,
3137
File,
38+
SessionStatus,
3239
UserRequest,
3340
)
34-
from codeinterpreterapi.utils import (
35-
CodeAgentOutputParser,
36-
CodeCallbackHandler,
37-
CodeChatAgentOutputParser,
38-
)
3941

4042

4143
class CodeInterpreterSession:
@@ -45,19 +47,20 @@ def __init__(
4547
additional_tools: list[BaseTool] = [],
4648
**kwargs,
4749
) -> None:
48-
self.codebox = CodeBox()
50+
self.codebox = CodeBox(**kwargs)
4951
self.verbose = kwargs.get("verbose", settings.VERBOSE)
5052
self.tools: list[BaseTool] = self._tools(additional_tools)
5153
self.llm: BaseLanguageModel = llm or self._choose_llm(**kwargs)
5254
self.agent_executor: AgentExecutor = self._agent_executor()
5355
self.input_files: list[File] = []
5456
self.output_files: list[File] = []
57+
self.code_log: list[tuple[str, str]] = []
5558

56-
def start(self) -> None:
57-
self.codebox.start()
59+
def start(self) -> SessionStatus:
60+
return SessionStatus.from_codebox_status(self.codebox.start())
5861

59-
async def astart(self) -> None:
60-
await self.codebox.astart()
62+
async def astart(self) -> SessionStatus:
63+
return SessionStatus.from_codebox_status(await self.codebox.astart())
6164

6265
def _tools(self, additional_tools: list[BaseTool]) -> list[BaseTool]:
6366
return additional_tools + [
@@ -128,7 +131,6 @@ def _choose_agent(self) -> BaseSingleActionAgent:
128131
def _agent_executor(self) -> AgentExecutor:
129132
return AgentExecutor.from_agent_and_tools(
130133
agent=self._choose_agent(),
131-
callbacks=[CodeCallbackHandler(self)],
132134
max_iterations=9,
133135
tools=self.tools,
134136
verbose=self.verbose,
@@ -137,18 +139,67 @@ def _agent_executor(self) -> AgentExecutor:
137139
),
138140
)
139141

140-
async def show_code(self, code: str) -> None:
142+
def show_code(self, code: str) -> None:
143+
if self.verbose:
144+
print(code)
145+
146+
async def ashow_code(self, code: str) -> None:
141147
"""Callback function to show code to the user."""
142148
if self.verbose:
143149
print(code)
144150

145151
def _run_handler(self, code: str):
146-
raise NotImplementedError("Use arun_handler for now.")
152+
"""Run code in container and send the output to the user"""
153+
self.show_code(code)
154+
output: CodeBoxOutput = self.codebox.run(code)
155+
self.code_log.append((code, output.content))
156+
157+
if not isinstance(output.content, str):
158+
raise TypeError("Expected output.content to be a string.")
159+
160+
if output.type == "image/png":
161+
filename = f"image-{uuid.uuid4()}.png"
162+
file_buffer = BytesIO(base64.b64decode(output.content))
163+
file_buffer.name = filename
164+
self.output_files.append(File(name=filename, content=file_buffer.read()))
165+
return f"Image {filename} got send to the user."
166+
167+
elif output.type == "error":
168+
if "ModuleNotFoundError" in output.content:
169+
if package := re.search(
170+
r"ModuleNotFoundError: No module named '(.*)'", output.content
171+
):
172+
self.codebox.install(package.group(1))
173+
return (
174+
f"{package.group(1)} was missing but "
175+
"got installed now. Please try again."
176+
)
177+
else:
178+
# TODO: preanalyze error to optimize next code generation
179+
pass
180+
if self.verbose:
181+
print("Error:", output.content)
182+
183+
elif modifications := get_file_modifications(code, self.llm):
184+
for filename in modifications:
185+
if filename in [file.name for file in self.input_files]:
186+
continue
187+
fileb = self.codebox.download(filename)
188+
if not fileb.content:
189+
continue
190+
file_buffer = BytesIO(fileb.content)
191+
file_buffer.name = filename
192+
self.output_files.append(
193+
File(name=filename, content=file_buffer.read())
194+
)
195+
196+
return output.content
147197

148198
async def _arun_handler(self, code: str):
149199
"""Run code in container and send the output to the user"""
150-
print("Running code in container...", code)
200+
await self.ashow_code(code)
151201
output: CodeBoxOutput = await self.codebox.arun(code)
202+
self.code_log.append((code, output.content))
152203

153204
if not isinstance(output.content, str):
154205
raise TypeError("Expected output.content to be a string.")
@@ -176,7 +227,7 @@ async def _arun_handler(self, code: str):
176227
if self.verbose:
177228
print("Error:", output.content)
178229

179-
elif modifications := await get_file_modifications(code, self.llm):
230+
elif modifications := await aget_file_modifications(code, self.llm):
180231
for filename in modifications:
181232
if filename in [file.name for file in self.input_files]:
182233
continue
@@ -191,7 +242,22 @@ async def _arun_handler(self, code: str):
191242

192243
return output.content
193244

194-
async def _input_handler(self, request: UserRequest):
245+
def _input_handler(self, request: UserRequest) -> None:
246+
"""Callback function to handle user input."""
247+
if not request.files:
248+
return
249+
if not request.content:
250+
request.content = (
251+
"I uploaded, just text me back and confirm that you got the file(s)."
252+
)
253+
request.content += "\n**The user uploaded the following files: **\n"
254+
for file in request.files:
255+
self.input_files.append(file)
256+
request.content += f"[Attachment: {file.name}]\n"
257+
self.codebox.upload(file.name, file.content)
258+
request.content += "**File(s) are now available in the cwd. **\n"
259+
260+
async def _ainput_handler(self, request: UserRequest):
195261
# TODO: variables as context to the agent
196262
# TODO: current files as context to the agent
197263
if not request.files:
@@ -207,7 +273,7 @@ async def _input_handler(self, request: UserRequest):
207273
await self.codebox.aupload(file.name, file.content)
208274
request.content += "**File(s) are now available in the cwd. **\n"
209275

210-
async def _output_handler(self, final_response: str) -> CodeInterpreterResponse:
276+
def _output_handler(self, final_response: str) -> CodeInterpreterResponse:
211277
"""Embed images in the response"""
212278
for file in self.output_files:
213279
if str(file.name) in final_response:
@@ -216,25 +282,98 @@ async def _output_handler(self, final_response: str) -> CodeInterpreterResponse:
216282

217283
if self.output_files and re.search(r"\n\[.*\]\(.*\)", final_response):
218284
try:
219-
final_response = await remove_download_link(final_response, self.llm)
285+
final_response = remove_download_link(final_response, self.llm)
220286
except Exception as e:
221287
if self.verbose:
222288
print("Error while removing download links:", e)
223289

224-
return CodeInterpreterResponse(content=final_response, files=self.output_files)
290+
output_files = self.output_files
291+
code_log = self.code_log
292+
self.output_files = []
293+
self.code_log = []
294+
295+
return CodeInterpreterResponse(
296+
content=final_response, files=output_files, code_log=code_log
297+
)
298+
299+
async def _aoutput_handler(self, final_response: str) -> CodeInterpreterResponse:
300+
"""Embed images in the response"""
301+
for file in self.output_files:
302+
if str(file.name) in final_response:
303+
# rm ![Any](file.name) from the response
304+
final_response = re.sub(r"\n\n!\[.*\]\(.*\)", "", final_response)
305+
306+
if self.output_files and re.search(r"\n\[.*\]\(.*\)", final_response):
307+
try:
308+
final_response = await aremove_download_link(final_response, self.llm)
309+
except Exception as e:
310+
if self.verbose:
311+
print("Error while removing download links:", e)
312+
313+
output_files = self.output_files
314+
code_log = self.code_log
315+
self.output_files = []
316+
self.code_log = []
317+
318+
return CodeInterpreterResponse(
319+
content=final_response, files=output_files, code_log=code_log
320+
)
321+
322+
def generate_response_sync(
323+
self,
324+
user_msg: str,
325+
files: list[File] = [],
326+
detailed_error: bool = False,
327+
) -> CodeInterpreterResponse:
328+
"""Generate a Code Interpreter response based on the user's input."""
329+
user_request = UserRequest(content=user_msg, files=files)
330+
try:
331+
self._input_handler(user_request)
332+
response = self.agent_executor.run(input=user_request.content)
333+
return self._output_handler(response)
334+
except Exception as e:
335+
if self.verbose:
336+
traceback.print_exc()
337+
if detailed_error:
338+
return CodeInterpreterResponse(
339+
content="Error in CodeInterpreterSession: "
340+
f"{e.__class__.__name__} - {e}"
341+
)
342+
else:
343+
return CodeInterpreterResponse(
344+
content="Sorry, something went while generating your response."
345+
"Please try again or restart the session."
346+
)
225347

226348
async def generate_response(
227349
self,
228350
user_msg: str,
229351
files: list[File] = [],
230352
detailed_error: bool = False,
353+
) -> CodeInterpreterResponse:
354+
print(
355+
"DEPRECATION WARNING: Use agenerate_response for async generation.\n"
356+
"This function will be converted to sync in the future.\n"
357+
"You can use generate_response_sync for now.",
358+
)
359+
return await self.agenerate_response(
360+
user_msg=user_msg,
361+
files=files,
362+
detailed_error=detailed_error,
363+
)
364+
365+
async def agenerate_response(
366+
self,
367+
user_msg: str,
368+
files: list[File] = [],
369+
detailed_error: bool = False,
231370
) -> CodeInterpreterResponse:
232371
"""Generate a Code Interpreter response based on the user's input."""
233372
user_request = UserRequest(content=user_msg, files=files)
234373
try:
235-
await self._input_handler(user_request)
374+
await self._ainput_handler(user_request)
236375
response = await self.agent_executor.arun(input=user_request.content)
237-
return await self._output_handler(response)
376+
return await self._aoutput_handler(response)
238377
except Exception as e:
239378
if self.verbose:
240379
traceback.print_exc()
@@ -249,11 +388,24 @@ async def generate_response(
249388
"Please try again or restart the session."
250389
)
251390

252-
async def is_running(self) -> bool:
391+
def is_running(self) -> bool:
392+
return self.codebox.status() == "running"
393+
394+
async def ais_running(self) -> bool:
253395
return await self.codebox.astatus() == "running"
254396

255-
async def astop(self) -> None:
256-
await self.codebox.astop()
397+
def stop(self) -> SessionStatus:
398+
return SessionStatus.from_codebox_status(self.codebox.stop())
399+
400+
async def astop(self) -> SessionStatus:
401+
return SessionStatus.from_codebox_status(await self.codebox.astop())
402+
403+
def __enter__(self) -> "CodeInterpreterSession":
404+
self.start()
405+
return self
406+
407+
def __exit__(self, exc_type, exc_value, traceback) -> None:
408+
self.stop()
257409

258410
async def __aenter__(self) -> "CodeInterpreterSession":
259411
await self.astart()

0 commit comments

Comments
 (0)