@@ -7,6 +7,7 @@ import type {
7
7
RunAgentInput ,
8
8
RunFinishedEvent ,
9
9
RunStartedEvent ,
10
+ StateSnapshotEvent ,
10
11
TextMessageChunkEvent ,
11
12
ToolCall ,
12
13
ToolCallArgsEvent ,
@@ -22,9 +23,9 @@ import {
22
23
ExperimentalEmptyAdapter ,
23
24
} from "@copilotkit/runtime" ;
24
25
import { processDataStream } from "@ai-sdk/ui-utils" ;
25
- import type { CoreMessage , Mastra } from "@mastra/core" ;
26
+ import type { CoreMessage , Mastra , StorageThreadType } from "@mastra/core" ;
26
27
import { registerApiRoute } from "@mastra/core/server" ;
27
- import type { Agent as LocalMastraAgent } from "@mastra/core/agent" ;
28
+ import { Agent as LocalMastraAgent } from "@mastra/core/agent" ;
28
29
import type { Context } from "hono" ;
29
30
import { RuntimeContext } from "@mastra/core/runtime-context" ;
30
31
import { randomUUID } from "crypto" ;
@@ -185,88 +186,148 @@ export class MastraAgent extends AbstractAgent {
185
186
finalMessages . push ( assistantMessage ) ;
186
187
187
188
return new Observable < BaseEvent > ( ( subscriber ) => {
188
- subscriber . next ( {
189
- type : EventType . RUN_STARTED ,
190
- threadId : input . threadId ,
191
- runId : input . runId ,
192
- } as RunStartedEvent ) ;
193
-
194
- this . streamMastraAgent ( input , {
195
- onTextPart : ( text ) => {
196
- assistantMessage . content += text ;
197
- const event : TextMessageChunkEvent = {
198
- type : EventType . TEXT_MESSAGE_CHUNK ,
199
- role : "assistant" ,
200
- messageId,
201
- delta : text ,
202
- } ;
203
- subscriber . next ( event ) ;
204
- } ,
205
- onToolCallPart : ( streamPart ) => {
206
- let toolCall : ToolCall = {
207
- id : streamPart . toolCallId ,
208
- type : "function" ,
209
- function : {
210
- name : streamPart . toolName ,
211
- arguments : JSON . stringify ( streamPart . args ) ,
189
+ const run = async ( ) => {
190
+ subscriber . next ( {
191
+ type : EventType . RUN_STARTED ,
192
+ threadId : input . threadId ,
193
+ runId : input . runId ,
194
+ } as RunStartedEvent ) ;
195
+
196
+ // Handle local agent memory management (from Mastra implementation)
197
+ if ( 'metrics' in this . agent ) {
198
+ const memory = this . agent . getMemory ( ) ;
199
+
200
+ if ( memory && input . state && Object . keys ( input . state || { } ) . length > 0 ) {
201
+ let thread : StorageThreadType | null = await memory . getThreadById ( { threadId : input . threadId } ) ;
202
+
203
+ if ( ! thread ) {
204
+ thread = {
205
+ id : input . threadId ,
206
+ title : '' ,
207
+ metadata : { } ,
208
+ resourceId : this . resourceId as string ,
209
+ createdAt : new Date ( ) ,
210
+ updatedAt : new Date ( ) ,
211
+ } ;
212
+ }
213
+
214
+ if ( thread . resourceId && thread . resourceId !== this . resourceId ) {
215
+ throw new Error (
216
+ `Thread with id ${ input . threadId } resourceId does not match the current resourceId ${ this . resourceId } ` ,
217
+ ) ;
218
+ }
219
+
220
+ const { messages, ...rest } = input . state ;
221
+ const workingMemory = JSON . stringify ( rest ) ;
222
+
223
+ // Update thread metadata with new working memory
224
+ await memory . saveThread ( {
225
+ thread : {
226
+ ...thread ,
227
+ metadata : {
228
+ ...thread . metadata ,
229
+ workingMemory,
230
+ } ,
231
+ } ,
232
+ } ) ;
233
+ }
234
+ }
235
+
236
+ try {
237
+ await this . streamMastraAgent ( input , {
238
+ onTextPart : ( text ) => {
239
+ assistantMessage . content += text ;
240
+ const event : TextMessageChunkEvent = {
241
+ type : EventType . TEXT_MESSAGE_CHUNK ,
242
+ role : "assistant" ,
243
+ messageId,
244
+ delta : text ,
245
+ } ;
246
+ subscriber . next ( event ) ;
247
+ } ,
248
+ onToolCallPart : ( streamPart ) => {
249
+ let toolCall : ToolCall = {
250
+ id : streamPart . toolCallId ,
251
+ type : "function" ,
252
+ function : {
253
+ name : streamPart . toolName ,
254
+ arguments : JSON . stringify ( streamPart . args ) ,
255
+ } ,
256
+ } ;
257
+ assistantMessage . toolCalls ! . push ( toolCall ) ;
258
+
259
+ const startEvent : ToolCallStartEvent = {
260
+ type : EventType . TOOL_CALL_START ,
261
+ parentMessageId : messageId ,
262
+ toolCallId : streamPart . toolCallId ,
263
+ toolCallName : streamPart . toolName ,
264
+ } ;
265
+ subscriber . next ( startEvent ) ;
266
+
267
+ const argsEvent : ToolCallArgsEvent = {
268
+ type : EventType . TOOL_CALL_ARGS ,
269
+ toolCallId : streamPart . toolCallId ,
270
+ delta : JSON . stringify ( streamPart . args ) ,
271
+ } ;
272
+ subscriber . next ( argsEvent ) ;
273
+
274
+ const endEvent : ToolCallEndEvent = {
275
+ type : EventType . TOOL_CALL_END ,
276
+ toolCallId : streamPart . toolCallId ,
277
+ } ;
278
+ subscriber . next ( endEvent ) ;
212
279
} ,
213
- } ;
214
- assistantMessage . toolCalls ! . push ( toolCall ) ;
215
-
216
- const startEvent : ToolCallStartEvent = {
217
- type : EventType . TOOL_CALL_START ,
218
- parentMessageId : messageId ,
219
- toolCallId : streamPart . toolCallId ,
220
- toolCallName : streamPart . toolName ,
221
- } ;
222
- subscriber . next ( startEvent ) ;
223
-
224
- const argsEvent : ToolCallArgsEvent = {
225
- type : EventType . TOOL_CALL_ARGS ,
226
- toolCallId : streamPart . toolCallId ,
227
- delta : JSON . stringify ( streamPart . args ) ,
228
- } ;
229
- subscriber . next ( argsEvent ) ;
230
-
231
- const endEvent : ToolCallEndEvent = {
232
- type : EventType . TOOL_CALL_END ,
233
- toolCallId : streamPart . toolCallId ,
234
- } ;
235
- subscriber . next ( endEvent ) ;
236
- } ,
237
- onToolResultPart ( streamPart ) {
238
- const toolMessage : ToolMessage = {
239
- role : "tool" ,
240
- id : randomUUID ( ) ,
241
- toolCallId : streamPart . toolCallId ,
242
- content : JSON . stringify ( streamPart . result ) ,
243
- } ;
244
- finalMessages . push ( toolMessage ) ;
245
- } ,
246
- onFinishMessagePart : ( ) => {
247
- // Emit message snapshot
248
- const event : MessagesSnapshotEvent = {
249
- type : EventType . MESSAGES_SNAPSHOT ,
250
- messages : finalMessages ,
251
- } ;
252
- subscriber . next ( event ) ;
253
-
254
- // Emit run finished event
255
- subscriber . next ( {
256
- type : EventType . RUN_FINISHED ,
257
- threadId : input . threadId ,
258
- runId : input . runId ,
259
- } as RunFinishedEvent ) ;
260
-
261
- // Complete the observable
262
- subscriber . complete ( ) ;
263
- } ,
264
- onError : ( error ) => {
265
- console . error ( "error" , error ) ;
266
- // Handle error
280
+ onToolResultPart ( streamPart ) {
281
+ const toolMessage : ToolMessage = {
282
+ role : "tool" ,
283
+ id : randomUUID ( ) ,
284
+ toolCallId : streamPart . toolCallId ,
285
+ content : JSON . stringify ( streamPart . result ) ,
286
+ } ;
287
+ finalMessages . push ( toolMessage ) ;
288
+ } ,
289
+ onFinishMessagePart : async ( ) => {
290
+ // Emit message snapshot
291
+ const event : MessagesSnapshotEvent = {
292
+ type : EventType . MESSAGES_SNAPSHOT ,
293
+ messages : finalMessages ,
294
+ } ;
295
+ subscriber . next ( event ) ;
296
+
297
+ if ( 'metrics' in this . agent ) {
298
+ const memory = this . agent . getMemory ( ) ;
299
+ if ( memory ) {
300
+ const workingMemory = await memory . getWorkingMemory ( { threadId : input . threadId , format : 'json' } ) ;
301
+ subscriber . next ( {
302
+ type : EventType . STATE_SNAPSHOT ,
303
+ snapshot : workingMemory ,
304
+ } as StateSnapshotEvent ) ;
305
+ }
306
+ }
307
+
308
+ // Emit run finished event
309
+ subscriber . next ( {
310
+ type : EventType . RUN_FINISHED ,
311
+ threadId : input . threadId ,
312
+ runId : input . runId ,
313
+ } as RunFinishedEvent ) ;
314
+
315
+ // Complete the observable
316
+ subscriber . complete ( ) ;
317
+ } ,
318
+ onError : ( error ) => {
319
+ console . error ( "error" , error ) ;
320
+ // Handle error
321
+ subscriber . error ( error ) ;
322
+ } ,
323
+ } ) ;
324
+ } catch ( error ) {
325
+ console . error ( "Stream error:" , error ) ;
267
326
subscriber . error ( error ) ;
268
- } ,
269
- } ) ;
327
+ }
328
+ } ;
329
+
330
+ run ( ) ;
270
331
271
332
return ( ) => { } ;
272
333
} ) ;
@@ -278,7 +339,7 @@ export class MastraAgent extends AbstractAgent {
278
339
* @param options - The options for the mastra agent.
279
340
* @returns The stream of the mastra agent.
280
341
*/
281
- private streamMastraAgent (
342
+ private async streamMastraAgent (
282
343
{ threadId, runId, messages, tools } : RunAgentInput ,
283
344
{
284
345
onTextPart,
@@ -287,7 +348,7 @@ export class MastraAgent extends AbstractAgent {
287
348
onToolResultPart,
288
349
onError,
289
350
} : MastraAgentStreamOptions ,
290
- ) {
351
+ ) : Promise < void > {
291
352
const clientTools = tools . reduce (
292
353
( acc , tool ) => {
293
354
acc [ tool . name as string ] = {
@@ -310,48 +371,74 @@ export class MastraAgent extends AbstractAgent {
310
371
}
311
372
312
373
if ( isLocalMastraAgent ( this . agent ) ) {
313
- // in process agent
314
- return this . agent
315
- . stream ( convertedMessages , {
374
+ // Local agent - use the agent's stream method directly
375
+ try {
376
+ const response = await this . agent . stream ( convertedMessages , {
316
377
threadId,
317
378
resourceId,
318
379
runId,
319
380
clientTools,
320
381
runtimeContext,
321
- } )
322
- . then ( ( response ) => {
323
- return processDataStream ( {
324
- stream : ( response as any ) . toDataStreamResponse ( ) . body ! ,
325
- onTextPart,
326
- onToolCallPart,
327
- onToolResultPart,
328
- onFinishMessagePart,
329
- } ) ;
330
- } )
331
- . catch ( ( error ) => {
332
- onError ?.( error ) ;
333
382
} ) ;
383
+
384
+ // For local agents, the response should already be a stream
385
+ // Process it using the agent's built-in streaming mechanism
386
+ if ( response && typeof response === 'object' ) {
387
+ // If the response has a toDataStreamResponse method, use it
388
+ if ( 'toDataStreamResponse' in response && typeof response . toDataStreamResponse === 'function' ) {
389
+ const dataStreamResponse = response . toDataStreamResponse ( ) ;
390
+ if ( dataStreamResponse && dataStreamResponse . body ) {
391
+ await processDataStream ( {
392
+ stream : dataStreamResponse . body ,
393
+ onTextPart,
394
+ onToolCallPart,
395
+ onToolResultPart,
396
+ onFinishMessagePart,
397
+ } ) ;
398
+ } else {
399
+ throw new Error ( 'Invalid data stream response from local agent' ) ;
400
+ }
401
+ } else {
402
+ // If it's already a readable stream, process it directly
403
+ await processDataStream ( {
404
+ stream : response as any ,
405
+ onTextPart,
406
+ onToolCallPart,
407
+ onToolResultPart,
408
+ onFinishMessagePart,
409
+ } ) ;
410
+ }
411
+ } else {
412
+ throw new Error ( 'Invalid response from local agent' ) ;
413
+ }
414
+ } catch ( error ) {
415
+ onError ?.( error as Error ) ;
416
+ }
334
417
} else {
335
- // remote agent
336
- return this . agent
337
- . stream ( {
418
+ // Remote agent - use the remote agent's stream method
419
+ try {
420
+ const response = await this . agent . stream ( {
338
421
threadId,
339
422
resourceId,
340
423
runId,
341
424
messages : convertedMessages ,
342
425
clientTools,
343
- } )
344
- . then ( ( response ) => {
345
- return response . processDataStream ( {
426
+ } ) ;
427
+
428
+ // Remote agents should have a processDataStream method
429
+ if ( response && typeof response . processDataStream === 'function' ) {
430
+ await response . processDataStream ( {
346
431
onTextPart,
347
432
onToolCallPart,
348
433
onToolResultPart,
349
434
onFinishMessagePart,
350
435
} ) ;
351
- } )
352
- . catch ( ( error ) => {
353
- onError ?.( error ) ;
354
- } ) ;
436
+ } else {
437
+ throw new Error ( 'Invalid response from remote agent' ) ;
438
+ }
439
+ } catch ( error ) {
440
+ onError ?.( error as Error ) ;
441
+ }
355
442
}
356
443
}
357
444
}
0 commit comments