Skip to content

Commit 8ee7fb4

Browse files
authored
Merge pull request #1265 from WebFuzzing/ai-multi-endpoint-classifier
Ai multi endpoint classifier
2 parents e2d269c + 1e9efd1 commit 8ee7fb4

File tree

7 files changed

+299
-301
lines changed

7 files changed

+299
-301
lines changed

core-it/src/main/kotlin/bar/examples/it/spring/aiconstraint/numeric/AICMultiTypeApplication.kt

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import org.springframework.http.ResponseEntity
88
import org.springframework.web.bind.annotation.*
99

1010
@SpringBootApplication(exclude = [SecurityAutoConfiguration::class])
11-
@RequestMapping(path = ["/api"])
11+
@RequestMapping(path = ["/petShopApi"])
1212
@RestController
1313
open class AICMultiTypeApplication {
1414

@@ -32,7 +32,7 @@ open class AICMultiTypeApplication {
3232
FEMALE
3333
}
3434

35-
@GetMapping("/petShop")
35+
@GetMapping("/petInfo")
3636
open fun getString(
3737

3838
@RequestParam("category", required = true)
@@ -91,4 +91,30 @@ open class AICMultiTypeApplication {
9191
)
9292
}
9393

94+
@GetMapping("/ownerInfo")
95+
open fun getOwnerInfo(
96+
@RequestParam("id", required = true)
97+
@Parameter(required = true, description = "Owner's id")
98+
id: Int,
99+
100+
@RequestParam("age", required = true)
101+
@Parameter(required = true, description = "Owner's age")
102+
age: Int
103+
104+
): ResponseEntity<String> {
105+
106+
if (id <= 0) {
107+
return ResponseEntity.status(400).body("Owner id must be a positive number.")
108+
}
109+
if (age <= 0) {
110+
return ResponseEntity.status(400).body("Owner age must be a positive number.")
111+
}
112+
113+
// Response
114+
return ResponseEntity.status(200).body(
115+
"Owner Name: $id, Age: $age"
116+
)
117+
118+
}
119+
94120
}

core-it/src/test/kotlin/org/evomaster/core/problem/rest/aiconstraint/numeric/AICMultiTypeCheck.kt

Lines changed: 0 additions & 106 deletions
This file was deleted.

core-it/src/test/kotlin/org/evomaster/core/problem/rest/aiconstraint/numeric/AIGLMCheck.kt

Lines changed: 87 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,14 @@ import org.evomaster.core.problem.rest.builder.RestActionBuilderV3
1414
import org.evomaster.core.problem.rest.schema.RestSchema
1515
import org.evomaster.core.EMConfig
1616
import org.evomaster.core.problem.rest.classifier.GLMOnlineClassifier
17+
import org.evomaster.core.problem.rest.data.RestPath
1718
import org.evomaster.core.problem.rest.schema.OpenApiAccess
19+
import org.evomaster.core.problem.rest.service.sampler.AbstractRestSampler
1820
import org.evomaster.core.search.action.Action
1921
import org.evomaster.core.search.service.Randomness
22+
import java.net.HttpURLConnection
23+
import java.net.URL
24+
import javax.ws.rs.core.MediaType
2025

2126

2227
class AIGLMCheck : IntegrationTestRestBase() {
@@ -45,100 +50,126 @@ class AIGLMCheck : IntegrationTestRestBase() {
4550
)
4651
}
4752

