Skip to content

Commit ceb1854

Browse files
committed
feat(mastra): support shared state for local agents)
Signed-off-by: Tyler Slaton <[email protected]>
1 parent 6264c58 commit ceb1854

File tree

3 files changed

+2737
-206
lines changed

3 files changed

+2737
-206
lines changed

typescript-sdk/integrations/mastra/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
},
3232
"peerDependencies": {
3333
"@copilotkit/runtime": "^1.8.13",
34-
"@mastra/core": "^0.10.1",
34+
"@mastra/core": "^0.10.6",
3535
"zod": "^3.0.0"
3636
},
3737
"devDependencies": {

typescript-sdk/integrations/mastra/src/index.ts

Lines changed: 196 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import type {
77
RunAgentInput,
88
RunFinishedEvent,
99
RunStartedEvent,
10+
StateSnapshotEvent,
1011
TextMessageChunkEvent,
1112
ToolCall,
1213
ToolCallArgsEvent,
@@ -22,9 +23,9 @@ import {
2223
ExperimentalEmptyAdapter,
2324
} from "@copilotkit/runtime";
2425
import { processDataStream } from "@ai-sdk/ui-utils";
25-
import type { CoreMessage, Mastra } from "@mastra/core";
26+
import type { CoreMessage, Mastra, StorageThreadType } from "@mastra/core";
2627
import { registerApiRoute } from "@mastra/core/server";
27-
import type { Agent as LocalMastraAgent } from "@mastra/core/agent";
28+
import { Agent as LocalMastraAgent } from "@mastra/core/agent";
2829
import type { Context } from "hono";
2930
import { RuntimeContext } from "@mastra/core/runtime-context";
3031
import { randomUUID } from "crypto";
@@ -185,88 +186,148 @@ export class MastraAgent extends AbstractAgent {
185186
finalMessages.push(assistantMessage);
186187

187188
return new Observable<BaseEvent>((subscriber) => {
188-
subscriber.next({
189-
type: EventType.RUN_STARTED,
190-
threadId: input.threadId,
191-
runId: input.runId,
192-
} as RunStartedEvent);
193-
194-
this.streamMastraAgent(input, {
195-
onTextPart: (text) => {
196-
assistantMessage.content += text;
197-
const event: TextMessageChunkEvent = {
198-
type: EventType.TEXT_MESSAGE_CHUNK,
199-
role: "assistant",
200-
messageId,
201-
delta: text,
202-
};
203-
subscriber.next(event);
204-
},
205-
onToolCallPart: (streamPart) => {
206-
let toolCall: ToolCall = {
207-
id: streamPart.toolCallId,
208-
type: "function",
209-
function: {
210-
name: streamPart.toolName,
211-
arguments: JSON.stringify(streamPart.args),
189+
const run = async () => {
190+
subscriber.next({
191+
type: EventType.RUN_STARTED,
192+
threadId: input.threadId,
193+
runId: input.runId,
194+
} as RunStartedEvent);
195+
196+
// Handle local agent memory management (from Mastra implementation)
197+
if ('metrics' in this.agent) {
198+
const memory = this.agent.getMemory();
199+
200+
if (memory && input.state && Object.keys(input.state || {}).length > 0) {
201+
let thread: StorageThreadType | null = await memory.getThreadById({ threadId: input.threadId });
202+
203+
if (!thread) {
204+
thread = {
205+
id: input.threadId,
206+
title: '',
207+
metadata: {},
208+
resourceId: this.resourceId as string,
209+
createdAt: new Date(),
210+
updatedAt: new Date(),
211+
};
212+
}
213+
214+
if (thread.resourceId && thread.resourceId !== this.resourceId) {
215+
throw new Error(
216+
`Thread with id ${input.threadId} resourceId does not match the current resourceId ${this.resourceId}`,
217+
);
218+
}
219+
220+
const { messages, ...rest } = input.state;
221+
const workingMemory = JSON.stringify(rest);
222+
223+
// Update thread metadata with new working memory
224+
await memory.saveThread({
225+
thread: {
226+
...thread,
227+
metadata: {
228+
...thread.metadata,
229+
workingMemory,
230+
},
231+
},
232+
});
233+
}
234+
}
235+
236+
try {
237+
await this.streamMastraAgent(input, {
238+
onTextPart: (text) => {
239+
assistantMessage.content += text;
240+
const event: TextMessageChunkEvent = {
241+
type: EventType.TEXT_MESSAGE_CHUNK,
242+
role: "assistant",
243+
messageId,
244+
delta: text,
245+
};
246+
subscriber.next(event);
247+
},
248+
onToolCallPart: (streamPart) => {
249+
let toolCall: ToolCall = {
250+
id: streamPart.toolCallId,
251+
type: "function",
252+
function: {
253+
name: streamPart.toolName,
254+
arguments: JSON.stringify(streamPart.args),
255+
},
256+
};
257+
assistantMessage.toolCalls!.push(toolCall);
258+
259+
const startEvent: ToolCallStartEvent = {
260+
type: EventType.TOOL_CALL_START,
261+
parentMessageId: messageId,
262+
toolCallId: streamPart.toolCallId,
263+
toolCallName: streamPart.toolName,
264+
};
265+
subscriber.next(startEvent);
266+
267+
const argsEvent: ToolCallArgsEvent = {
268+
type: EventType.TOOL_CALL_ARGS,
269+
toolCallId: streamPart.toolCallId,
270+
delta: JSON.stringify(streamPart.args),
271+
};
272+
subscriber.next(argsEvent);
273+
274+
const endEvent: ToolCallEndEvent = {
275+
type: EventType.TOOL_CALL_END,
276+
toolCallId: streamPart.toolCallId,
277+
};
278+
subscriber.next(endEvent);
212279
},
213-
};
214-
assistantMessage.toolCalls!.push(toolCall);
215-
216-
const startEvent: ToolCallStartEvent = {
217-
type: EventType.TOOL_CALL_START,
218-
parentMessageId: messageId,
219-
toolCallId: streamPart.toolCallId,
220-
toolCallName: streamPart.toolName,
221-
};
222-
subscriber.next(startEvent);
223-
224-
const argsEvent: ToolCallArgsEvent = {
225-
type: EventType.TOOL_CALL_ARGS,
226-
toolCallId: streamPart.toolCallId,
227-
delta: JSON.stringify(streamPart.args),
228-
};
229-
subscriber.next(argsEvent);
230-
231-
const endEvent: ToolCallEndEvent = {
232-
type: EventType.TOOL_CALL_END,
233-
toolCallId: streamPart.toolCallId,
234-
};
235-
subscriber.next(endEvent);
236-
},
237-
onToolResultPart(streamPart) {
238-
const toolMessage: ToolMessage = {
239-
role: "tool",
240-
id: randomUUID(),
241-
toolCallId: streamPart.toolCallId,
242-
content: JSON.stringify(streamPart.result),
243-
};
244-
finalMessages.push(toolMessage);
245-
},
246-
onFinishMessagePart: () => {
247-
// Emit message snapshot
248-
const event: MessagesSnapshotEvent = {
249-
type: EventType.MESSAGES_SNAPSHOT,
250-
messages: finalMessages,
251-
};
252-
subscriber.next(event);
253-
254-
// Emit run finished event
255-
subscriber.next({
256-
type: EventType.RUN_FINISHED,
257-
threadId: input.threadId,
258-
runId: input.runId,
259-
} as RunFinishedEvent);
260-
261-
// Complete the observable
262-
subscriber.complete();
263-
},
264-
onError: (error) => {
265-
console.error("error", error);
266-
// Handle error
280+
onToolResultPart(streamPart) {
281+
const toolMessage: ToolMessage = {
282+
role: "tool",
283+
id: randomUUID(),
284+
toolCallId: streamPart.toolCallId,
285+
content: JSON.stringify(streamPart.result),
286+
};
287+
finalMessages.push(toolMessage);
288+
},
289+
onFinishMessagePart: async () => {
290+
// Emit message snapshot
291+
const event: MessagesSnapshotEvent = {
292+
type: EventType.MESSAGES_SNAPSHOT,
293+
messages: finalMessages,
294+
};
295+
subscriber.next(event);
296+
297+
if ('metrics' in this.agent){
298+
const memory = this.agent.getMemory();
299+
if (memory) {
300+
const workingMemory = await memory.getWorkingMemory({ threadId: input.threadId, format: 'json' });
301+
subscriber.next({
302+
type: EventType.STATE_SNAPSHOT,
303+
snapshot: workingMemory,
304+
} as StateSnapshotEvent);
305+
}
306+
}
307+
308+
// Emit run finished event
309+
subscriber.next({
310+
type: EventType.RUN_FINISHED,
311+
threadId: input.threadId,
312+
runId: input.runId,
313+
} as RunFinishedEvent);
314+
315+
// Complete the observable
316+
subscriber.complete();
317+
},
318+
onError: (error) => {
319+
console.error("error", error);
320+
// Handle error
321+
subscriber.error(error);
322+
},
323+
});
324+
} catch (error) {
325+
console.error("Stream error:", error);
267326
subscriber.error(error);
268-
},
269-
});
327+
}
328+
};
329+
330+
run();
270331

271332
return () => {};
272333
});
@@ -278,7 +339,7 @@ export class MastraAgent extends AbstractAgent {
278339
* @param options - The options for the mastra agent.
279340
* @returns The stream of the mastra agent.
280341
*/
281-
private streamMastraAgent(
342+
private async streamMastraAgent(
282343
{ threadId, runId, messages, tools }: RunAgentInput,
283344
{
284345
onTextPart,
@@ -287,7 +348,7 @@ export class MastraAgent extends AbstractAgent {
287348
onToolResultPart,
288349
onError,
289350
}: MastraAgentStreamOptions,
290-
) {
351+
): Promise<void> {
291352
const clientTools = tools.reduce(
292353
(acc, tool) => {
293354
acc[tool.name as string] = {
@@ -310,48 +371,74 @@ export class MastraAgent extends AbstractAgent {
310371
}
311372

312373
if (isLocalMastraAgent(this.agent)) {
313-
// in process agent
314-
return this.agent
315-
.stream(convertedMessages, {
374+
// Local agent - use the agent's stream method directly
375+
try {
376+
const response = await this.agent.stream(convertedMessages, {
316377
threadId,
317378
resourceId,
318379
runId,
319380
clientTools,
320381
runtimeContext,
321-
})
322-
.then((response) => {
323-
return processDataStream({
324-
stream: (response as any).toDataStreamResponse().body!,
325-
onTextPart,
326-
onToolCallPart,
327-
onToolResultPart,
328-
onFinishMessagePart,
329-
});
330-
})
331-
.catch((error) => {
332-
onError?.(error);
333382
});
383+
384+
// For local agents, the response should already be a stream
385+
// Process it using the agent's built-in streaming mechanism
386+
if (response && typeof response === 'object') {
387+
// If the response has a toDataStreamResponse method, use it
388+
if ('toDataStreamResponse' in response && typeof response.toDataStreamResponse === 'function') {
389+
const dataStreamResponse = response.toDataStreamResponse();
390+
if (dataStreamResponse && dataStreamResponse.body) {
391+
await processDataStream({
392+
stream: dataStreamResponse.body,
393+
onTextPart,
394+
onToolCallPart,
395+
onToolResultPart,
396+
onFinishMessagePart,
397+
});
398+
} else {
399+
throw new Error('Invalid data stream response from local agent');
400+
}
401+
} else {
402+
// If it's already a readable stream, process it directly
403+
await processDataStream({
404+
stream: response as any,
405+
onTextPart,
406+
onToolCallPart,
407+
onToolResultPart,
408+
onFinishMessagePart,
409+
});
410+
}
411+
} else {
412+
throw new Error('Invalid response from local agent');
413+
}
414+
} catch (error) {
415+
onError?.(error as Error);
416+
}
334417
} else {
335-
// remote agent
336-
return this.agent
337-
.stream({
418+
// Remote agent - use the remote agent's stream method
419+
try {
420+
const response = await this.agent.stream({
338421
threadId,
339422
resourceId,
340423
runId,
341424
messages: convertedMessages,
342425
clientTools,
343-
})
344-
.then((response) => {
345-
return response.processDataStream({
426+
});
427+
428+
// Remote agents should have a processDataStream method
429+
if (response && typeof response.processDataStream === 'function') {
430+
await response.processDataStream({
346431
onTextPart,
347432
onToolCallPart,
348433
onToolResultPart,
349434
onFinishMessagePart,
350435
});
351-
})
352-
.catch((error) => {
353-
onError?.(error);
354-
});
436+
} else {
437+
throw new Error('Invalid response from remote agent');
438+
}
439+
} catch (error) {
440+
onError?.(error as Error);
441+
}
355442
}
356443
}
357444
}

0 commit comments

Comments
 (0)