Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package bar.examples.it.spring.aiclassification.basic

import org.springframework.boot.SpringApplication
import org.springframework.boot.autoconfigure.SpringBootApplication
import org.springframework.boot.autoconfigure.security.servlet.SecurityAutoConfiguration
import org.springframework.http.ResponseEntity
import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RequestParam
import org.springframework.web.bind.annotation.RestController
import javax.ws.rs.QueryParam

@SpringBootApplication(exclude = [SecurityAutoConfiguration::class])
@RequestMapping(path = ["/api/basic"])
@RestController
open class BasicApplication {

companion object {
@JvmStatic
fun main(args: Array<String>) {
SpringApplication.run(BasicApplication::class.java, *args)
}
}

enum class Alphabet {
A,
B,
C,
D
}

@GetMapping
open fun getData(
@RequestParam("x", required = false) x: Alphabet?,
@RequestParam("y", required = false) y: Int?,
@RequestParam("z", required = false) z: Boolean?,
): ResponseEntity<String> {

// No dependency, just constraint on a single variable
if (y == null) {
return ResponseEntity.status(400).build()
}

return ResponseEntity.ok().body("OK")
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package bar.examples.it.spring.aiconstraint.numeric
package bar.examples.it.spring.aiclassification.multitype

import io.swagger.v3.oas.annotations.Parameter
import org.springframework.boot.SpringApplication
Expand All @@ -10,12 +10,12 @@ import org.springframework.web.bind.annotation.*
@SpringBootApplication(exclude = [SecurityAutoConfiguration::class])
@RequestMapping(path = ["/petShopApi"])
@RestController
open class AICMultiTypeApplication {
open class MultiTypeApplication {

companion object {
@JvmStatic
fun main(args: Array<String>) {
SpringApplication.run(AICMultiTypeApplication::class.java, *args)
SpringApplication.run(MultiTypeApplication::class.java, *args)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package bar.examples.it.spring.aiclassification.basic

import bar.examples.it.spring.SpringController

class BasicController : SpringController(BasicApplication::class.java)

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package bar.examples.it.spring.aiclassification.multitype

import bar.examples.it.spring.SpringController

class MultiTypeController : SpringController(MultiTypeApplication::class.java)

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.evomaster.core.problem.rest.aiconstraint.numeric
package org.evomaster.core.problem.rest.aiclassification

import bar.examples.it.spring.aiconstraint.numeric.AICMultiTypeController
import bar.examples.it.spring.aiclassification.basic.BasicController
import bar.examples.it.spring.aiclassification.multitype.MultiTypeController
import org.evomaster.core.problem.enterprise.SampleType
import org.evomaster.core.problem.rest.IntegrationTestRestBase
import org.evomaster.core.problem.rest.data.RestCallAction
Expand All @@ -13,13 +14,15 @@ import org.evomaster.core.problem.rest.builder.RestActionBuilderV3
import org.evomaster.core.problem.rest.schema.RestSchema
import org.evomaster.core.EMConfig
import org.evomaster.core.problem.rest.classifier.GLMOnlineClassifier
import org.evomaster.core.problem.rest.classifier.InputEncoderUtils
import org.evomaster.core.problem.rest.schema.OpenApiAccess
import org.evomaster.core.problem.rest.service.sampler.AbstractRestSampler
import org.evomaster.core.search.action.Action
import org.evomaster.core.search.service.Randomness
import java.net.HttpURLConnection
import java.net.URL
import javax.ws.rs.core.MediaType
import kotlin.collections.iterator
import kotlin.math.abs


Expand All @@ -28,7 +31,8 @@ class AIGLMCheck : IntegrationTestRestBase() {
companion object {
@JvmStatic
fun init() {
initClass(AICMultiTypeController())
initClass(BasicController())
// initClass(MultiTypeController())
}

@JvmStatic
Expand Down Expand Up @@ -66,7 +70,7 @@ class AIGLMCheck : IntegrationTestRestBase() {

val body = connection.inputStream.bufferedReader().use { it.readText() }
result.setBody(body)
result.setBodyType(MediaType.APPLICATION_JSON_TYPE) // or guess based on Content-Type header
result.setBodyType(MediaType.APPLICATION_JSON_TYPE) // or guess based on the Content-Type header

} catch (e: Exception) {
result.setTimedout(true)
Expand Down Expand Up @@ -95,32 +99,26 @@ class AIGLMCheck : IntegrationTestRestBase() {
val actionList = actionCluster.values.filterIsInstance<RestCallAction>()

val endpointToDimension = mutableMapOf<String, Int?>()
val endpointToCorrectPrediction = mutableMapOf<String, Int>()
val endpointToTotalExecution = mutableMapOf<String, Int>()
val endpointToAccuracy = mutableMapOf<String, Double>()
for (action in actionList) {
val name = action.getName()

val hasUnsupportedGene = action.parameters.any { p ->
val g = p.gene
val g = p.primaryGene().getLeafGene()
g !is IntegerGene && g !is DoubleGene && g !is BooleanGene && g !is EnumGene<*>
}

val dimension = if (hasUnsupportedGene) {
null
} else {
action.parameters.count { p ->
val g = p.gene
val g = p.primaryGene().getLeafGene()
g is IntegerGene || g is DoubleGene || g is BooleanGene || g is EnumGene<*>
}
}

println("Endpoint: $name, dimension: $dimension")
endpointToDimension[name] = dimension

endpointToCorrectPrediction[name] = 0
endpointToTotalExecution[name] = 0
endpointToAccuracy[name] = 0.0
}

/**
Expand All @@ -133,7 +131,7 @@ class AIGLMCheck : IntegrationTestRestBase() {
endpointToClassifier[name] = null
}else{
val model = GLMOnlineClassifier()
model.setDimension(dimension)
model.setup(dimension=dimension, warmup = 10)
endpointToClassifier[name] = model
}
}
Expand All @@ -155,18 +153,18 @@ class AIGLMCheck : IntegrationTestRestBase() {
val name = sampledAction.getName()
val classifier = endpointToClassifier[name]
val dimension = endpointToDimension[name]
val geneValues = sampledAction.parameters.map { it.gene.getValueAsRawString() }
var cp = endpointToCorrectPrediction[name]!!
var tot = endpointToTotalExecution[name]!!
var ac = endpointToAccuracy[name]!!
val geneValues = sampledAction.parameters.map { it.primaryGene().getValueAsRawString() }

println("*************************************************")
println("Path : $name")
println("Classifier : ${if (classifier == null) "null" else "GLM"}")
println("Dimension : $dimension")
println("Input Genes : ${geneValues.joinToString(", ")}")
println("Actual Genes : ${geneValues.size}")
println("cp, tot, ac : $cp, $tot, $ac")
println("Genes Size : ${geneValues.size}")
println("Correct Predictions: ${classifier?.performance?.correctPrediction}")
println("Total Requests : ${classifier?.performance?.totalSentRequests}")
println("Accuracy : ${classifier?.performance?.accuracy()}")


// executeRestCallAction is replaced with createIndividual to avoid override error
// val individual = createIndividual(listOf(sampledAction), SampleType.RANDOM)
Expand All @@ -180,51 +178,41 @@ class AIGLMCheck : IntegrationTestRestBase() {
continue
}
// Warmup cold classifiers by at least n request
val n = 2
val isCold = tot <= n
val isCold = classifier.performance.totalSentRequests<classifier.warmup
if (isCold) {
println("Warmup by at least $n request")
println("Warmup by at least ${classifier.warmup} request")
val result = executeRestCallAction(action, "$baseUrlOfSut")
classifier.updateModel(action, result)
tot += 1
endpointToTotalExecution[name] = tot
continue
}

// Classification
println("Classifying!")
val rawEncodedFeatures = InputEncoderUtils.encode(sampledAction).rawEncodedFeatures
println("Raw encoded features are : ${rawEncodedFeatures.joinToString(", ")}")
val classification = classifier.classify(action)
val p200 = classification.probabilities[200]!!
val p400 = classification.probabilities[400]!!
require(p200 in 0.0..1.0 && p400 in 0.0..1.0 && abs((p200 + p400) - 1.0) < 1e-6) {
"Probabilities must be in [0,1] and sum to 1"
}

// Prediction
val prediction: Int = if (p200 > p400) 200 else 400
val predictionOfStatusCode = classification.prediction()
println("Prediction is : $predictionOfStatusCode")

// Probabilistic decision-making based on Bernoulli(prob = aci)
val sendOrNot: Boolean
if (prediction == 200) {
if (predictionOfStatusCode != 400) {
sendOrNot = true
}else{
sendOrNot = if(Math.random() > ac) true else false
sendOrNot = if(Math.random() > classifier.performance.accuracy()) true else false
}

// Execute the request and update
if (sendOrNot) {
val result = executeRestCallAction(action,"$baseUrlOfSut")
println("Response : ${result.getStatusCode()}")

if (result.getStatusCode()==prediction) {
cp = cp + 1
}
tot = tot + 1
ac = cp.toDouble() / tot

endpointToCorrectPrediction[name] = cp
endpointToTotalExecution[name] = tot
endpointToAccuracy[name] = ac
println("Updating the classifier!")
classifier.updateModel(action, result)

println("Updating the classifier!")
classifier.updateModel(action, result)
}

Expand Down
Loading
Loading