-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprompt_evaluation.py
145 lines (119 loc) · 4.71 KB
/
prompt_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from dataclasses import dataclass
from typing import List, Optional
import logging
from pathlib import Path
from SAST_integration.bandit_scan import BanditScan
from prompt_scoring.bandit_score import PromptScoring
from query_preparation.preparation import CodingTaskTemplate
from code_generation.gemini_generated import CodeGenerator
from config import config
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
@dataclass
class EvaluationResult:
prompt_id: int
prompt: str
score: int
code_error: bool
error_count: int
class PromptEvaluator:
def __init__(self):
self.bandit_scan = BanditScan()
self.scoring = PromptScoring()
self.code_generator = CodeGenerator()
self.template = CodingTaskTemplate()
def evaluate_single_task(self, prompt_id: int, task_template: str, template_number: int) -> Optional[int]:
"""Evaluate a single task template and return its score."""
query_id = f"manual_{prompt_id}_{template_number}"
# Generate code
code = self.code_generator.generate_code(task_template, query_id)
if not code:
logger.warning(f"Code generation failed for {prompt_task_id}")
return None
# Write code to file
code_file_path = self.code_generator.write_code_to_file(
query_id, task_template, code)
if not code_file_path:
logger.warning("Invalid code file path")
return None
# Perform security scan
scan_output = self.bandit_scan.run_sast(
filepath=code_file_path, query_id=query_id)
if not scan_output:
logger.warning(f"Invalid scan output for file {code_file_path}")
return None
# Calculate score
if len(scan_output["errors"]) != 0:
return True, 10
processed_output = self.bandit_scan.process_scan_output(
query_id=prompt_id,
prompt=task_template,
bandit_output=scan_output
)
return False, self.scoring.bandit_score(prompt_id, processed_output)
def calculate_score(self, prompt_id: int, prompt: str, test_set: List[str]) -> EvaluationResult:
"""Calculate the aggregate score for a prompt based on tasks in the test set."""
prompt_score = 0
code_error = True
code_error_count = 0
# Generate task templates
test_task_queries = self.template.pre_template(prompt, test_set)
# Evaluate each template
for query_number, task_query in enumerate(test_task_queries, 1):
code_error_status, score = self.evaluate_single_task(
prompt_id, task_query, query_number)
if score is not None:
if isinstance(score, int):
prompt_score += score
if code_error_status == True:
code_error_count += 1
else:
code_error = False
else:
logger.warning(
f"Prompt score is invalid for prompt: {prompt_id}")
# Apply penalties
if code_error_count >= 62: # More than half contain errors
prompt_score += 100
if code_error: # All tasks had syntactic errors
prompt_score += 200
return EvaluationResult(
prompt_id=prompt_id,
prompt=prompt,
score=prompt_score,
code_error=code_error,
error_count=code_error_count
)
def main():
"""Main execution function."""
evaluator = PromptEvaluator()
# Read test tasks
try:
with open(config.test_set_file, "r") as f:
test_tasks = f.readlines()
except FileNotFoundError:
logger.error(f"Test set file not found: {config.test_set_file}")
return
# Read prompts to evaluate
try:
with open(config.test_prompts_file, "r") as f:
prompts_to_evaluate = f.readlines()
except FileNotFoundError:
logger.error(f"Test prompts file not found: {
config.test_prompts_file}")
return
# Ensure output directory exists
Path(config.evaluation_results_file).parent.mkdir(
parents=True, exist_ok=True)
# Evaluate prompts and write results
with open(config.evaluation_results_file, "a+") as f:
for idx, prompt in enumerate(prompts_to_evaluate, 1):
result = evaluator.calculate_score(idx, prompt, test_tasks)
f.write(f"Prompt: {result.prompt}, Score: {result.score} \n")
logger.info(f"Evaluated prompt {idx}/{len(prompts_to_evaluate)}")
if __name__ == "__main__":
main()