From e03e5d1a8058832077f059abe6bf3cb6d7d35836 Mon Sep 17 00:00:00 2001 From: CahidArda Date: Mon, 5 May 2025 11:50:48 +0300 Subject: [PATCH 1/3] feat: add middleware --- src/context/auto-executor.ts | 61 +++++- src/context/context.ts | 5 +- src/index.ts | 1 + src/middleware/index.ts | 2 + src/middleware/logging.ts | 22 ++ src/middleware/middleware.test.ts | 338 ++++++++++++++++++++++++++++++ src/middleware/middleware.ts | 43 ++++ src/qstash/submit-steps.ts | 8 + src/serve/index.ts | 16 +- src/serve/options.ts | 1 + src/types.ts | 20 ++ 11 files changed, 511 insertions(+), 6 deletions(-) create mode 100644 src/middleware/index.ts create mode 100644 src/middleware/logging.ts create mode 100644 src/middleware/middleware.test.ts create mode 100644 src/middleware/middleware.ts diff --git a/src/context/auto-executor.ts b/src/context/auto-executor.ts index ae334ea..9bb5ace 100644 --- a/src/context/auto-executor.ts +++ b/src/context/auto-executor.ts @@ -5,6 +5,8 @@ import { type BaseLazyStep } from "./steps"; import type { WorkflowLogger } from "../logger"; import { QstashError } from "@upstash/qstash"; import { submitParallelSteps, submitSingleStep } from "../qstash/submit-steps"; +import { WorkflowMiddleware } from "../middleware"; +import { runMiddlewares } from "../middleware/middleware"; export class AutoExecutor { private context: WorkflowContext; @@ -15,8 +17,9 @@ export class AutoExecutor { private readonly nonPlanStepCount: number; private readonly steps: Step[]; private indexInCurrentList = 0; - private invokeCount: number; - private telemetry?: Telemetry; + private readonly invokeCount: number; + private readonly telemetry?: Telemetry; + private readonly middlewares?: WorkflowMiddleware[]; public stepCount = 0; public planStepCount = 0; @@ -28,13 +31,15 @@ export class AutoExecutor { steps: Step[], telemetry?: Telemetry, invokeCount?: number, - debug?: WorkflowLogger + debug?: WorkflowLogger, + middlewares?: WorkflowMiddleware[] ) { this.context = context; this.steps = steps; this.telemetry = telemetry; this.invokeCount = invokeCount ?? 0; this.debug = debug; + this.middlewares = middlewares; this.nonPlanStepCount = this.steps.filter((step) => !step.targetStep).length; } @@ -133,7 +138,20 @@ export class AutoExecutor { step, stepCount: this.stepCount, }); - return lazyStep.parseOut(step.out); + const parsedOut = lazyStep.parseOut(step.out); + + const isLastMemoized = + this.stepCount + 1 === this.nonPlanStepCount && this.steps.at(-1)!.stepId !== 0; + + if (isLastMemoized) { + runMiddlewares(this.middlewares, "afterExecution", { + workflowRunId: this.context.workflowRunId, + stepName: lazyStep.stepName, + result: parsedOut, + }); + } + + return parsedOut; } const resultStep = await submitSingleStep({ @@ -144,6 +162,7 @@ export class AutoExecutor { concurrency: 1, telemetry: this.telemetry, debug: this.debug, + middlewares: this.middlewares, }); throw new WorkflowAbort(lazyStep.stepName, resultStep); } @@ -232,6 +251,7 @@ export class AutoExecutor { concurrency: parallelSteps.length, telemetry: this.telemetry, debug: this.debug, + middlewares: this.middlewares, }); throw new WorkflowAbort(parallelStep.stepName, resultStep); } catch (error) { @@ -256,6 +276,22 @@ export class AutoExecutor { * This call to the API should be discarded: no operations are to be made. Parallel steps which are still * running will finish and call QStash eventually. */ + + if (this.middlewares) { + const resultStep = this.steps.at(-1)!; + const lazyStep = parallelSteps.find( + (planStep, index) => resultStep.stepId - index === initialStepCount + ); + + if (lazyStep) { + runMiddlewares(this.middlewares, "afterExecution", { + workflowRunId: this.context.workflowRunId, + stepName: lazyStep.stepName, + result: lazyStep.parseOut(resultStep.out), + }); + } + } + throw new WorkflowAbort("discarded parallel"); } case "last": { @@ -271,6 +307,23 @@ export class AutoExecutor { validateParallelSteps(parallelSteps, parallelResultSteps); + if (this.middlewares) { + const isLastMemoized = + this.stepCount + 1 === this.nonPlanStepCount && this.steps.at(-1)!.stepId !== 0; + if (isLastMemoized) { + const resultStep = this.steps.at(-1)!; + const lazyStep = parallelSteps.find( + (planStep, index) => resultStep.stepId - index === initialStepCount + )!; + + runMiddlewares(this.middlewares, "afterExecution", { + workflowRunId: this.context.workflowRunId, + stepName: lazyStep.stepName, + result: lazyStep.parseOut(resultStep.out), + }); + } + } + return parallelResultSteps.map((step, index) => parallelSteps[index].parseOut(step.out) ) as TResults; diff --git a/src/context/context.ts b/src/context/context.ts index a2f8976..91ad9f6 100644 --- a/src/context/context.ts +++ b/src/context/context.ts @@ -28,6 +28,7 @@ import { WorkflowApi } from "./api"; import { WorkflowAgents } from "../agents"; import { FlowControl } from "@upstash/qstash"; import { getNewUrlFromWorkflowId } from "../serve/serve-many"; +import { WorkflowMiddleware } from "../middleware"; /** * Upstash Workflow context @@ -176,6 +177,7 @@ export class WorkflowContext { telemetry, invokeCount, flowControl, + middlewares, }: { qstashClient: WorkflowClient; workflowRunId: string; @@ -190,6 +192,7 @@ export class WorkflowContext { telemetry?: Telemetry; invokeCount?: number; flowControl?: FlowControl; + middlewares?: WorkflowMiddleware[]; }) { this.qstashClient = qstashClient; this.workflowRunId = workflowRunId; @@ -202,7 +205,7 @@ export class WorkflowContext { this.retries = retries ?? DEFAULT_RETRIES; this.flowControl = flowControl; - this.executor = new AutoExecutor(this, this.steps, telemetry, invokeCount, debug); + this.executor = new AutoExecutor(this, this.steps, telemetry, invokeCount, debug, middlewares); } /** diff --git a/src/index.ts b/src/index.ts index ad18d62..4fbd119 100644 --- a/src/index.ts +++ b/src/index.ts @@ -6,3 +6,4 @@ export * from "./logger"; export * from "./client"; export { WorkflowError, WorkflowAbort } from "./error"; export { WorkflowTool } from "./agents/adapters"; +export { WorkflowMiddleware, loggingMiddleware } from "./middleware"; diff --git a/src/middleware/index.ts b/src/middleware/index.ts new file mode 100644 index 0000000..413364a --- /dev/null +++ b/src/middleware/index.ts @@ -0,0 +1,2 @@ +export { WorkflowMiddleware } from "./middleware"; +export { loggingMiddleware } from "./logging"; diff --git a/src/middleware/logging.ts b/src/middleware/logging.ts new file mode 100644 index 0000000..fde27c4 --- /dev/null +++ b/src/middleware/logging.ts @@ -0,0 +1,22 @@ +import { WorkflowMiddleware } from "./middleware"; + +export const loggingMiddleware = new WorkflowMiddleware({ + init: () => { + console.log("Logging middleware initialized"); + + return { + afterExecution(params) { + console.log("Step executed:", params); + }, + beforeExecution(params) { + console.log("Step execution started:", params); + }, + runStarted(params) { + console.log("Workflow run started:", params); + }, + runCompleted(params) { + console.log("Workflow run completed:", params); + }, + }; + }, +}); diff --git a/src/middleware/middleware.test.ts b/src/middleware/middleware.test.ts new file mode 100644 index 0000000..e20cbd1 --- /dev/null +++ b/src/middleware/middleware.test.ts @@ -0,0 +1,338 @@ +import { describe, test, expect, jest } from "bun:test"; +import { WorkflowMiddleware } from "./middleware"; +import { Client } from "@upstash/qstash"; +import { getRequest } from "../test-utils"; +import { nanoid } from "../utils"; +import { serve } from "../../platforms/nextjs"; +import { RouteFunction, Step } from "../types"; + +const createLoggingMiddleware = () => { + const accumulator: [string, unknown?][] = []; + const middleware = new WorkflowMiddleware({ + init: () => { + accumulator.push(["init"]); + + return { + afterExecution(params) { + accumulator.push(["afterExecution", params]); + }, + beforeExecution(params) { + accumulator.push(["beforeExecution", params]); + }, + runStarted(params) { + accumulator.push(["runStarted", params]); + }, + runCompleted(params) { + accumulator.push(["runCompleted", params]); + }, + }; + }, + }); + + return { middleware, accumulator }; +}; + +describe("middleware", () => { + test("should not call init in constructor", () => { + const init = jest.fn(); + new WorkflowMiddleware({ init }); + expect(init).not.toHaveBeenCalled(); + }); + + describe("runCallback method", () => { + test("should call init and callbacks", async () => { + const { middleware, accumulator } = createLoggingMiddleware(); + const stepName = `step-${nanoid()}`; + + await middleware.runCallback("runStarted", { workflowRunId: "wfr-id" }); + expect(accumulator).toEqual([["init"], ["runStarted", { workflowRunId: "wfr-id" }]]); + + await middleware.runCallback("beforeExecution", { + workflowRunId: "wfr-id", + stepName: stepName, + }); + expect(accumulator).toEqual([ + ["init"], + ["runStarted", { workflowRunId: "wfr-id" }], + ["beforeExecution", { workflowRunId: "wfr-id", stepName }], + ]); + + await middleware.runCallback("beforeExecution", { + workflowRunId: "wfr-id", + stepName: stepName, + }); + expect(accumulator).toEqual([ + ["init"], + ["runStarted", { workflowRunId: "wfr-id" }], + ["beforeExecution", { workflowRunId: "wfr-id", stepName }], + ["beforeExecution", { workflowRunId: "wfr-id", stepName }], + ]); + + const result = "some-result"; + await middleware.runCallback("afterExecution", { + workflowRunId: "wfr-id", + stepName: stepName, + result, + }); + expect(accumulator).toEqual([ + ["init"], + ["runStarted", { workflowRunId: "wfr-id" }], + ["beforeExecution", { workflowRunId: "wfr-id", stepName }], + ["beforeExecution", { workflowRunId: "wfr-id", stepName }], + ["afterExecution", { workflowRunId: "wfr-id", stepName, result }], + ]); + + const finishResult = "finished-result"; + await middleware.runCallback("runCompleted", { + workflowRunId: "wfr-id", + result: finishResult, + }); + expect(accumulator).toEqual([ + ["init"], + ["runStarted", { workflowRunId: "wfr-id" }], + ["beforeExecution", { workflowRunId: "wfr-id", stepName }], + ["beforeExecution", { workflowRunId: "wfr-id", stepName }], + ["afterExecution", { workflowRunId: "wfr-id", stepName, result }], + ["runCompleted", { workflowRunId: "wfr-id", result: finishResult }], + ]); + }); + + describe("with context", () => { + const stepOneName = `step-one-${nanoid()}`; + const stepTwoName = `step-two-${nanoid()}`; + const stepThreeName = `step-three-${nanoid()}`; + const parallelRunOne = `parallel-sleep-One-${nanoid()}`; + const parallelRunTwo = `parallel-sleep-Two-${nanoid()}`; + const stepResult = `step-result-${nanoid()}`; + const stepResultOne = `step-result-one-${nanoid()}`; + const stepResultTwo = `step-result-two-${nanoid()}`; + + const incrementalTestSteps: { + step?: Step; + middlewareAccumaltor: ReturnType["accumulator"]; + }[] = [ + { + middlewareAccumaltor: [ + ["init"], + [ + "runStarted", + { + workflowRunId: "wfr-id", + }, + ], + [ + "beforeExecution", + { + workflowRunId: "wfr-id", + stepName: stepOneName, + }, + ], + ], + }, + { + step: { + stepId: 1, + stepName: stepOneName, + stepType: "SleepFor", + sleepFor: 1, + concurrent: 1, + }, + middlewareAccumaltor: [ + ["init"], + [ + "afterExecution", + { + workflowRunId: "wfr-id", + stepName: stepOneName, + result: undefined, + }, + ], + [ + "beforeExecution", + { + workflowRunId: "wfr-id", + stepName: stepTwoName, + }, + ], + ], + }, + { + step: { + stepId: 2, + stepName: stepTwoName, + stepType: "Run", + out: JSON.stringify(stepResult), + concurrent: 1, + }, + middlewareAccumaltor: [ + ["init"], + [ + "afterExecution", + { + workflowRunId: "wfr-id", + stepName: stepTwoName, + result: stepResult, + }, + ], + ], + }, + { + step: { + stepId: 0, + stepName: parallelRunOne, + stepType: "Run", + concurrent: 2, + targetStep: 3, + }, + middlewareAccumaltor: [ + ["init"], + [ + "beforeExecution", + { + workflowRunId: "wfr-id", + stepName: parallelRunOne, + }, + ], + ], + }, + { + step: { + stepId: 0, + stepName: parallelRunTwo, + stepType: "Run", + concurrent: 2, + targetStep: 4, + }, + middlewareAccumaltor: [ + ["init"], + [ + "beforeExecution", + { + workflowRunId: "wfr-id", + stepName: parallelRunTwo, + }, + ], + ], + }, + { + step: { + stepId: 4, + stepName: parallelRunTwo, + stepType: "Run", + out: JSON.stringify(stepResultTwo), + concurrent: 2, + }, + middlewareAccumaltor: [ + ["init"], + [ + "afterExecution", + { + workflowRunId: "wfr-id", + stepName: parallelRunTwo, + result: stepResultTwo, + }, + ], + ], + }, + { + step: { + stepId: 3, + stepName: parallelRunOne, + stepType: "Run", + out: JSON.stringify(stepResultOne), + concurrent: 2, + }, + middlewareAccumaltor: [ + ["init"], + [ + "afterExecution", + { + workflowRunId: "wfr-id", + result: stepResultOne, + stepName: parallelRunOne, + }, + ], + [ + "beforeExecution", + { + workflowRunId: "wfr-id", + stepName: stepThreeName, + }, + ], + ], + }, + { + step: { + stepId: 5, + stepName: stepThreeName, + stepType: "SleepFor", + sleepFor: 10, + concurrent: 1, + }, + middlewareAccumaltor: [ + ["init"], + [ + "afterExecution", + { + workflowRunId: "wfr-id", + result: undefined, + stepName: stepThreeName, + }, + ], + [ + "runCompleted", + { + workflowRunId: "wfr-id", + result: undefined, + }, + ], + ], + }, + ]; + + const routeFunction: RouteFunction = async (context) => { + await context.sleep(stepOneName, 1); + await context.run(stepTwoName, () => stepResult); + await Promise.all([ + context.run(parallelRunOne, () => stepResultOne), + context.run(parallelRunTwo, () => stepResultTwo), + ]); + await context.sleep(stepThreeName, 10); + }; + + const qstashClient = new Client({ baseUrl: "https://requestcatcher.com", token: "token" }); + qstashClient.http.request = jest.fn(); + + const runMiddlewareTest = async ( + steps: Step[], + expectedAccumulator: ReturnType["accumulator"] + ) => { + const { middleware, accumulator } = createLoggingMiddleware(); + + const request = getRequest("https://requestcatcher.com", "wfr-id", undefined, steps); + + const { POST: handler } = serve(routeFunction, { + middlewares: [middleware], + url: "https://requestcatcher.com", + receiver: undefined, + qstashClient, + }); + + const response = await handler(request); + expect(response.status).toBe(200); + + expect(accumulator).toEqual(expectedAccumulator); + }; + + incrementalTestSteps.forEach(({ middlewareAccumaltor }, index) => { + const testSteps = incrementalTestSteps + .slice(0, index + 1) + .map(({ step }) => step) + .filter(Boolean) as Step[]; + test(`should call middleware in order case #${index + 1}`, async () => { + await runMiddlewareTest(testSteps, middlewareAccumaltor); + }); + }); + }); + }); +}); diff --git a/src/middleware/middleware.ts b/src/middleware/middleware.ts new file mode 100644 index 0000000..260d0f8 --- /dev/null +++ b/src/middleware/middleware.ts @@ -0,0 +1,43 @@ +import { MiddlewareCallbacks, MiddlewareParameters } from "../types"; + +export class WorkflowMiddleware { + private readonly init: MiddlewareParameters["init"]; + private middlewareCallbacks?: MiddlewareCallbacks; + + constructor(parameters: MiddlewareParameters) { + this.init = parameters.init; + this.middlewareCallbacks = undefined; + } + + async runCallback( + callback: K, + parameters: Parameters>[0] + ): Promise { + await this.ensureInit(); + const cb = this.middlewareCallbacks?.[callback]; + if (cb) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await cb(parameters as any); + } + } + + private async ensureInit() { + if (!this.middlewareCallbacks) { + this.middlewareCallbacks = await this.init(); + } + } +} + +export const runMiddlewares = async ( + middlewares: WorkflowMiddleware[] | undefined, + callback: K, + parameters: Parameters>[0] +) => { + if (!middlewares) { + return; + } + + middlewares.forEach(async (m) => { + await m.runCallback(callback, parameters); + }); +}; diff --git a/src/qstash/submit-steps.ts b/src/qstash/submit-steps.ts index b7c76e2..4c87d57 100644 --- a/src/qstash/submit-steps.ts +++ b/src/qstash/submit-steps.ts @@ -5,6 +5,8 @@ import { Telemetry } from "../types"; import { WorkflowContext } from "../context"; import { BaseLazyStep } from "../context/steps"; import { getHeaders } from "./headers"; +import { WorkflowMiddleware } from "../middleware"; +import { runMiddlewares } from "../middleware/middleware"; export const submitParallelSteps = async ({ context, @@ -76,6 +78,7 @@ export const submitSingleStep = async ({ concurrency, telemetry, debug, + middlewares, }: { context: WorkflowContext; lazyStep: BaseLazyStep; @@ -84,8 +87,13 @@ export const submitSingleStep = async ({ concurrency: number; telemetry?: Telemetry; debug?: WorkflowLogger; + middlewares?: WorkflowMiddleware[]; }) => { const resultStep = await lazyStep.getResultStep(concurrency, stepId); + await runMiddlewares(middlewares, "beforeExecution", { + workflowRunId: context.workflowRunId, + stepName: resultStep.stepName, + }); await debug?.log("INFO", "RUN_SINGLE", { fromRequest: false, step: resultStep, diff --git a/src/serve/index.ts b/src/serve/index.ts index 54d041d..df12058 100644 --- a/src/serve/index.ts +++ b/src/serve/index.ts @@ -3,6 +3,7 @@ import { SDK_TELEMETRY, WORKFLOW_INVOKE_COUNT_HEADER } from "../constants"; import { WorkflowContext } from "../context"; import { formatWorkflowError } from "../error"; import { WorkflowLogger } from "../logger"; +import { runMiddlewares } from "../middleware/middleware"; import { ExclusiveValidationOptions, RouteFunction, @@ -63,6 +64,7 @@ export const serveBase = < disableTelemetry, flowControl, onError, + middlewares, } = processOptions(options); telemetry = disableTelemetry ? undefined : telemetry; const debug = WorkflowLogger.getLogger(verbose); @@ -154,6 +156,7 @@ export const serveBase = < telemetry, invokeCount, flowControl, + middlewares, }); // attempt running routeFunction until the first step @@ -204,8 +207,19 @@ export const serveBase = < invokeCount, }) : await triggerRouteFunction({ - onStep: async () => routeFunction(workflowContext), + onStep: async () => { + if (steps.length === 1) { + await runMiddlewares(middlewares, "runStarted", { + workflowRunId: workflowContext.workflowRunId, + }); + } + await routeFunction(workflowContext); + }, onCleanup: async (result) => { + await runMiddlewares(middlewares, "runCompleted", { + workflowRunId: workflowContext.workflowRunId, + result, + }); await triggerWorkflowDelete(workflowContext, result, debug); }, onCancel: async () => { diff --git a/src/serve/options.ts b/src/serve/options.ts index b016ff0..d5d27bc 100644 --- a/src/serve/options.ts +++ b/src/serve/options.ts @@ -30,6 +30,7 @@ export const processOptions = => { const environment = options?.env ?? (typeof process === "undefined" ? ({} as Record) : process.env); diff --git a/src/types.ts b/src/types.ts index 38e7c65..307c1af 100644 --- a/src/types.ts +++ b/src/types.ts @@ -4,6 +4,7 @@ import type { HTTPMethods } from "@upstash/qstash"; import type { WorkflowContext } from "./context"; import type { WorkflowLogger } from "./logger"; import { z } from "zod"; +import { WorkflowMiddleware } from "./middleware"; /** * Interface for Client with required methods @@ -256,6 +257,10 @@ export type WorkflowServeOptions< * and number of requests per second with the same key. */ flowControl?: FlowControl; + /** + * List of workflow middlewares to use + */ + middlewares?: WorkflowMiddleware[]; } & ValidationOptions; type ValidationOptions = { @@ -519,3 +524,18 @@ export type InvokableWorkflow = { options: WorkflowServeOptions; workflowId?: string; }; + +export type MiddlewareCallbacks = { + beforeExecution?: (params: { workflowRunId: string; stepName: string }) => Promise | void; + afterExecution?: (params: { + workflowRunId: string; + stepName: string; + result: unknown; + }) => Promise | void; + runStarted?: (params: { workflowRunId: string }) => Promise | void; + runCompleted?: (params: { workflowRunId: string; result: unknown }) => Promise | void; +}; + +export type MiddlewareParameters = { + init: () => Promise | MiddlewareCallbacks; +}; From 538c10b2a5655a5077beb4ba7e50428b1410764d Mon Sep 17 00:00:00 2001 From: CahidArda Date: Mon, 5 May 2025 15:11:12 +0300 Subject: [PATCH 2/3] feat: add name and workflowRunId params --- src/middleware/logging.ts | 1 + src/middleware/middleware.test.ts | 3 ++- src/middleware/middleware.ts | 6 +++--- src/types.ts | 3 ++- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/middleware/logging.ts b/src/middleware/logging.ts index fde27c4..dbe9669 100644 --- a/src/middleware/logging.ts +++ b/src/middleware/logging.ts @@ -1,6 +1,7 @@ import { WorkflowMiddleware } from "./middleware"; export const loggingMiddleware = new WorkflowMiddleware({ + name: "logging", init: () => { console.log("Logging middleware initialized"); diff --git a/src/middleware/middleware.test.ts b/src/middleware/middleware.test.ts index e20cbd1..5ec1203 100644 --- a/src/middleware/middleware.test.ts +++ b/src/middleware/middleware.test.ts @@ -9,6 +9,7 @@ import { RouteFunction, Step } from "../types"; const createLoggingMiddleware = () => { const accumulator: [string, unknown?][] = []; const middleware = new WorkflowMiddleware({ + name: "test", init: () => { accumulator.push(["init"]); @@ -35,7 +36,7 @@ const createLoggingMiddleware = () => { describe("middleware", () => { test("should not call init in constructor", () => { const init = jest.fn(); - new WorkflowMiddleware({ init }); + new WorkflowMiddleware({ name: "test", init }); expect(init).not.toHaveBeenCalled(); }); diff --git a/src/middleware/middleware.ts b/src/middleware/middleware.ts index 260d0f8..30ae495 100644 --- a/src/middleware/middleware.ts +++ b/src/middleware/middleware.ts @@ -13,7 +13,7 @@ export class WorkflowMiddleware { callback: K, parameters: Parameters>[0] ): Promise { - await this.ensureInit(); + await this.ensureInit(parameters.workflowRunId); const cb = this.middlewareCallbacks?.[callback]; if (cb) { // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -21,9 +21,9 @@ export class WorkflowMiddleware { } } - private async ensureInit() { + private async ensureInit(workflowRunId: string) { if (!this.middlewareCallbacks) { - this.middlewareCallbacks = await this.init(); + this.middlewareCallbacks = await this.init({ workflowRunId }); } } } diff --git a/src/types.ts b/src/types.ts index 307c1af..da3b81d 100644 --- a/src/types.ts +++ b/src/types.ts @@ -537,5 +537,6 @@ export type MiddlewareCallbacks = { }; export type MiddlewareParameters = { - init: () => Promise | MiddlewareCallbacks; + name: string; + init: (params: { workflowRunId: string }) => Promise | MiddlewareCallbacks; }; From ae40015c1c9c281fd2ead3e21126a786a939adc8 Mon Sep 17 00:00:00 2001 From: CahidArda Date: Mon, 5 May 2025 16:29:14 +0300 Subject: [PATCH 3/3] fix: run afterExecution in the same request and add onError --- src/context/auto-executor.ts | 45 ------------ src/middleware/middleware.test.ts | 111 +++++++++++++++++------------- src/qstash/submit-steps.ts | 5 ++ src/serve/index.ts | 21 +++++- src/types.ts | 9 +-- 5 files changed, 90 insertions(+), 101 deletions(-) diff --git a/src/context/auto-executor.ts b/src/context/auto-executor.ts index 9bb5ace..cbea467 100644 --- a/src/context/auto-executor.ts +++ b/src/context/auto-executor.ts @@ -6,7 +6,6 @@ import type { WorkflowLogger } from "../logger"; import { QstashError } from "@upstash/qstash"; import { submitParallelSteps, submitSingleStep } from "../qstash/submit-steps"; import { WorkflowMiddleware } from "../middleware"; -import { runMiddlewares } from "../middleware/middleware"; export class AutoExecutor { private context: WorkflowContext; @@ -140,17 +139,6 @@ export class AutoExecutor { }); const parsedOut = lazyStep.parseOut(step.out); - const isLastMemoized = - this.stepCount + 1 === this.nonPlanStepCount && this.steps.at(-1)!.stepId !== 0; - - if (isLastMemoized) { - runMiddlewares(this.middlewares, "afterExecution", { - workflowRunId: this.context.workflowRunId, - stepName: lazyStep.stepName, - result: parsedOut, - }); - } - return parsedOut; } @@ -276,22 +264,6 @@ export class AutoExecutor { * This call to the API should be discarded: no operations are to be made. Parallel steps which are still * running will finish and call QStash eventually. */ - - if (this.middlewares) { - const resultStep = this.steps.at(-1)!; - const lazyStep = parallelSteps.find( - (planStep, index) => resultStep.stepId - index === initialStepCount - ); - - if (lazyStep) { - runMiddlewares(this.middlewares, "afterExecution", { - workflowRunId: this.context.workflowRunId, - stepName: lazyStep.stepName, - result: lazyStep.parseOut(resultStep.out), - }); - } - } - throw new WorkflowAbort("discarded parallel"); } case "last": { @@ -307,23 +279,6 @@ export class AutoExecutor { validateParallelSteps(parallelSteps, parallelResultSteps); - if (this.middlewares) { - const isLastMemoized = - this.stepCount + 1 === this.nonPlanStepCount && this.steps.at(-1)!.stepId !== 0; - if (isLastMemoized) { - const resultStep = this.steps.at(-1)!; - const lazyStep = parallelSteps.find( - (planStep, index) => resultStep.stepId - index === initialStepCount - )!; - - runMiddlewares(this.middlewares, "afterExecution", { - workflowRunId: this.context.workflowRunId, - stepName: lazyStep.stepName, - result: lazyStep.parseOut(resultStep.out), - }); - } - } - return parallelResultSteps.map((step, index) => parallelSteps[index].parseOut(step.out) ) as TResults; diff --git a/src/middleware/middleware.test.ts b/src/middleware/middleware.test.ts index 5ec1203..f502380 100644 --- a/src/middleware/middleware.test.ts +++ b/src/middleware/middleware.test.ts @@ -26,6 +26,9 @@ const createLoggingMiddleware = () => { runCompleted(params) { accumulator.push(["runCompleted", params]); }, + onError(params) { + accumulator.push(["onError", params]); + }, }; }, }); @@ -69,32 +72,28 @@ describe("middleware", () => { ["beforeExecution", { workflowRunId: "wfr-id", stepName }], ]); - const result = "some-result"; await middleware.runCallback("afterExecution", { workflowRunId: "wfr-id", stepName: stepName, - result, }); expect(accumulator).toEqual([ ["init"], ["runStarted", { workflowRunId: "wfr-id" }], ["beforeExecution", { workflowRunId: "wfr-id", stepName }], ["beforeExecution", { workflowRunId: "wfr-id", stepName }], - ["afterExecution", { workflowRunId: "wfr-id", stepName, result }], + ["afterExecution", { workflowRunId: "wfr-id", stepName }], ]); - const finishResult = "finished-result"; await middleware.runCallback("runCompleted", { workflowRunId: "wfr-id", - result: finishResult, }); expect(accumulator).toEqual([ ["init"], ["runStarted", { workflowRunId: "wfr-id" }], ["beforeExecution", { workflowRunId: "wfr-id", stepName }], ["beforeExecution", { workflowRunId: "wfr-id", stepName }], - ["afterExecution", { workflowRunId: "wfr-id", stepName, result }], - ["runCompleted", { workflowRunId: "wfr-id", result: finishResult }], + ["afterExecution", { workflowRunId: "wfr-id", stepName }], + ["runCompleted", { workflowRunId: "wfr-id" }], ]); }); @@ -128,6 +127,13 @@ describe("middleware", () => { stepName: stepOneName, }, ], + [ + "afterExecution", + { + workflowRunId: "wfr-id", + stepName: stepOneName, + }, + ], ], }, { @@ -141,15 +147,14 @@ describe("middleware", () => { middlewareAccumaltor: [ ["init"], [ - "afterExecution", + "beforeExecution", { workflowRunId: "wfr-id", - stepName: stepOneName, - result: undefined, + stepName: stepTwoName, }, ], [ - "beforeExecution", + "afterExecution", { workflowRunId: "wfr-id", stepName: stepTwoName, @@ -165,17 +170,7 @@ describe("middleware", () => { out: JSON.stringify(stepResult), concurrent: 1, }, - middlewareAccumaltor: [ - ["init"], - [ - "afterExecution", - { - workflowRunId: "wfr-id", - stepName: stepTwoName, - result: stepResult, - }, - ], - ], + middlewareAccumaltor: [], }, { step: { @@ -194,6 +189,13 @@ describe("middleware", () => { stepName: parallelRunOne, }, ], + [ + "afterExecution", + { + workflowRunId: "wfr-id", + stepName: parallelRunOne, + }, + ], ], }, { @@ -213,6 +215,13 @@ describe("middleware", () => { stepName: parallelRunTwo, }, ], + [ + "afterExecution", + { + workflowRunId: "wfr-id", + stepName: parallelRunTwo, + }, + ], ], }, { @@ -223,17 +232,7 @@ describe("middleware", () => { out: JSON.stringify(stepResultTwo), concurrent: 2, }, - middlewareAccumaltor: [ - ["init"], - [ - "afterExecution", - { - workflowRunId: "wfr-id", - stepName: parallelRunTwo, - result: stepResultTwo, - }, - ], - ], + middlewareAccumaltor: [], }, { step: { @@ -246,15 +245,14 @@ describe("middleware", () => { middlewareAccumaltor: [ ["init"], [ - "afterExecution", + "beforeExecution", { workflowRunId: "wfr-id", - result: stepResultOne, - stepName: parallelRunOne, + stepName: stepThreeName, }, ], [ - "beforeExecution", + "afterExecution", { workflowRunId: "wfr-id", stepName: stepThreeName, @@ -272,19 +270,10 @@ describe("middleware", () => { }, middlewareAccumaltor: [ ["init"], - [ - "afterExecution", - { - workflowRunId: "wfr-id", - result: undefined, - stepName: stepThreeName, - }, - ], [ "runCompleted", { workflowRunId: "wfr-id", - result: undefined, }, ], ], @@ -306,7 +295,8 @@ describe("middleware", () => { const runMiddlewareTest = async ( steps: Step[], - expectedAccumulator: ReturnType["accumulator"] + expectedAccumulator: ReturnType["accumulator"], + status: number = 200 ) => { const { middleware, accumulator } = createLoggingMiddleware(); @@ -320,7 +310,7 @@ describe("middleware", () => { }); const response = await handler(request); - expect(response.status).toBe(200); + expect(response.status).toBe(status); expect(accumulator).toEqual(expectedAccumulator); }; @@ -334,6 +324,31 @@ describe("middleware", () => { await runMiddlewareTest(testSteps, middlewareAccumaltor); }); }); + + test("with error", async () => { + await runMiddlewareTest( + [ + { + stepId: 1, + stepName: stepOneName + "-error", + stepType: "SleepFor", + sleepFor: 1, + concurrent: 1, + }, + ], + [ + ["init"], + [ + "onError", + { + workflowRunId: "wfr-id", + error: expect.any(Error), + }, + ], + ], + 500 + ); + }); }); }); }); diff --git a/src/qstash/submit-steps.ts b/src/qstash/submit-steps.ts index 4c87d57..bf45106 100644 --- a/src/qstash/submit-steps.ts +++ b/src/qstash/submit-steps.ts @@ -119,6 +119,11 @@ export const submitSingleStep = async ({ steps: [resultStep], }); + await runMiddlewares(middlewares, "afterExecution", { + workflowRunId: context.workflowRunId, + stepName: lazyStep.stepName, + }); + const submitResult = await lazyStep.submitStep({ context, body, diff --git a/src/serve/index.ts b/src/serve/index.ts index df12058..c99e140 100644 --- a/src/serve/index.ts +++ b/src/serve/index.ts @@ -1,5 +1,5 @@ import { makeCancelRequest } from "../client/utils"; -import { SDK_TELEMETRY, WORKFLOW_INVOKE_COUNT_HEADER } from "../constants"; +import { SDK_TELEMETRY, WORKFLOW_ID_HEADER, WORKFLOW_INVOKE_COUNT_HEADER } from "../constants"; import { WorkflowContext } from "../context"; import { formatWorkflowError } from "../error"; import { WorkflowLogger } from "../logger"; @@ -218,7 +218,6 @@ export const serveBase = < onCleanup: async (result) => { await runMiddlewares(middlewares, "runCompleted", { workflowRunId: workflowContext.workflowRunId, - result, }); await triggerWorkflowDelete(workflowContext, result, debug); }, @@ -263,6 +262,24 @@ export const serveBase = < status: 500, }) as TResponse; } + + try { + runMiddlewares(middlewares, "onError", { + workflowRunId: request.headers.get(WORKFLOW_ID_HEADER) ?? "no-workflow-id", + error: error as Error, + }); + } catch (middlewareError) { + const formattedMiddlewareError = formatWorkflowError(middlewareError); + const errorMessage = + `Error while running middleware onError callback: '${formattedMiddlewareError.message}'.` + + `\nOriginal error: '${formattedError.message}'`; + + console.error(errorMessage); + return new Response(errorMessage, { + status: 500, + }) as TResponse; + } + return new Response(JSON.stringify(formattedError), { status: 500, }) as TResponse; diff --git a/src/types.ts b/src/types.ts index da3b81d..65892e5 100644 --- a/src/types.ts +++ b/src/types.ts @@ -527,13 +527,10 @@ export type InvokableWorkflow = { export type MiddlewareCallbacks = { beforeExecution?: (params: { workflowRunId: string; stepName: string }) => Promise | void; - afterExecution?: (params: { - workflowRunId: string; - stepName: string; - result: unknown; - }) => Promise | void; + afterExecution?: (params: { workflowRunId: string; stepName: string }) => Promise | void; runStarted?: (params: { workflowRunId: string }) => Promise | void; - runCompleted?: (params: { workflowRunId: string; result: unknown }) => Promise | void; + runCompleted?: (params: { workflowRunId: string }) => Promise | void; + onError?: (params: { workflowRunId: string; error: Error }) => Promise | void; }; export type MiddlewareParameters = {