1
- from dataclasses import dataclass
1
+ from dataclasses import dataclass , field
2
2
3
3
import datasets
4
4
import duckdb
9
9
10
10
@dataclass
11
11
class AnalystAgentDeps :
12
- output : dict [str , pd .DataFrame ]
12
+ output : dict [str , pd .DataFrame ] = field ( default_factory = dict )
13
13
14
14
def store (self , value : pd .DataFrame ) -> str :
15
15
"""Store the output in deps and return the reference such as Out[1] to be used by the LLM."""
@@ -84,7 +84,7 @@ def run_duckdb(ctx: RunContext[AnalystAgentDeps], dataset: str, sql: str) -> str
84
84
dataset: reference string to the DataFrame
85
85
sql: the query to be executed using DuckDB
86
86
"""
87
- data = ctx .deps .output [ dataset ]
87
+ data = ctx .deps .get ( dataset )
88
88
result = duckdb .query_df (df = data , virtual_table_name = 'dataset' , sql_query = sql )
89
89
# pass the result as ref (because DuckDB SQL can select many rows, creating another huge dataframe)
90
90
ref = ctx .deps .store (result .df ()) # pyright: ignore[reportUnknownMemberType]
@@ -93,13 +93,13 @@ def run_duckdb(ctx: RunContext[AnalystAgentDeps], dataset: str, sql: str) -> str
93
93
94
94
@analyst_agent .tool
95
95
def display (ctx : RunContext [AnalystAgentDeps ], name : str ) -> str :
96
- """Display at most 5 rows of the dataframe ."""
96
+ """Display at most 5 rows of the dataframe."""
97
97
dataset = ctx .deps .get (name )
98
98
return dataset .head ().to_string () # pyright: ignore[reportUnknownMemberType]
99
99
100
100
101
101
if __name__ == '__main__' :
102
- deps = AnalystAgentDeps (output = {} )
102
+ deps = AnalystAgentDeps ()
103
103
result = analyst_agent .run_sync (
104
104
user_prompt = 'Count how many negative comments are there in the dataset `cornell-movie-review-data/rotten_tomatoes`' ,
105
105
deps = deps ,
0 commit comments