48-
fun runClassifierExample() {
53+
fun executeRestCallAction(action: RestCallAction, baseUrlOfSut: String): RestCallResult {
54+
val fullUrl = "$baseUrlOfSut${action.resolvedPath()}"
55+
val url = URL(fullUrl)
56+
val connection = url.openConnection() as HttpURLConnection
57+
58+
connection.requestMethod = action.verb.name
59+
connection.connectTimeout = 5000
60+
connection.readTimeout = 5000
61+
62+
val result = RestCallResult(action.getLocalId())
63+
64+
try {
65+
val status = connection.responseCode
66+
result.setStatusCode(status)
67+
68+
val body = connection.inputStream.bufferedReader().use { it.readText() }
69+
result.setBody(body)
70+
result.setBodyType(MediaType.APPLICATION_JSON_TYPE) // or guess based on Content-Type header
71+
72+
} catch (e: Exception) {
73+
result.setTimedout(true)
74+
result.setBody("ERROR: ${e.message}")
75+
}
4976

50-
/**
51-
* Generate a random RestCallAction using EvoMaster Randomness
52-
*/
53-
// Fetch and parse OpenAPI schema based on the schema location
77+
return result
78+
}
79+
80+
fun runClassifierExample() {
5481
val schema = OpenApiAccess.getOpenAPIFromLocation("$baseUrlOfSut/v3/api-docs")
55-
// Wrap schema into RestSchema
5682
val restSchema = RestSchema(schema)
57-
// Configuration for gene generation
83+
5884
val config = EMConfig().apply {
85+
aiModelForResponseClassification = EMConfig.AIResponseClassifierModel.GLM
5986
enableSchemaConstraintHandling = true
6087
allowInvalidData = false
6188
probRestDefault = 0.0
6289
probRestExamples = 0.0
6390
}
91+
92+
6493
val options = RestActionBuilderV3.Options(config)
65-
// actionCluster contains provides possible actions
6694
val actionCluster = mutableMapOf<String, Action>()
67-
// Generate RestCallAction
6895
RestActionBuilderV3.addActionsFromSwagger(restSchema, actionCluster, options = options)
69-
// Sample one random RestCallAction
70-
val random = Randomness()
96+
7197
val actionList = actionCluster.values.filterIsInstance<RestCallAction>()
72-
val template = random.choose(actionList)
73-
val sampledAction = template.copy() as RestCallAction
74-
sampledAction.doInitialize(random)
75-
76-
// Calculate the input dimension of the classifier
77-
var dimension:Int = 0
78-
for (gene in sampledAction.seeTopGenes()) {
79-
when (gene) {
80-
is IntegerGene, is DoubleGene, is BooleanGene, is EnumGene<*> -> {
81-
dimension++
82-
}
98+
99+
val pathToDimension = mutableMapOf<RestPath, Int>()
100+
for (action in actionList) {
101+
val path = action.path
102+
if (pathToDimension.containsKey(path)) continue
103+
104+
val dimension = action.parameters.count { p ->
105+
val g = p.gene
106+
g is IntegerGene || g is DoubleGene || g is BooleanGene || g is EnumGene<*>
83107
}
108+
pathToDimension[path] = dimension
84109
}
85-
require(dimension == 6)
86110

87-
// Create a glm classifier
88-
val classifier = injector.getInstance(AIResponseClassifier::class.java)
89-
//classifier.initModel(dimension) //FIXME
111+
val pathToClassifier = mutableMapOf<RestPath, GLMOnlineClassifier>()
112+
for ((path, dimension) in pathToDimension) {
113+
val model = GLMOnlineClassifier()
114+
model.setDimension(dimension)
115+
pathToClassifier[path] = model
116+
}
117+
118+
println("Classifiers initialized with their dimensions:")
119+
for ((path, expected) in pathToDimension) {
120+
val classifier = pathToClassifier[path]!!
121+
println("$path -> expected: $expected, actualDim: ${classifier.getDimension()}")
122+
}
90123

91-
// Use reflection to access the private delegate
92-
val delegateField = classifier::class.java.getDeclaredField("delegate")
93-
delegateField.isAccessible = true
94-
val glm = delegateField.get(classifier) as? GLMOnlineClassifier
95124

96-
var time =1
125+
val random = Randomness()
126+
val sampler = injector.getInstance(AbstractRestSampler::class.java)
127+
var time = 1
97128
val timeLimit = 20
98129
while (time <= timeLimit) {
99130
val template = random.choose(actionList)
100131
val sampledAction = template.copy() as RestCallAction
101132
sampledAction.doInitialize(random)
102133

134+
val path = sampledAction.path
135+
val dimension = pathToDimension[path] ?: error("No dimension for path: $path")
136+
val classifier = pathToClassifier[path] ?: error("Expected classifier for path: $path")
103137
val geneValues = sampledAction.parameters.map { it.gene.getValueAsRawString() }
104-
println("**********************************************")
105-
println("Time: $time")
106-
println("Genes: [${geneValues.joinToString(", ")}]")
107138

108-
// createIndividual send the request and evaluate
109-
val individual = createIndividual(listOf(sampledAction), SampleType.RANDOM)
110-
val evaluatedAction = individual.evaluatedMainActions()[0]
111-
val action = evaluatedAction.action as RestCallAction
112-
val result = evaluatedAction.result as RestCallResult
139+
println("*************************************************")
140+
println("Time : $time")
141+
println("Path : $path")
142+
println("Input Genes : ${geneValues.joinToString(", ")}")
143+
println("Input dim : ${classifier.getDimension()}")
144+
println("Expected Dim : $dimension")
145+
println("Actual Genes : ${geneValues.size}")
113146

114-
// update the model
147+
// //executeRestCallAction is replaced with createIndividual to avoid override error
148+
// val individual = createIndividual(listOf(sampledAction), SampleType.RANDOM)
149+
val individual = sampler.createIndividual(SampleType.RANDOM, listOf(sampledAction).toMutableList())
150+
val action = individual.seeMainExecutableActions()[0]
151+
val result = executeRestCallAction(action,"$baseUrlOfSut")
152+
println("Response:\n${result.getStatusCode()}")
153+
154+
155+
// Update and classify
115156
classifier.updateModel(action, result)
157+
val classification = classifier.classify(action)
116158

117-
// classify an action
118-
val c = classifier.classify(action)
119-
// the classification provides two values as the probability of getting 400 and 200
120-
require(c.probabilities.values.all { it in 0.0..1.0 }) {
159+
println("Probabilities: ${classification.probabilities}")
160+
require(classification.probabilities.values.all { it in 0.0..1.0 }) {
121161
"All probabilities must be in [0,1]"
122162
}
123163

124-
if (glm != null) {
125-
val weightsAndBias = glm.getModelParams()
126-
println(
127-
"""
128-
Weights and Bias = $weightsAndBias
129-
""".trimIndent()
130-
)
164+
if (classifier != null) {
165+
val weightsAndBias = classifier.getModelParams()
166+
println("Weights and Bias = $weightsAndBias")
131167
println("**********************************************")
132168
println("**********************************************")
133169
} else {
134170
println("The classifier is not a GLMOnlineClassifier")
135171
}
136-
137172
time++
138173
}
139-
140-
141174
}
142-
143175
}
144-

0 commit comments

Comments
 (0)