Skip to content

Commit eaad923

Browse files
Merge pull request #4 from Couchbase-Ecosystem/DA-1012-langgraph-checkpointer
Create Collection support in Async Class
2 parents b8c2b91 + 09a756c commit eaad923

File tree

4 files changed

+122
-53
lines changed

4 files changed

+122
-53
lines changed

README.md

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -68,58 +68,72 @@ with CouchbaseSaver.from_conn_info(
6868
bucket_name=os.getenv("CB_BUCKET") or "test",
6969
scope_name=os.getenv("CB_SCOPE") or "langgraph",
7070
) as checkpointer:
71-
# Create the agent with checkpointing
72-
graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)
73-
74-
# Configure with a unique thread ID
75-
config = {"configurable": {"thread_id": "1"}}
76-
77-
# Run the agent
78-
res = graph.invoke({"messages": [("human", "what's the weather in sf")]}, config)
79-
80-
# Retrieve checkpoints
81-
latest_checkpoint = checkpointer.get(config)
82-
latest_checkpoint_tuple = checkpointer.get_tuple(config)
83-
checkpoint_tuples = list(checkpointer.list(config))
84-
85-
print(latest_checkpoint)
86-
print(latest_checkpoint_tuple)
87-
print(checkpoint_tuples)
71+
# Create the agent with checkpointing
72+
graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)
73+
74+
# Configure with a unique thread ID
75+
config = {"configurable": {"thread_id": "1"}}
76+
77+
# Run the agent
78+
res = graph.invoke({"messages": [("human", "what's the weather in sf")]}, config)
79+
80+
# Retrieve checkpoints
81+
latest_checkpoint = checkpointer.get(config)
82+
latest_checkpoint_tuple = checkpointer.get_tuple(config)
83+
checkpoint_tuples = list(checkpointer.list(config))
84+
85+
print(latest_checkpoint)
86+
print(latest_checkpoint_tuple)
87+
print(checkpoint_tuples)
8888
```
8989

9090
### Asynchronous Usage
9191

9292
```python
9393
import os
94+
from acouchbase.cluster import Cluster as ACluster
95+
from couchbase.auth import PasswordAuthenticator
96+
from couchbase.options import ClusterOptions
9497
from langgraph_checkpointer_couchbase import AsyncCouchbaseSaver
9598
from langgraph.graph import create_react_agent
9699

