diff --git a/src/context/auto-executor.ts b/src/context/auto-executor.ts index ae334ea..cbea467 100644 --- a/src/context/auto-executor.ts +++ b/src/context/auto-executor.ts @@ -5,6 +5,7 @@ import { type BaseLazyStep } from "./steps"; import type { WorkflowLogger } from "../logger"; import { QstashError } from "@upstash/qstash"; import { submitParallelSteps, submitSingleStep } from "../qstash/submit-steps"; +import { WorkflowMiddleware } from "../middleware"; export class AutoExecutor { private context: WorkflowContext; @@ -15,8 +16,9 @@ export class AutoExecutor { private readonly nonPlanStepCount: number; private readonly steps: Step[]; private indexInCurrentList = 0; - private invokeCount: number; - private telemetry?: Telemetry; + private readonly invokeCount: number; + private readonly telemetry?: Telemetry; + private readonly middlewares?: WorkflowMiddleware[]; public stepCount = 0; public planStepCount = 0; @@ -28,13 +30,15 @@ export class AutoExecutor { steps: Step[], telemetry?: Telemetry, invokeCount?: number, - debug?: WorkflowLogger + debug?: WorkflowLogger, + middlewares?: WorkflowMiddleware[] ) { this.context = context; this.steps = steps; this.telemetry = telemetry; this.invokeCount = invokeCount ?? 0; this.debug = debug; + this.middlewares = middlewares; this.nonPlanStepCount = this.steps.filter((step) => !step.targetStep).length; } @@ -133,7 +137,9 @@ export class AutoExecutor { step, stepCount: this.stepCount, }); - return lazyStep.parseOut(step.out); + const parsedOut = lazyStep.parseOut(step.out); + + return parsedOut; } const resultStep = await submitSingleStep({ @@ -144,6 +150,7 @@ export class AutoExecutor { concurrency: 1, telemetry: this.telemetry, debug: this.debug, + middlewares: this.middlewares, }); throw new WorkflowAbort(lazyStep.stepName, resultStep); } @@ -232,6 +239,7 @@ export class AutoExecutor { concurrency: parallelSteps.length, telemetry: this.telemetry, debug: this.debug, + middlewares: this.middlewares, }); throw new WorkflowAbort(parallelStep.stepName, resultStep); } catch (error) { diff --git a/src/context/context.ts b/src/context/context.ts index a2f8976..91ad9f6 100644 --- a/src/context/context.ts +++ b/src/context/context.ts @@ -28,6 +28,7 @@ import { WorkflowApi } from "./api"; import { WorkflowAgents } from "../agents"; import { FlowControl } from "@upstash/qstash"; import { getNewUrlFromWorkflowId } from "../serve/serve-many"; +import { WorkflowMiddleware } from "../middleware"; /** * Upstash Workflow context @@ -176,6 +177,7 @@ export class WorkflowContext { telemetry, invokeCount, flowControl, + middlewares, }: { qstashClient: WorkflowClient; workflowRunId: string; @@ -190,6 +192,7 @@ export class WorkflowContext { telemetry?: Telemetry; invokeCount?: number; flowControl?: FlowControl; + middlewares?: WorkflowMiddleware[]; }) { this.qstashClient = qstashClient; this.workflowRunId = workflowRunId; @@ -202,7 +205,7 @@ export class WorkflowContext { this.retries = retries ?? DEFAULT_RETRIES; this.flowControl = flowControl; - this.executor = new AutoExecutor(this, this.steps, telemetry, invokeCount, debug); + this.executor = new AutoExecutor(this, this.steps, telemetry, invokeCount, debug, middlewares); } /** diff --git a/src/index.ts b/src/index.ts index ad18d62..4fbd119 100644 --- a/src/index.ts +++ b/src/index.ts @@ -6,3 +6,4 @@ export * from "./logger"; export * from "./client"; export { WorkflowError, WorkflowAbort } from "./error"; export { WorkflowTool } from "./agents/adapters"; +export { WorkflowMiddleware, loggingMiddleware } from "./middleware"; diff --git a/src/middleware/index.ts b/src/middleware/index.ts new file mode 100644 index 0000000..413364a --- /dev/null +++ b/src/middleware/index.ts @@ -0,0 +1,2 @@ +export { WorkflowMiddleware } from "./middleware"; +export { loggingMiddleware } from "./logging"; diff --git a/src/middleware/logging.ts b/src/middleware/logging.ts new file mode 100644 index 0000000..dbe9669 --- /dev/null +++ b/src/middleware/logging.ts @@ -0,0 +1,23 @@ +import { WorkflowMiddleware } from "./middleware"; + +export const loggingMiddleware = new WorkflowMiddleware({ + name: "logging", + init: () => { + console.log("Logging middleware initialized"); + + return { + afterExecution(params) { + console.log("Step executed:", params); + }, + beforeExecution(params) { + console.log("Step execution started:", params); + }, + runStarted(params) { + console.log("Workflow run started:", params); + }, + runCompleted(params) { + console.log("Workflow run completed:", params); + }, + }; + }, +}); diff --git a/src/middleware/middleware.test.ts b/src/middleware/middleware.test.ts new file mode 100644 index 0000000..f502380 --- /dev/null +++ b/src/middleware/middleware.test.ts @@ -0,0 +1,354 @@ +import { describe, test, expect, jest } from "bun:test"; +import { WorkflowMiddleware } from "./middleware"; +import { Client } from "@upstash/qstash"; +import { getRequest } from "../test-utils"; +import { nanoid } from "../utils"; +import { serve } from "../../platforms/nextjs"; +import { RouteFunction, Step } from "../types"; + +const createLoggingMiddleware = () => { + const accumulator: [string, unknown?][] = []; + const middleware = new WorkflowMiddleware({ + name: "test", + init: () => { + accumulator.push(["init"]); + + return { + afterExecution(params) { + accumulator.push(["afterExecution", params]); + }, + beforeExecution(params) { + accumulator.push(["beforeExecution", params]); + }, + runStarted(params) { + accumulator.push(["runStarted", params]); + }, + runCompleted(params) { + accumulator.push(["runCompleted", params]); + }, + onError(params) { + accumulator.push(["onError", params]); + }, + }; + }, + }); + + return { middleware, accumulator }; +}; + +describe("middleware", () => { + test("should not call init in constructor", () => { + const init = jest.fn(); + new WorkflowMiddleware({ name: "test", init }); + expect(init).not.toHaveBeenCalled(); + }); + + describe("runCallback method", () => { + test("should call init and callbacks", async () => { + const { middleware, accumulator } = createLoggingMiddleware(); + const stepName = `step-${nanoid()}`; + + await middleware.runCallback("runStarted", { workflowRunId: "wfr-id" }); + expect(accumulator).toEqual([["init"], ["runStarted", { workflowRunId: "wfr-id" }]]); + + await middleware.runCallback("beforeExecution", { + workflowRunId: "wfr-id", + stepName: stepName, + }); + expect(accumulator).toEqual([ + ["init"], + ["runStarted", { workflowRunId: "wfr-id" }], + ["beforeExecution", { workflowRunId: "wfr-id", stepName }], + ]); + + await middleware.runCallback("beforeExecution", { + workflowRunId: "wfr-id", + stepName: stepName, + }); + expect(accumulator).toEqual([ + ["init"], + ["runStarted", { workflowRunId: "wfr-id" }], + ["beforeExecution", { workflowRunId: "wfr-id", stepName }], + ["beforeExecution", { workflowRunId: "wfr-id", stepName }], + ]); + + await middleware.runCallback("afterExecution", { + workflowRunId: "wfr-id", + stepName: stepName, + }); + expect(accumulator).toEqual([ + ["init"], + ["runStarted", { workflowRunId: "wfr-id" }], + ["beforeExecution", { workflowRunId: "wfr-id", stepName }], + ["beforeExecution", { workflowRunId: "wfr-id", stepName }], + ["afterExecution", { workflowRunId: "wfr-id", stepName }], + ]); + + await middleware.runCallback("runCompleted", { + workflowRunId: "wfr-id", + }); + expect(accumulator).toEqual([ + ["init"], + ["runStarted", { workflowRunId: "wfr-id" }], + ["beforeExecution", { workflowRunId: "wfr-id", stepName }], + ["beforeExecution", { workflowRunId: "wfr-id", stepName }], + ["afterExecution", { workflowRunId: "wfr-id", stepName }], + ["runCompleted", { workflowRunId: "wfr-id" }], + ]); + }); + + describe("with context", () => { + const stepOneName = `step-one-${nanoid()}`; + const stepTwoName = `step-two-${nanoid()}`; + const stepThreeName = `step-three-${nanoid()}`; + const parallelRunOne = `parallel-sleep-One-${nanoid()}`; + const parallelRunTwo = `parallel-sleep-Two-${nanoid()}`; + const stepResult = `step-result-${nanoid()}`; + const stepResultOne = `step-result-one-${nanoid()}`; + const stepResultTwo = `step-result-two-${nanoid()}`; + + const incrementalTestSteps: { + step?: Step; + middlewareAccumaltor: ReturnType["accumulator"]; + }[] = [ + { + middlewareAccumaltor: [ + ["init"], + [ + "runStarted", + { + workflowRunId: "wfr-id", + }, + ], + [ + "beforeExecution", + { + workflowRunId: "wfr-id", + stepName: stepOneName, + }, + ], + [ + "afterExecution", + { + workflowRunId: "wfr-id", + stepName: stepOneName, + }, + ], + ], + }, + { + step: { + stepId: 1, + stepName: stepOneName, + stepType: "SleepFor", + sleepFor: 1, + concurrent: 1, + }, + middlewareAccumaltor: [ + ["init"], + [ + "beforeExecution", + { + workflowRunId: "wfr-id", + stepName: stepTwoName, + }, + ], + [ + "afterExecution", + { + workflowRunId: "wfr-id", + stepName: stepTwoName, + }, + ], + ], + }, + { + step: { + stepId: 2, + stepName: stepTwoName, + stepType: "Run", + out: JSON.stringify(stepResult), + concurrent: 1, + }, + middlewareAccumaltor: [], + }, + { + step: { + stepId: 0, + stepName: parallelRunOne, + stepType: "Run", + concurrent: 2, + targetStep: 3, + }, + middlewareAccumaltor: [ + ["init"], + [ + "beforeExecution", + { + workflowRunId: "wfr-id", + stepName: parallelRunOne, + }, + ], + [ + "afterExecution", + { + workflowRunId: "wfr-id", + stepName: parallelRunOne, + }, + ], + ], + }, + { + step: { + stepId: 0, + stepName: parallelRunTwo, + stepType: "Run", + concurrent: 2, + targetStep: 4, + }, + middlewareAccumaltor: [ + ["init"], + [ + "beforeExecution", + { + workflowRunId: "wfr-id", + stepName: parallelRunTwo, + }, + ], + [ + "afterExecution", + { + workflowRunId: "wfr-id", + stepName: parallelRunTwo, + }, + ], + ], + }, + { + step: { + stepId: 4, + stepName: parallelRunTwo, + stepType: "Run", + out: JSON.stringify(stepResultTwo), + concurrent: 2, + }, + middlewareAccumaltor: [], + }, + { + step: { + stepId: 3, + stepName: parallelRunOne, + stepType: "Run", + out: JSON.stringify(stepResultOne), + concurrent: 2, + }, + middlewareAccumaltor: [ + ["init"], + [ + "beforeExecution", + { + workflowRunId: "wfr-id", + stepName: stepThreeName, + }, + ], + [ + "afterExecution", + { + workflowRunId: "wfr-id", + stepName: stepThreeName, + }, + ], + ], + }, + { + step: { + stepId: 5, + stepName: stepThreeName, + stepType: "SleepFor", + sleepFor: 10, + concurrent: 1, + }, + middlewareAccumaltor: [ + ["init"], + [ + "runCompleted", + { + workflowRunId: "wfr-id", + }, + ], + ], + }, + ]; + + const routeFunction: RouteFunction = async (context) => { + await context.sleep(stepOneName, 1); + await context.run(stepTwoName, () => stepResult); + await Promise.all([ + context.run(parallelRunOne, () => stepResultOne), + context.run(parallelRunTwo, () => stepResultTwo), + ]); + await context.sleep(stepThreeName, 10); + }; + + const qstashClient = new Client({ baseUrl: "https://requestcatcher.com", token: "token" }); + qstashClient.http.request = jest.fn(); + + const runMiddlewareTest = async ( + steps: Step[], + expectedAccumulator: ReturnType["accumulator"], + status: number = 200 + ) => { + const { middleware, accumulator } = createLoggingMiddleware(); + + const request = getRequest("https://requestcatcher.com", "wfr-id", undefined, steps); + + const { POST: handler } = serve(routeFunction, { + middlewares: [middleware], + url: "https://requestcatcher.com", + receiver: undefined, + qstashClient, + }); + + const response = await handler(request); + expect(response.status).toBe(status); + + expect(accumulator).toEqual(expectedAccumulator); + }; + + incrementalTestSteps.forEach(({ middlewareAccumaltor }, index) => { + const testSteps = incrementalTestSteps + .slice(0, index + 1) + .map(({ step }) => step) + .filter(Boolean) as Step[]; + test(`should call middleware in order case #${index + 1}`, async () => { + await runMiddlewareTest(testSteps, middlewareAccumaltor); + }); + }); + + test("with error", async () => { + await runMiddlewareTest( + [ + { + stepId: 1, + stepName: stepOneName + "-error", + stepType: "SleepFor", + sleepFor: 1, + concurrent: 1, + }, + ], + [ + ["init"], + [ + "onError", + { + workflowRunId: "wfr-id", + error: expect.any(Error), + }, + ], + ], + 500 + ); + }); + }); + }); +}); diff --git a/src/middleware/middleware.ts b/src/middleware/middleware.ts new file mode 100644 index 0000000..30ae495 --- /dev/null +++ b/src/middleware/middleware.ts @@ -0,0 +1,43 @@ +import { MiddlewareCallbacks, MiddlewareParameters } from "../types"; + +export class WorkflowMiddleware { + private readonly init: MiddlewareParameters["init"]; + private middlewareCallbacks?: MiddlewareCallbacks; + + constructor(parameters: MiddlewareParameters) { + this.init = parameters.init; + this.middlewareCallbacks = undefined; + } + + async runCallback( + callback: K, + parameters: Parameters>[0] + ): Promise { + await this.ensureInit(parameters.workflowRunId); + const cb = this.middlewareCallbacks?.[callback]; + if (cb) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + await cb(parameters as any); + } + } + + private async ensureInit(workflowRunId: string) { + if (!this.middlewareCallbacks) { + this.middlewareCallbacks = await this.init({ workflowRunId }); + } + } +} + +export const runMiddlewares = async ( + middlewares: WorkflowMiddleware[] | undefined, + callback: K, + parameters: Parameters>[0] +) => { + if (!middlewares) { + return; + } + + middlewares.forEach(async (m) => { + await m.runCallback(callback, parameters); + }); +}; diff --git a/src/qstash/submit-steps.ts b/src/qstash/submit-steps.ts index b7c76e2..bf45106 100644 --- a/src/qstash/submit-steps.ts +++ b/src/qstash/submit-steps.ts @@ -5,6 +5,8 @@ import { Telemetry } from "../types"; import { WorkflowContext } from "../context"; import { BaseLazyStep } from "../context/steps"; import { getHeaders } from "./headers"; +import { WorkflowMiddleware } from "../middleware"; +import { runMiddlewares } from "../middleware/middleware"; export const submitParallelSteps = async ({ context, @@ -76,6 +78,7 @@ export const submitSingleStep = async ({ concurrency, telemetry, debug, + middlewares, }: { context: WorkflowContext; lazyStep: BaseLazyStep; @@ -84,8 +87,13 @@ export const submitSingleStep = async ({ concurrency: number; telemetry?: Telemetry; debug?: WorkflowLogger; + middlewares?: WorkflowMiddleware[]; }) => { const resultStep = await lazyStep.getResultStep(concurrency, stepId); + await runMiddlewares(middlewares, "beforeExecution", { + workflowRunId: context.workflowRunId, + stepName: resultStep.stepName, + }); await debug?.log("INFO", "RUN_SINGLE", { fromRequest: false, step: resultStep, @@ -111,6 +119,11 @@ export const submitSingleStep = async ({ steps: [resultStep], }); + await runMiddlewares(middlewares, "afterExecution", { + workflowRunId: context.workflowRunId, + stepName: lazyStep.stepName, + }); + const submitResult = await lazyStep.submitStep({ context, body, diff --git a/src/serve/index.ts b/src/serve/index.ts index 54d041d..c99e140 100644 --- a/src/serve/index.ts +++ b/src/serve/index.ts @@ -1,8 +1,9 @@ import { makeCancelRequest } from "../client/utils"; -import { SDK_TELEMETRY, WORKFLOW_INVOKE_COUNT_HEADER } from "../constants"; +import { SDK_TELEMETRY, WORKFLOW_ID_HEADER, WORKFLOW_INVOKE_COUNT_HEADER } from "../constants"; import { WorkflowContext } from "../context"; import { formatWorkflowError } from "../error"; import { WorkflowLogger } from "../logger"; +import { runMiddlewares } from "../middleware/middleware"; import { ExclusiveValidationOptions, RouteFunction, @@ -63,6 +64,7 @@ export const serveBase = < disableTelemetry, flowControl, onError, + middlewares, } = processOptions(options); telemetry = disableTelemetry ? undefined : telemetry; const debug = WorkflowLogger.getLogger(verbose); @@ -154,6 +156,7 @@ export const serveBase = < telemetry, invokeCount, flowControl, + middlewares, }); // attempt running routeFunction until the first step @@ -204,8 +207,18 @@ export const serveBase = < invokeCount, }) : await triggerRouteFunction({ - onStep: async () => routeFunction(workflowContext), + onStep: async () => { + if (steps.length === 1) { + await runMiddlewares(middlewares, "runStarted", { + workflowRunId: workflowContext.workflowRunId, + }); + } + await routeFunction(workflowContext); + }, onCleanup: async (result) => { + await runMiddlewares(middlewares, "runCompleted", { + workflowRunId: workflowContext.workflowRunId, + }); await triggerWorkflowDelete(workflowContext, result, debug); }, onCancel: async () => { @@ -249,6 +262,24 @@ export const serveBase = < status: 500, }) as TResponse; } + + try { + runMiddlewares(middlewares, "onError", { + workflowRunId: request.headers.get(WORKFLOW_ID_HEADER) ?? "no-workflow-id", + error: error as Error, + }); + } catch (middlewareError) { + const formattedMiddlewareError = formatWorkflowError(middlewareError); + const errorMessage = + `Error while running middleware onError callback: '${formattedMiddlewareError.message}'.` + + `\nOriginal error: '${formattedError.message}'`; + + console.error(errorMessage); + return new Response(errorMessage, { + status: 500, + }) as TResponse; + } + return new Response(JSON.stringify(formattedError), { status: 500, }) as TResponse; diff --git a/src/serve/options.ts b/src/serve/options.ts index b016ff0..d5d27bc 100644 --- a/src/serve/options.ts +++ b/src/serve/options.ts @@ -30,6 +30,7 @@ export const processOptions = => { const environment = options?.env ?? (typeof process === "undefined" ? ({} as Record) : process.env); diff --git a/src/types.ts b/src/types.ts index 38e7c65..65892e5 100644 --- a/src/types.ts +++ b/src/types.ts @@ -4,6 +4,7 @@ import type { HTTPMethods } from "@upstash/qstash"; import type { WorkflowContext } from "./context"; import type { WorkflowLogger } from "./logger"; import { z } from "zod"; +import { WorkflowMiddleware } from "./middleware"; /** * Interface for Client with required methods @@ -256,6 +257,10 @@ export type WorkflowServeOptions< * and number of requests per second with the same key. */ flowControl?: FlowControl; + /** + * List of workflow middlewares to use + */ + middlewares?: WorkflowMiddleware[]; } & ValidationOptions; type ValidationOptions = { @@ -519,3 +524,16 @@ export type InvokableWorkflow = { options: WorkflowServeOptions; workflowId?: string; }; + +export type MiddlewareCallbacks = { + beforeExecution?: (params: { workflowRunId: string; stepName: string }) => Promise | void; + afterExecution?: (params: { workflowRunId: string; stepName: string }) => Promise | void; + runStarted?: (params: { workflowRunId: string }) => Promise | void; + runCompleted?: (params: { workflowRunId: string }) => Promise | void; + onError?: (params: { workflowRunId: string; error: Error }) => Promise | void; +}; + +export type MiddlewareParameters = { + name: string; + init: (params: { workflowRunId: string }) => Promise | MiddlewareCallbacks; +};