diff --git a/core/src/main/kotlin/org/evomaster/core/Main.kt b/core/src/main/kotlin/org/evomaster/core/Main.kt index 5a7e9c545a..12fccc34e8 100644 --- a/core/src/main/kotlin/org/evomaster/core/Main.kt +++ b/core/src/main/kotlin/org/evomaster/core/Main.kt @@ -261,7 +261,7 @@ class Main { resetExternalServiceHandler(injector) - resetHTTPCallbackVerifier(injector) + stopHTTPCallbackVerifier(injector) val statistics = injector.getInstance(Statistics::class.java) val data = statistics.getData(solution) @@ -1019,9 +1019,9 @@ class Main { externalServiceHandler.reset() } - private fun resetHTTPCallbackVerifier(injector: Injector) { + private fun stopHTTPCallbackVerifier(injector: Injector) { val httpCallbackVerifier = injector.getInstance(HttpCallbackVerifier::class.java) - httpCallbackVerifier.reset() + httpCallbackVerifier.stop() } } } diff --git a/core/src/main/kotlin/org/evomaster/core/problem/rest/service/fitness/AbstractRestFitness.kt b/core/src/main/kotlin/org/evomaster/core/problem/rest/service/fitness/AbstractRestFitness.kt index e9d7ffa46c..2a73d50138 100644 --- a/core/src/main/kotlin/org/evomaster/core/problem/rest/service/fitness/AbstractRestFitness.kt +++ b/core/src/main/kotlin/org/evomaster/core/problem/rest/service/fitness/AbstractRestFitness.kt @@ -733,10 +733,12 @@ abstract class AbstractRestFitness : HttpWsFitness() { } // FIXME: Code never reach this when we recompute the fitness under SSRFAnalyser - // So the faults never get marked. - // ResourceRestFitness get invoked during the recompute + // When the execution reach this during recomputing fitness, [HttpCallbackVerifier] + // WireMock seems to be [null]. + // Due to that the method will never return true if any calls made. if (config.security && config.ssrf) { if (ssrfAnalyser.anyCallsMadeToHTTPVerifier(a)) { + // Code reach this point during the search, which is unnecessary during search rcr.setVulnerableForSSRF(true) } } @@ -1205,8 +1207,8 @@ abstract class AbstractRestFitness : HttpWsFitness() { idMapper.getFaultDescriptiveId(DefinedFaultCategory.SSRF, it.getName()) ) fv.updateTarget(scenarioId, 1.0, it.positionAmongMainActions()) - val paramName = ssrfAnalyser.getVulnerableParameterName(it) + val paramName = ssrfAnalyser.getVulnerableParameterName(it) ar.addFault(DetectedFault(DefinedFaultCategory.SSRF, it.getName(), paramName)) } } diff --git a/core/src/main/kotlin/org/evomaster/core/problem/security/service/HttpCallbackVerifier.kt b/core/src/main/kotlin/org/evomaster/core/problem/security/service/HttpCallbackVerifier.kt index 789dff2378..f3bac25b30 100644 --- a/core/src/main/kotlin/org/evomaster/core/problem/security/service/HttpCallbackVerifier.kt +++ b/core/src/main/kotlin/org/evomaster/core/problem/security/service/HttpCallbackVerifier.kt @@ -42,18 +42,24 @@ class HttpCallbackVerifier { @PreDestroy fun destroy() { - resetHTTPVerifier() + stop() } - fun initWireMockServer() { + fun prepare() { + if (isActive) { + return + } + try { val config = WireMockConfiguration() .extensions(ResponseTemplateTransformer(false)) .port(config.httpCallbackVerifierPort) - wireMockServer = WireMockServer(config) - wireMockServer!!.start() - wireMockServer!!.stubFor(getDefaultStub()) + val wm = WireMockServer(config) + wm.start() + wm.stubFor(getDefaultStub()) + + wireMockServer = wm } catch (e: Exception) { throw RuntimeException( e.message + @@ -63,10 +69,9 @@ class HttpCallbackVerifier { } fun isCallbackURL(value: String): Boolean { - // Regex pattern looks for URL contains [HTTP_CALLBACK_VERIFIER] address and [HTTPCallbackVerifier] - // port, along with the path /sink/ and UUID as token generated to make the callback URL unique. + // Regex pattern looks for URL contains the pattern generated by the [HTTPCallbackVerifier]. val pattern = - """^http:\/\/localhost:${config.httpCallbackVerifierPort}\/sink\/.{36}""".toRegex() + """^http:\/\/localhost:${config.httpCallbackVerifierPort}\/EM_SSRF_\d+$""".toRegex() return pattern.matches(value) } @@ -75,7 +80,9 @@ class HttpCallbackVerifier { * Method generates a unique callback link to be used as payload for SSRF. */ fun generateCallbackLink(name: String): String { - val ssrfPath = "/sink/${counter++}" + // FIXME: sink/EM_0 <- slash get replaced with a comma at some point, which fails + // the verification based on the metadata + val ssrfPath = "/EM_SSRF_${counter++}" wireMockServer!!.stubFor( WireMock.any(WireMock.urlEqualTo(ssrfPath)) @@ -86,7 +93,6 @@ class HttpCallbackVerifier { .withStatus(200) .withBody("OK") ) - ) val link = "http://localhost:${wireMockServer!!.port()}$ssrfPath" @@ -98,12 +104,12 @@ class HttpCallbackVerifier { /** * @param name represents the Action name - * * During stub creation, stubs are tagged with Action name in the metadata. */ fun verify(name: String): Boolean { if (isActive) { - wireMockServer!!.allServeEvents + wireMockServer!! + .allServeEvents .filter { event -> event.wasMatched } .forEach { e -> val matched = e.stubMapping.metadata @@ -116,14 +122,14 @@ class HttpCallbackVerifier { return false } - fun resetHTTPVerifier() { + fun reset() { wireMockServer?.resetAll() wireMockServer?.stubFor(getDefaultStub()) actionCallbackLinkMapping.clear() counter = 0 } - fun reset() { + fun stop() { counter = 0 wireMockServer?.stop() wireMockServer = null diff --git a/core/src/main/kotlin/org/evomaster/core/problem/security/service/SSRFAnalyser.kt b/core/src/main/kotlin/org/evomaster/core/problem/security/service/SSRFAnalyser.kt index 36b4702c8e..31b203ce81 100644 --- a/core/src/main/kotlin/org/evomaster/core/problem/security/service/SSRFAnalyser.kt +++ b/core/src/main/kotlin/org/evomaster/core/problem/security/service/SSRFAnalyser.kt @@ -15,11 +15,7 @@ import org.evomaster.core.problem.security.SSRFUtil import org.evomaster.core.search.EvaluatedIndividual import org.evomaster.core.search.Solution import org.evomaster.core.search.gene.Gene -import org.evomaster.core.search.gene.ObjectGene -import org.evomaster.core.search.gene.wrapper.ChoiceGene -import org.evomaster.core.search.gene.wrapper.CustomMutationRateGene -import org.evomaster.core.search.gene.wrapper.OptionalGene -import org.evomaster.core.search.gene.string.StringGene +import org.evomaster.core.search.gene.utils.GeneUtils import org.evomaster.core.search.service.Archive import org.evomaster.core.search.service.FitnessFunction import org.slf4j.Logger @@ -63,6 +59,10 @@ class SSRFAnalyser { */ private lateinit var individualsInSolution: List> + private val urlRegexPattern: Regex = Regex("/url|source|remote|target/ig") + + private val potentialUrlParamNames: List = listOf("url", "source", "target", "datasource") + companion object { private val log: Logger = LoggerFactory.getLogger(SSRFAnalyser::class.java) } @@ -70,6 +70,7 @@ class SSRFAnalyser { @PostConstruct fun init() { log.debug("Initializing {}", SSRFAnalyser::class.simpleName) + loadURLParamNamesFromFile() } @PreDestroy @@ -79,9 +80,8 @@ class SSRFAnalyser { } } - fun apply(): Solution { - LoggingUtil.Companion.getInfoLogger().info("Applying {}", SSRFAnalyser::class.simpleName) + LoggingUtil.getInfoLogger().info("Applying {}", SSRFAnalyser::class.simpleName) // extract individuals from the archive val individuals = this.archive.extractSolution().individuals @@ -89,7 +89,7 @@ class SSRFAnalyser { individualsInSolution = RestIndividualSelectorUtils.findIndividuals( individuals, - statusCodes = listOf(200, 201) + statusCodes = listOf(200, 201, 204) ) if (individualsInSolution.isEmpty()) { @@ -100,25 +100,23 @@ class SSRFAnalyser { // The below steps are generic, for future extensions can be // accommodated easily under these common steps. + if (httpCallbackVerifier.isActive) { + // Reset before execution + httpCallbackVerifier.reset() + } else { + httpCallbackVerifier.prepare() + } + // Classify endpoints with potential vulnerability classes classify() - if (actionVulnerabilityMapping.isNotEmpty()) { - if (httpCallbackVerifier.isActive) { - // Reset before execution - httpCallbackVerifier.resetHTTPVerifier() - } else { - httpCallbackVerifier.initWireMockServer() - } - } - // execute analyse() // TODO: This is for development, remove it later val individualsAfterExecution = RestIndividualSelectorUtils.findIndividuals( this.archive.extractSolution().individuals, - statusCodes = listOf(200, 201) + statusCodes = listOf(200, 201, 204) ) log.debug("Total individuals after vulnerability analysis: {}", individualsAfterExecution.size) @@ -135,14 +133,25 @@ class SSRFAnalyser { should check the content of rcr result */ - val hasCallbackURL = action.parameters.any { param -> - val genes = getStringGenesFromParam(param.seeGenes()) - genes.any { gene -> + val hasCallBackURL = GeneUtils + .getAllStringFields(action.parameters) + .any { gene -> httpCallbackVerifier.isCallbackURL(gene.getValueAsRawString()) } + + if (hasCallBackURL) { + // FIXME: When the code reaches this point during SSRF phase + // WireMock is null, due to that this will return false. + // Which will not add the fault category. + // However, I can see the WireMock get initiated even before + // reaching this point. + // I suspected something to do with the dependency injection. + // I tried moving the WireMock inside this class, still the same. + val x = httpCallbackVerifier.verify(action.getName()) + return x } - return hasCallbackURL && httpCallbackVerifier.verify(action.getName()) + return false } /** @@ -159,13 +168,36 @@ class SSRFAnalyser { // Are we going mark potential vulnerability classes as one time // job or going to evaluate each time (which is costly). - when (config.vulnerableInputClassificationStrategy) { - EMConfig.VulnerableInputClassificationStrategy.MANUAL -> { - manualClassifier() - } + individualsInSolution.forEach { evaluatedIndividual -> + evaluatedIndividual.evaluatedMainActions().forEach { a -> + val action = a.action + if (action is RestCallAction) { + val actionFaultMapping = ActionFaultMapping(action.getName()) + val inputFaultMapping: MutableMap = + extractBodyParameters(action.parameters) + + inputFaultMapping.forEach { (paramName, paramMapping) -> + val answer = when (config.vulnerableInputClassificationStrategy) { + EMConfig.VulnerableInputClassificationStrategy.MANUAL -> { + manualClassifier(paramName, paramMapping.description) + } + + EMConfig.VulnerableInputClassificationStrategy.LLM -> { + llmClassifier(paramName, paramMapping.description) + } + } - EMConfig.VulnerableInputClassificationStrategy.LLM -> { - llmClassifier() + if (answer) { + paramMapping.addSecurityFaultCategory(DefinedFaultCategory.SSRF) + actionFaultMapping.addSecurityFaultCategory(DefinedFaultCategory.SSRF) + actionFaultMapping.isVulnerable = true + } + } + + // Assign the param mapping + actionFaultMapping.params = inputFaultMapping + actionVulnerabilityMapping[action.getName()] = actionFaultMapping + } } } } @@ -174,8 +206,11 @@ class SSRFAnalyser { if (actionVulnerabilityMapping.containsKey(action.getName())) { val mapping = actionVulnerabilityMapping[action.getName()] if (mapping != null) { - val param = mapping.params.filter { it.value.securityFaults.contains( - DefinedFaultCategory.SSRF) } + val param = mapping.params.filter { + it.value.securityFaults.contains( + DefinedFaultCategory.SSRF + ) + } return param.keys.first() } } @@ -184,59 +219,47 @@ class SSRFAnalyser { } /** - * TODO: Classify based on manual - * TODO: Need to rename the word manual to something meaningful later + * A private method to identify parameter is a potentially holds URL value + * using a Regex based approach. */ - private fun manualClassifier() { - // TODO: Can use the extracted CSV to map the parameter name - // to the vulnerability class. + private fun manualClassifier(name: String, description: String? = null): Boolean { + if (potentialUrlParamNames.contains(name.lowercase())) { + return true + } + + if (name.matches(urlRegexPattern)) { + return true + } + if (description != null) { + if (description.matches(urlRegexPattern)) { + return true + } + } + return false } /** - * Private method to classify parameters using a large language model. + * Private method to identify parameter is a potentially holds URL value, + * using a large language model. */ - private fun llmClassifier() { - // For now, we consider only the individuals selected from [Archive] - // TODO: This can be isolated to classify at the beginning of the search - individualsInSolution.forEach { evaluatedIndividual -> - evaluatedIndividual.evaluatedMainActions().forEach { a -> - val action = a.action - if (action is RestCallAction && !actionVulnerabilityMapping.containsKey(action.getName())) { - val actionFaultMapping = ActionFaultMapping(action.getName()) - val inputFaultMapping: MutableMap = - extractBodyParameters(action.parameters) - - inputFaultMapping.forEach { paramName, paramMapping -> - val answer = if (!paramMapping.description.isNullOrBlank()) { - languageModelConnector.query( - SSRFUtil.Companion.getPromptWithNameAndDescription( - paramMapping.name, - paramMapping.description - ) - ) - } else { - languageModelConnector.query( - SSRFUtil.Companion.getPromptWithNameOnly( - paramMapping.name - ) - ) - } - - if (answer != null && answer.answer == SSRFUtil.Companion.SSRF_PROMPT_ANSWER_FOR_POSSIBILITY) { - paramMapping.addSecurityFaultCategory(DefinedFaultCategory.SSRF) - actionFaultMapping.addSecurityFaultCategory(DefinedFaultCategory.SSRF) - actionFaultMapping.isVulnerable = true - } - } - - // Assign the param mapping - actionFaultMapping.params = inputFaultMapping - - actionVulnerabilityMapping[action.getName()] = actionFaultMapping - } - } + private fun llmClassifier(name: String, description: String? = null): Boolean { + val answer = if (!description.isNullOrBlank()) { + languageModelConnector.query( + SSRFUtil.getPromptWithNameAndDescription( + name, + description + ) + ) + } else { + languageModelConnector.query( + SSRFUtil.getPromptWithNameOnly( + name + ) + ) } + + return answer != null && answer.answer == SSRFUtil.SSRF_PROMPT_ANSWER_FOR_POSSIBILITY } /** @@ -247,15 +270,13 @@ class SSRFAnalyser { ): MutableMap { val output = mutableMapOf() - parameters.forEach { param -> - val genes = getStringGenesFromParam(param.seeGenes()) + val genes = GeneUtils.getAllStringFields(parameters) - genes.forEach { gene -> - output[gene.name] = InputFaultMapping( - gene.name, - gene.description, - ) - } + genes.forEach { gene -> + output[gene.name] = InputFaultMapping( + gene.name, + gene.description, + ) } return output @@ -344,37 +365,7 @@ class SSRFAnalyser { } } - private fun getStringGenesFromParam(genes: List): List { - val output = mutableListOf() - - genes.forEach { gene -> - when (gene) { - is StringGene -> { - output.add(gene) - } - - is OptionalGene -> { - output.addAll(getStringGenesFromParam(gene.getViewOfChildren())) - } - - is ObjectGene -> { - output.addAll(getStringGenesFromParam(gene.getViewOfChildren())) - } - - is ChoiceGene<*> -> { - output.addAll(getStringGenesFromParam(gene.getViewOfChildren())) - } - - is CustomMutationRateGene<*> -> { - output.addAll(getStringGenesFromParam(gene.getViewOfChildren())) - } - - else -> { - // Do nothing - } - } - } - - return output + private fun loadURLParamNamesFromFile() { + // TODO } } diff --git a/core/src/main/kotlin/org/evomaster/core/search/gene/network/InetGene.kt b/core/src/main/kotlin/org/evomaster/core/search/gene/network/InetGene.kt index 7fc6c76487..e771842e98 100644 --- a/core/src/main/kotlin/org/evomaster/core/search/gene/network/InetGene.kt +++ b/core/src/main/kotlin/org/evomaster/core/search/gene/network/InetGene.kt @@ -64,7 +64,6 @@ class InetGene( } } - @Deprecated("Do not call directly outside this package. Call setFromStringValue") /** * Set value from a string of [InetAddress]. diff --git a/e2e-tests/spring-rest-openapi-v3/src/test/kotlin/org/evomaster/e2etests/spring/openapi/v3/security/ssrf/base/SSRFBaseEMTest.kt b/e2e-tests/spring-rest-openapi-v3/src/test/kotlin/org/evomaster/e2etests/spring/openapi/v3/security/ssrf/base/SSRFBaseEMTest.kt index f378889d04..bc25be536b 100644 --- a/e2e-tests/spring-rest-openapi-v3/src/test/kotlin/org/evomaster/e2etests/spring/openapi/v3/security/ssrf/base/SSRFBaseEMTest.kt +++ b/e2e-tests/spring-rest-openapi-v3/src/test/kotlin/org/evomaster/e2etests/spring/openapi/v3/security/ssrf/base/SSRFBaseEMTest.kt @@ -1,9 +1,11 @@ package org.evomaster.e2etests.spring.openapi.v3.security.ssrf.base import com.foo.rest.examples.spring.openapi.v3.security.ssrf.base.SSRFBaseController +import org.evomaster.core.EMConfig import org.evomaster.core.problem.rest.data.HttpVerb import org.evomaster.e2etests.spring.openapi.v3.SpringTestBase import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Test @@ -14,16 +16,18 @@ class SSRFBaseEMTest : SpringTestBase() { @BeforeAll @JvmStatic fun init() { - initClass(SSRFBaseController()) + val config = EMConfig() + config.instrumentMR_NET = false + initClass(SSRFBaseController(), config) } } - @Disabled("WIP") + @Disabled @Test fun testSSRFEM() { runTestHandlingFlakyAndCompilation( "SSRFBaseEMTest", - 300, + 200, ) { args: MutableList -> // If mocking enabled, it'll spin new services each time when there is a valid URL. @@ -33,12 +37,12 @@ class SSRFBaseEMTest : SpringTestBase() { setOption(args, "ssrf", "true") setOption(args, "vulnerableInputClassificationStrategy", "MANUAL") - setOption(args, "languageModelConnector", "true") + setOption(args, "languageModelConnector", "false") setOption(args, "schemaOracles", "false") val solution = initAndRun(args) - Assertions.assertTrue(solution.individuals.isNotEmpty()) + assertTrue(solution.individuals.isNotEmpty()) assertHasAtLeastOne(solution, HttpVerb.POST, 200, "/api/fetch/data", "OK") assertHasAtLeastOne(solution, HttpVerb.POST, 200, "/api/fetch/image", "OK") diff --git a/e2e-tests/spring-rest-openapi-v3/src/test/kotlin/org/evomaster/e2etests/spring/openapi/v3/security/ssrf/header/SSRFHeaderEMTest.kt b/e2e-tests/spring-rest-openapi-v3/src/test/kotlin/org/evomaster/e2etests/spring/openapi/v3/security/ssrf/header/SSRFHeaderEMTest.kt index 33df9b8c95..ee3f77b6f9 100644 --- a/e2e-tests/spring-rest-openapi-v3/src/test/kotlin/org/evomaster/e2etests/spring/openapi/v3/security/ssrf/header/SSRFHeaderEMTest.kt +++ b/e2e-tests/spring-rest-openapi-v3/src/test/kotlin/org/evomaster/e2etests/spring/openapi/v3/security/ssrf/header/SSRFHeaderEMTest.kt @@ -1,6 +1,7 @@ package org.evomaster.e2etests.spring.openapi.v3.security.ssrf.header import com.foo.rest.examples.spring.openapi.v3.security.ssrf.header.SSRFHeaderController +import org.evomaster.core.EMConfig import org.evomaster.core.problem.rest.data.HttpVerb import org.evomaster.e2etests.spring.openapi.v3.SpringTestBase import org.junit.jupiter.api.Assertions @@ -14,7 +15,9 @@ class SSRFHeaderEMTest: SpringTestBase() { @BeforeAll @JvmStatic fun init() { - initClass(SSRFHeaderController()) + val config = EMConfig() + config.instrumentMR_NET = false + initClass(SSRFHeaderController(), config) } } @@ -23,7 +26,7 @@ class SSRFHeaderEMTest: SpringTestBase() { fun testSSRFHeader() { runTestHandlingFlakyAndCompilation( "SSRFEMTest", - 300, + 100, ) { args: MutableList -> // If mocking enabled, it'll spin new services each time when there is a valid URL. diff --git a/e2e-tests/spring-rest-openapi-v3/src/test/kotlin/org/evomaster/e2etests/spring/openapi/v3/security/ssrf/query/SSRFQueryEMTest.kt b/e2e-tests/spring-rest-openapi-v3/src/test/kotlin/org/evomaster/e2etests/spring/openapi/v3/security/ssrf/query/SSRFQueryEMTest.kt index 26c3d84a91..88ba76403d 100644 --- a/e2e-tests/spring-rest-openapi-v3/src/test/kotlin/org/evomaster/e2etests/spring/openapi/v3/security/ssrf/query/SSRFQueryEMTest.kt +++ b/e2e-tests/spring-rest-openapi-v3/src/test/kotlin/org/evomaster/e2etests/spring/openapi/v3/security/ssrf/query/SSRFQueryEMTest.kt @@ -1,6 +1,7 @@ package org.evomaster.e2etests.spring.openapi.v3.security.ssrf.query import com.foo.rest.examples.spring.openapi.v3.security.ssrf.query.SSRFQueryController +import org.evomaster.core.EMConfig import org.evomaster.core.problem.rest.data.HttpVerb import org.evomaster.e2etests.spring.openapi.v3.SpringTestBase import org.junit.jupiter.api.Assertions @@ -14,7 +15,9 @@ class SSRFQueryEMTest: SpringTestBase() { @BeforeAll @JvmStatic fun init() { - initClass(SSRFQueryController()) + val config = EMConfig() + config.instrumentMR_NET = false + initClass(SSRFQueryController(), config) } } @@ -23,7 +26,7 @@ class SSRFQueryEMTest: SpringTestBase() { fun testSSRFQuery() { runTestHandlingFlakyAndCompilation( "SSRFQueryEMTest", - 300, + 80, ) { args: MutableList -> // If mocking enabled, it'll spin new services each time when there is a valid URL.