1
1
from contextlib import asynccontextmanager
2
2
from datetime import timedelta
3
3
from typing import Any , AsyncIterator , Dict , Optional , Sequence , Tuple
4
+ import logging
4
5
5
6
from langchain_core .runnables import RunnableConfig
6
7
from acouchbase .cluster import Cluster as ACluster
7
8
from acouchbase .bucket import Bucket as ABucket
8
9
from couchbase .auth import PasswordAuthenticator
9
10
from couchbase .options import ClusterOptions , QueryOptions , UpsertOptions
11
+ from couchbase .exceptions import CollectionAlreadyExistsException
10
12
11
13
from langgraph .checkpoint .base import (
12
14
BaseCheckpointSaver ,
18
20
)
19
21
from .utils import _encode_binary , _decode_binary
20
22
23
+ logger = logging .getLogger (__name__ )
24
+
21
25
class AsyncCouchbaseSaver (BaseCheckpointSaver ):
22
26
"""A checkpoint saver that stores checkpoints in a Couchbase database."""
23
27
@@ -35,9 +39,35 @@ def __init__(
35
39
self .cluster = cluster
36
40
self .bucket_name = bucket_name
37
41
self .scope_name = scope_name
42
+ self .bucket = self .cluster .bucket (bucket_name )
43
+ self .scope = self .bucket .scope (scope_name )
38
44
self .checkpoints_collection_name = checkpoints_collection_name
39
45
self .checkpoint_writes_collection_name = checkpoint_writes_collection_name
40
46
47
+ async def create_collections (self ):
48
+ """Create collections in the Couchbase bucket if they do not exist."""
49
+
50
+ collection_manager = self .bucket .collections ()
51
+ try :
52
+ await collection_manager .create_collection (self .scope_name , self .checkpoints_collection_name )
53
+ except CollectionAlreadyExistsException as _ :
54
+ pass
55
+ except Exception as e :
56
+ logger .exception ("Error creating collections" )
57
+ raise e
58
+ finally :
59
+ self .checkpoints_collection = self .bucket .scope (self .scope_name ).collection (self .checkpoints_collection_name )
60
+
61
+ try :
62
+ await collection_manager .create_collection (self .scope_name , self .checkpoint_writes_collection_name )
63
+ except CollectionAlreadyExistsException as _ :
64
+ pass
65
+ except Exception as e :
66
+ logger .exception ("Error creating collections" )
67
+ raise e
68
+ finally :
69
+ self .checkpoint_writes_collection = self .bucket .scope (self .scope_name ).collection (self .checkpoint_writes_collection_name )
70
+
41
71
@classmethod
42
72
@asynccontextmanager
43
73
async def from_conn_info (
@@ -69,15 +99,25 @@ async def from_conn_info(
69
99
cls .bucket_name = bucket_name
70
100
cls .scope_name = scope_name
71
101
72
- saver = AsyncCouchbaseSaver (cluster , bucket_name , scope_name , checkpoints_collection_name , checkpoint_writes_collection_name )
73
- cls .bucket = cluster .bucket (bucket_name )
74
- await cls .bucket .on_connect ()
102
+ bucket = cluster .bucket (bucket_name )
103
+ await bucket .on_connect ()
104
+
105
+ saver = AsyncCouchbaseSaver (
106
+ cluster ,
107
+ bucket_name ,
108
+ scope_name ,
109
+ checkpoints_collection_name ,
110
+ checkpoint_writes_collection_name ,
111
+ )
112
+
113
+ await saver .create_collections ()
75
114
76
115
yield saver
77
116
finally :
78
117
if cluster :
79
118
await cluster .close ()
80
119
120
+
81
121
@classmethod
82
122
@asynccontextmanager
83
123
async def from_cluster (
@@ -98,9 +138,18 @@ async def from_cluster(
98
138
AsyncCouchbaseSaver: An instance of the AsyncCouchbaseSaver
99
139
"""
100
140
101
- saver = AsyncCouchbaseSaver (cluster , bucket_name , scope_name , checkpoints_collection_name , checkpoint_writes_collection_name )
102
- cls .bucket = cluster .bucket (bucket_name )
103
- await cls .bucket .on_connect ()
141
+ bucket = cluster .bucket (bucket_name )
142
+ await bucket .on_connect ()
143
+
144
+ saver = AsyncCouchbaseSaver (
145
+ cluster ,
146
+ bucket_name ,
147
+ scope_name ,
148
+ checkpoints_collection_name ,
149
+ checkpoint_writes_collection_name ,
150
+ )
151
+
152
+ await saver .create_collections ()
104
153
105
154
yield saver
106
155
@@ -149,7 +198,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
149
198
async for write_doc in serialized_writes_result :
150
199
checkpoint_writes = write_doc .get (self .checkpoint_writes_collection_name , {})
151
200
if "task_id" not in checkpoint_writes :
152
- print ( "Error: 'task_id' is not present in checkpoint_writes" )
201
+ logger . warning ( " 'task_id' is not present in checkpoint_writes" )
153
202
else :
154
203
pending_writes .append (
155
204
(
@@ -294,6 +343,8 @@ async def aput(
294
343
295
344
upsert_key = f"{ thread_id } ::{ checkpoint_ns } ::{ checkpoint_id } "
296
345
346
+ # ensure bucket connected (idempotent)
347
+ await self .bucket .on_connect ()
297
348
collection = self .bucket .scope (self .scope_name ).collection (self .checkpoints_collection_name )
298
349
await collection .upsert (upsert_key , (doc ), UpsertOptions (timeout = timedelta (seconds = 5 )))
299
350
@@ -324,6 +375,7 @@ async def aput_writes(
324
375
checkpoint_ns = config ["configurable" ]["checkpoint_ns" ]
325
376
checkpoint_id = config ["configurable" ]["checkpoint_id" ]
326
377
378
+ await self .bucket .on_connect ()
327
379
collection = self .bucket .scope (self .scope_name ).collection (self .checkpoint_writes_collection_name )
328
380
329
381
for idx , (channel , value ) in enumerate (writes ):
0 commit comments