22
22
from langchain .tools import BaseTool , StructuredTool
23
23
24
24
from codeinterpreterapi .agents import OpenAIFunctionsAgent
25
- from codeinterpreterapi .chains import get_file_modifications , remove_download_link
25
+ from codeinterpreterapi .chains import (
26
+ aget_file_modifications ,
27
+ aremove_download_link ,
28
+ get_file_modifications ,
29
+ remove_download_link ,
30
+ )
26
31
from codeinterpreterapi .config import settings
32
+ from codeinterpreterapi .parser import CodeAgentOutputParser , CodeChatAgentOutputParser
27
33
from codeinterpreterapi .prompts import code_interpreter_system_message
28
34
from codeinterpreterapi .schema import (
29
35
CodeInput ,
30
36
CodeInterpreterResponse ,
31
37
File ,
38
+ SessionStatus ,
32
39
UserRequest ,
33
40
)
34
- from codeinterpreterapi .utils import (
35
- CodeAgentOutputParser ,
36
- CodeCallbackHandler ,
37
- CodeChatAgentOutputParser ,
38
- )
39
41
40
42
41
43
class CodeInterpreterSession :
@@ -45,19 +47,20 @@ def __init__(
45
47
additional_tools : list [BaseTool ] = [],
46
48
** kwargs ,
47
49
) -> None :
48
- self .codebox = CodeBox ()
50
+ self .codebox = CodeBox (** kwargs )
49
51
self .verbose = kwargs .get ("verbose" , settings .VERBOSE )
50
52
self .tools : list [BaseTool ] = self ._tools (additional_tools )
51
53
self .llm : BaseLanguageModel = llm or self ._choose_llm (** kwargs )
52
54
self .agent_executor : AgentExecutor = self ._agent_executor ()
53
55
self .input_files : list [File ] = []
54
56
self .output_files : list [File ] = []
57
+ self .code_log : list [tuple [str , str ]] = []
55
58
56
- def start (self ) -> None :
57
- self .codebox .start ()
59
+ def start (self ) -> SessionStatus :
60
+ return SessionStatus . from_codebox_status ( self .codebox .start () )
58
61
59
- async def astart (self ) -> None :
60
- await self .codebox .astart ()
62
+ async def astart (self ) -> SessionStatus :
63
+ return SessionStatus . from_codebox_status ( await self .codebox .astart () )
61
64
62
65
def _tools (self , additional_tools : list [BaseTool ]) -> list [BaseTool ]:
63
66
return additional_tools + [
@@ -128,7 +131,6 @@ def _choose_agent(self) -> BaseSingleActionAgent:
128
131
def _agent_executor (self ) -> AgentExecutor :
129
132
return AgentExecutor .from_agent_and_tools (
130
133
agent = self ._choose_agent (),
131
- callbacks = [CodeCallbackHandler (self )],
132
134
max_iterations = 9 ,
133
135
tools = self .tools ,
134
136
verbose = self .verbose ,
@@ -137,18 +139,67 @@ def _agent_executor(self) -> AgentExecutor:
137
139
),
138
140
)
139
141
140
- async def show_code (self , code : str ) -> None :
142
+ def show_code (self , code : str ) -> None :
143
+ if self .verbose :
144
+ print (code )
145
+
146
+ async def ashow_code (self , code : str ) -> None :
141
147
"""Callback function to show code to the user."""
142
148
if self .verbose :
143
149
print (code )
144
150
145
151
def _run_handler (self , code : str ):
146
- raise NotImplementedError ("Use arun_handler for now." )
152
+ """Run code in container and send the output to the user"""
153
+ self .show_code (code )
154
+ output : CodeBoxOutput = self .codebox .run (code )
155
+ self .code_log .append ((code , output .content ))
156
+
157
+ if not isinstance (output .content , str ):
158
+ raise TypeError ("Expected output.content to be a string." )
159
+
160
+ if output .type == "image/png" :
161
+ filename = f"image-{ uuid .uuid4 ()} .png"
162
+ file_buffer = BytesIO (base64 .b64decode (output .content ))
163
+ file_buffer .name = filename
164
+ self .output_files .append (File (name = filename , content = file_buffer .read ()))
165
+ return f"Image { filename } got send to the user."
166
+
167
+ elif output .type == "error" :
168
+ if "ModuleNotFoundError" in output .content :
169
+ if package := re .search (
170
+ r"ModuleNotFoundError: No module named '(.*)'" , output .content
171
+ ):
172
+ self .codebox .install (package .group (1 ))
173
+ return (
174
+ f"{ package .group (1 )} was missing but "
175
+ "got installed now. Please try again."
176
+ )
177
+ else :
178
+ # TODO: preanalyze error to optimize next code generation
179
+ pass
180
+ if self .verbose :
181
+ print ("Error:" , output .content )
182
+
183
+ elif modifications := get_file_modifications (code , self .llm ):
184
+ for filename in modifications :
185
+ if filename in [file .name for file in self .input_files ]:
186
+ continue
187
+ fileb = self .codebox .download (filename )
188
+ if not fileb .content :
189
+ continue
190
+ file_buffer = BytesIO (fileb .content )
191
+ file_buffer .name = filename
192
+ self .output_files .append (
193
+ File (name = filename , content = file_buffer .read ())
194
+ )
195
+
196
+ return output .content
147
197
148
198
async def _arun_handler (self , code : str ):
149
199
"""Run code in container and send the output to the user"""
150
- print ( "Running code in container..." , code )
200
+ await self . ashow_code ( code )
151
201
output : CodeBoxOutput = await self .codebox .arun (code )
202
+ self .code_log .append ((code , output .content ))
152
203
153
204
if not isinstance (output .content , str ):
154
205
raise TypeError ("Expected output.content to be a string." )
@@ -176,7 +227,7 @@ async def _arun_handler(self, code: str):
176
227
if self .verbose :
177
228
print ("Error:" , output .content )
178
229
179
- elif modifications := await get_file_modifications (code , self .llm ):
230
+ elif modifications := await aget_file_modifications (code , self .llm ):
180
231
for filename in modifications :
181
232
if filename in [file .name for file in self .input_files ]:
182
233
continue
@@ -191,7 +242,22 @@ async def _arun_handler(self, code: str):
191
242
192
243
return output .content
193
244
194
- async def _input_handler (self , request : UserRequest ):
245
+ def _input_handler (self , request : UserRequest ) -> None :
246
+ """Callback function to handle user input."""
247
+ if not request .files :
248
+ return
249
+ if not request .content :
250
+ request .content = (
251
+ "I uploaded, just text me back and confirm that you got the file(s)."
252
+ )
253
+ request .content += "\n **The user uploaded the following files: **\n "
254
+ for file in request .files :
255
+ self .input_files .append (file )
256
+ request .content += f"[Attachment: { file .name } ]\n "
257
+ self .codebox .upload (file .name , file .content )
258
+ request .content += "**File(s) are now available in the cwd. **\n "
259
+
260
+ async def _ainput_handler (self , request : UserRequest ):
195
261
# TODO: variables as context to the agent
196
262
# TODO: current files as context to the agent
197
263
if not request .files :
@@ -207,7 +273,7 @@ async def _input_handler(self, request: UserRequest):
207
273
await self .codebox .aupload (file .name , file .content )
208
274
request .content += "**File(s) are now available in the cwd. **\n "
209
275
210
- async def _output_handler (self , final_response : str ) -> CodeInterpreterResponse :
276
+ def _output_handler (self , final_response : str ) -> CodeInterpreterResponse :
211
277
"""Embed images in the response"""
212
278
for file in self .output_files :
213
279
if str (file .name ) in final_response :
@@ -216,25 +282,98 @@ async def _output_handler(self, final_response: str) -> CodeInterpreterResponse:
216
282
217
283
if self .output_files and re .search (r"\n\[.*\]\(.*\)" , final_response ):
218
284
try :
219
- final_response = await remove_download_link (final_response , self .llm )
285
+ final_response = remove_download_link (final_response , self .llm )
220
286
except Exception as e :
221
287
if self .verbose :
222
288
print ("Error while removing download links:" , e )
223
289
224
- return CodeInterpreterResponse (content = final_response , files = self .output_files )
290
+ output_files = self .output_files
291
+ code_log = self .code_log
292
+ self .output_files = []
293
+ self .code_log = []
294
+
295
+ return CodeInterpreterResponse (
296
+ content = final_response , files = output_files , code_log = code_log
297
+ )
298
+
299
+ async def _aoutput_handler (self , final_response : str ) -> CodeInterpreterResponse :
300
+ """Embed images in the response"""
301
+ for file in self .output_files :
302
+ if str (file .name ) in final_response :
303
+ # rm  from the response
304
+ final_response = re .sub (r"\n\n!\[.*\]\(.*\)" , "" , final_response )
305
+
306
+ if self .output_files and re .search (r"\n\[.*\]\(.*\)" , final_response ):
307
+ try :
308
+ final_response = await aremove_download_link (final_response , self .llm )
309
+ except Exception as e :
310
+ if self .verbose :
311
+ print ("Error while removing download links:" , e )
312
+
313
+ output_files = self .output_files
314
+ code_log = self .code_log
315
+ self .output_files = []
316
+ self .code_log = []
317
+
318
+ return CodeInterpreterResponse (
319
+ content = final_response , files = output_files , code_log = code_log
320
+ )
321
+
322
+ def generate_response_sync (
323
+ self ,
324
+ user_msg : str ,
325
+ files : list [File ] = [],
326
+ detailed_error : bool = False ,
327
+ ) -> CodeInterpreterResponse :
328
+ """Generate a Code Interpreter response based on the user's input."""
329
+ user_request = UserRequest (content = user_msg , files = files )
330
+ try :
331
+ self ._input_handler (user_request )
332
+ response = self .agent_executor .run (input = user_request .content )
333
+ return self ._output_handler (response )
334
+ except Exception as e :
335
+ if self .verbose :
336
+ traceback .print_exc ()
337
+ if detailed_error :
338
+ return CodeInterpreterResponse (
339
+ content = "Error in CodeInterpreterSession: "
340
+ f"{ e .__class__ .__name__ } - { e } "
341
+ )
342
+ else :
343
+ return CodeInterpreterResponse (
344
+ content = "Sorry, something went while generating your response."
345
+ "Please try again or restart the session."
346
+ )
225
347
226
348
async def generate_response (
227
349
self ,
228
350
user_msg : str ,
229
351
files : list [File ] = [],
230
352
detailed_error : bool = False ,
353
+ ) -> CodeInterpreterResponse :
354
+ print (
355
+ "DEPRECATION WARNING: Use agenerate_response for async generation.\n "
356
+ "This function will be converted to sync in the future.\n "
357
+ "You can use generate_response_sync for now." ,
358
+ )
359
+ return await self .agenerate_response (
360
+ user_msg = user_msg ,
361
+ files = files ,
362
+ detailed_error = detailed_error ,
363
+ )
364
+
365
+ async def agenerate_response (
366
+ self ,
367
+ user_msg : str ,
368
+ files : list [File ] = [],
369
+ detailed_error : bool = False ,
231
370
) -> CodeInterpreterResponse :
232
371
"""Generate a Code Interpreter response based on the user's input."""
233
372
user_request = UserRequest (content = user_msg , files = files )
234
373
try :
235
- await self ._input_handler (user_request )
374
+ await self ._ainput_handler (user_request )
236
375
response = await self .agent_executor .arun (input = user_request .content )
237
- return await self ._output_handler (response )
376
+ return await self ._aoutput_handler (response )
238
377
except Exception as e :
239
378
if self .verbose :
240
379
traceback .print_exc ()
@@ -249,11 +388,24 @@ async def generate_response(
249
388
"Please try again or restart the session."
250
389
)
251
390
252
- async def is_running (self ) -> bool :
391
+ def is_running (self ) -> bool :
392
+ return self .codebox .status () == "running"
393
+
394
+ async def ais_running (self ) -> bool :
253
395
return await self .codebox .astatus () == "running"
254
396
255
- async def astop (self ) -> None :
256
- await self .codebox .astop ()
397
+ def stop (self ) -> SessionStatus :
398
+ return SessionStatus .from_codebox_status (self .codebox .stop ())
399
+
400
+ async def astop (self ) -> SessionStatus :
401
+ return SessionStatus .from_codebox_status (await self .codebox .astop ())
402
+
403
+ def __enter__ (self ) -> "CodeInterpreterSession" :
404
+ self .start ()
405
+ return self
406
+
407
+ def __exit__ (self , exc_type , exc_value , traceback ) -> None :
408
+ self .stop ()
257
409
258
410
async def __aenter__ (self ) -> "CodeInterpreterSession" :
259
411
await self .astart ()
0 commit comments