97-
async with AsyncCouchbaseSaver.from_conn_info(
98-
cb_conn_str=os.getenv("CB_CLUSTER") or "couchbase://localhost",
99-
cb_username=os.getenv("CB_USERNAME") or "Administrator",
100-
cb_password=os.getenv("CB_PASSWORD") or "password",
101-
bucket_name=os.getenv("CB_BUCKET") or "test",
102-
scope_name=os.getenv("CB_SCOPE") or "langgraph",
100+
auth = PasswordAuthenticator(
101+
os.getenv("CB_USERNAME") or "Administrator",
102+
os.getenv("CB_PASSWORD") or "password",
103+
)
104+
options = ClusterOptions(auth)
105+
cluster = await ACluster.connect(os.getenv("CB_CLUSTER") or "couchbase://localhost", options)
106+
107+
bucket_name = os.getenv("CB_BUCKET") or "test"
108+
scope_name = os.getenv("CB_SCOPE") or "langgraph"
109+
110+
async with AsyncCouchbaseSaver.from_cluster(
111+
cluster=cluster,
112+
bucket_name=bucket_name,
113+
scope_name=scope_name,
103114
) as checkpointer:
104-
# Create the agent with checkpointing
105-
graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)
106-
107-
# Configure with a unique thread ID
108-
config = {"configurable": {"thread_id": "2"}}
109-
110-
# Run the agent asynchronously
111-
res = await graph.ainvoke(
112-
{"messages": [("human", "what's the weather in nyc")]}, config
113-
)
114-
115-
# Retrieve checkpoints asynchronously
116-
latest_checkpoint = await checkpointer.aget(config)
117-
latest_checkpoint_tuple = await checkpointer.aget_tuple(config)
118-
checkpoint_tuples = [c async for c in checkpointer.alist(config)]
119-
120-
print(latest_checkpoint)
121-
print(latest_checkpoint_tuple)
122-
print(checkpoint_tuples)
115+
# Create the agent with checkpointing
116+
graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)
117+
118+
# Configure with a unique thread ID
119+
config = {"configurable": {"thread_id": "2"}}
120+
121+
# Run the agent asynchronously
122+
res = await graph.ainvoke(
123+
{"messages": [("human", "what's the weather in nyc")]}, config
124+
)
125+
126+
# Retrieve checkpoints asynchronously
127+
latest_checkpoint = await checkpointer.aget(config)
128+
latest_checkpoint_tuple = await checkpointer.aget_tuple(config)
129+
checkpoint_tuples = [c async for c in checkpointer.alist(config)]
130+
131+
print(latest_checkpoint)
132+
print(latest_checkpoint_tuple)
133+
print(checkpoint_tuples)
134+
135+
# Close the cluster when done
136+
await cluster.close()
123137
```
124138

125139
## Configuration Options

langgraph_checkpointer_couchbase/async_cb_saver.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from contextlib import asynccontextmanager
22
from datetime import timedelta
33
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple
4+
import logging
45

56
from langchain_core.runnables import RunnableConfig
67
from acouchbase.cluster import Cluster as ACluster
78
from acouchbase.bucket import Bucket as ABucket
89
from couchbase.auth import PasswordAuthenticator
910
from couchbase.options import ClusterOptions, QueryOptions, UpsertOptions
11+
from couchbase.exceptions import CollectionAlreadyExistsException
1012

1113
from langgraph.checkpoint.base import (
1214
BaseCheckpointSaver,
@@ -18,6 +20,8 @@
1820
)
1921
from .utils import _encode_binary, _decode_binary
2022

23+
logger = logging.getLogger(__name__)
24+
2125
class AsyncCouchbaseSaver(BaseCheckpointSaver):
2226
"""A checkpoint saver that stores checkpoints in a Couchbase database."""
2327

@@ -35,9 +39,35 @@ def __init__(
3539
self.cluster = cluster
3640
self.bucket_name = bucket_name
3741
self.scope_name = scope_name
42+
self.bucket = self.cluster.bucket(bucket_name)
43+
self.scope = self.bucket.scope(scope_name)
3844
self.checkpoints_collection_name = checkpoints_collection_name
3945
self.checkpoint_writes_collection_name = checkpoint_writes_collection_name
4046

47+
async def create_collections(self):
48+
"""Create collections in the Couchbase bucket if they do not exist."""
49+
50+
collection_manager = self.bucket.collections()
51+
try:
52+
await collection_manager.create_collection(self.scope_name, self.checkpoints_collection_name)
53+
except CollectionAlreadyExistsException as _:
54+
pass
55+
except Exception as e:
56+
logger.exception("Error creating collections")
57+
raise e
58+
finally:
59+
self.checkpoints_collection = self.bucket.scope(self.scope_name).collection(self.checkpoints_collection_name)
60+
61+
try:
62+
await collection_manager.create_collection(self.scope_name, self.checkpoint_writes_collection_name)
63+
except CollectionAlreadyExistsException as _:
64+
pass
65+
except Exception as e:
66+
logger.exception("Error creating collections")
67+
raise e
68+
finally:
69+
self.checkpoint_writes_collection = self.bucket.scope(self.scope_name).collection(self.checkpoint_writes_collection_name)
70+
4171
@classmethod
4272
@asynccontextmanager
4373
async def from_conn_info(
@@ -69,15 +99,25 @@ async def from_conn_info(
6999
cls.bucket_name = bucket_name
70100
cls.scope_name = scope_name
71101

72-
saver = AsyncCouchbaseSaver(cluster, bucket_name, scope_name, checkpoints_collection_name, checkpoint_writes_collection_name)
73-
cls.bucket = cluster.bucket(bucket_name)
74-
await cls.bucket.on_connect()
102+
bucket = cluster.bucket(bucket_name)
103+
await bucket.on_connect()
104+
105+
saver = AsyncCouchbaseSaver(
106+
cluster,
107+
bucket_name,
108+
scope_name,
109+
checkpoints_collection_name,
110+
checkpoint_writes_collection_name,
111+
)
112+
113+
await saver.create_collections()
75114

76115
yield saver
77116
finally:
78117
if cluster:
79118
await cluster.close()
80119

120+
81121
@classmethod
82122
@asynccontextmanager
83123
async def from_cluster(
@@ -98,9 +138,18 @@ async def from_cluster(
98138
AsyncCouchbaseSaver: An instance of the AsyncCouchbaseSaver
99139
"""
100140

101-
saver = AsyncCouchbaseSaver(cluster, bucket_name, scope_name, checkpoints_collection_name, checkpoint_writes_collection_name)
102-
cls.bucket = cluster.bucket(bucket_name)
103-
await cls.bucket.on_connect()
141+
bucket = cluster.bucket(bucket_name)
142+
await bucket.on_connect()
143+
144+
saver = AsyncCouchbaseSaver(
145+
cluster,
146+
bucket_name,
147+
scope_name,
148+
checkpoints_collection_name,
149+
checkpoint_writes_collection_name,
150+
)
151+
152+
await saver.create_collections()
104153

105154
yield saver
106155

@@ -149,7 +198,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
149198
async for write_doc in serialized_writes_result:
150199
checkpoint_writes = write_doc.get(self.checkpoint_writes_collection_name, {})
151200
if "task_id" not in checkpoint_writes:
152-
print("Error: 'task_id' is not present in checkpoint_writes")
201+
logger.warning("'task_id' is not present in checkpoint_writes")
153202
else:
154203
pending_writes.append(
155204
(
@@ -294,6 +343,8 @@ async def aput(
294343

295344
upsert_key = f"{thread_id}::{checkpoint_ns}::{checkpoint_id}"
296345

346+
# ensure bucket connected (idempotent)
347+
await self.bucket.on_connect()
297348
collection = self.bucket.scope(self.scope_name).collection(self.checkpoints_collection_name)
298349
await collection.upsert(upsert_key, (doc), UpsertOptions(timeout=timedelta(seconds=5)))
299350

@@ -324,6 +375,7 @@ async def aput_writes(
324375
checkpoint_ns = config["configurable"]["checkpoint_ns"]
325376
checkpoint_id = config["configurable"]["checkpoint_id"]
326377

378+
await self.bucket.on_connect()
327379
collection = self.bucket.scope(self.scope_name).collection(self.checkpoint_writes_collection_name)
328380

329381
for idx, (channel, value) in enumerate(writes):

langgraph_checkpointer_couchbase/couchbase_saver.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from contextlib import contextmanager
22
from datetime import timedelta
33
from typing import Any, Dict, Iterator, Optional, Sequence, Tuple
4+
import logging
45

56
from langchain_core.runnables import RunnableConfig
67
from couchbase.cluster import Cluster
@@ -19,6 +20,8 @@
1920
)
2021

2122
from .utils import _encode_binary, _decode_binary
23+
24+
logger = logging.getLogger(__name__)
2225
class CouchbaseSaver(BaseCheckpointSaver):
2326
"""A checkpoint saver that stores checkpoints in a Couchbase database.
2427
@@ -118,7 +121,7 @@ def create_collections(self):
118121
except CollectionAlreadyExistsException as _:
119122
pass
120123
except Exception as e:
121-
print(f"Error creating collections: {e}")
124+
logger.exception("Error creating collections")
122125
raise e
123126
finally:
124127
self.checkpoints_collection = self.bucket.scope(self.scope_name).collection(self.checkpoints_collection_name)
@@ -128,7 +131,7 @@ def create_collections(self):
128131
except CollectionAlreadyExistsException as _:
129132
pass
130133
except Exception as e:
131-
print(f"Error creating collections: {e}")
134+
logger.exception("Error creating collections")
132135
raise e
133136
finally:
134137
self.checkpoint_writes_collection = self.bucket.scope(self.scope_name).collection(self.checkpoint_writes_collection_name)
@@ -181,7 +184,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
181184
for write_doc in serialized_writes_result:
182185
checkpoint_writes = write_doc.get(self.checkpoint_writes_collection_name, {})
183186
if "task_id" not in checkpoint_writes:
184-
print("Error: 'task_id' is not present in checkpoint_writes")
187+
logger.warning("'task_id' is not present in checkpoint_writes")
185188
else:
186189
# Decode and deserialize value data
187190
value_data = _decode_binary(checkpoint_writes["value"])

tests/agent_e2e_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from langchain_core.tools import tool
55
from langchain_openai import ChatOpenAI
66
from langgraph.prebuilt import create_react_agent
7-
from langgraph_checkpoint_couchbase import CouchbaseSaver, AsyncCouchbaseSaver
7+
from langgraph_checkpointer_couchbase import CouchbaseSaver, AsyncCouchbaseSaver
88
from dotenv import load_dotenv
99
import os
1010
load_dotenv()

0 commit comments

Comments
 (0)