Skip to content

Commit c84a217

Browse files
authored
Add unit tests for async SR client and async serdes (#1994)
* Add unit tests for async SR client and async serdes * fix flake8 issues * fix test_delivery_report_serialization * run unasync
1 parent 4cc24f3 commit c84a217

32 files changed

+6223
-133
lines changed

src/confluent_kafka/schema_registry/_async/avro.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ async def __deserialize(
558558
latest_schema = await self._get_reader_schema(subject)
559559

560560
if latest_schema is not None:
561-
migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None)
561+
migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None)
562562
reader_schema_raw = latest_schema.schema
563563
reader_schema = await self._get_parsed_schema(latest_schema.schema)
564564
elif self._schema is not None:

src/confluent_kafka/schema_registry/_async/json_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = N
598598
writer_schema, writer_ref_registry = None, None
599599

600600
if latest_schema is not None:
601-
migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None)
601+
migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None)
602602
reader_schema_raw = latest_schema.schema
603603
reader_schema, reader_ref_registry = await self._get_parsed_schema(latest_schema.schema)
604604
elif self._schema is not None:
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright 2024 Confluent Inc.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
import uuid
19+
from collections import defaultdict
20+
from threading import Lock
21+
from typing import List, Dict, Optional
22+
23+
from .schema_registry_client import AsyncSchemaRegistryClient
24+
from ..common.schema_registry_client import RegisteredSchema, Schema, ServerConfig
25+
from ..error import SchemaRegistryError
26+
27+
28+
class _SchemaStore(object):
29+
30+
def __init__(self):
31+
self.lock = Lock()
32+
self.max_id = 0
33+
self.schema_id_index = {}
34+
self.schema_guid_index = {}
35+
self.schema_index = {}
36+
self.subject_schemas = defaultdict(set)
37+
38+
def set(self, registered_schema: RegisteredSchema) -> RegisteredSchema:
39+
with self.lock:
40+
self.max_id += 1
41+
rs = RegisteredSchema(
42+
schema_id=self.max_id,
43+
guid=registered_schema.guid,
44+
schema=registered_schema.schema,
45+
subject=registered_schema.subject,
46+
version=registered_schema.version
47+
)
48+
self.schema_id_index[rs.schema_id] = rs
49+
self.schema_guid_index[rs.guid] = rs
50+
self.schema_index[rs.schema] = rs.schema_id
51+
self.subject_schemas[rs.subject].add(rs)
52+
return rs
53+
54+
def get_schema(self, schema_id: int) -> Optional[Schema]:
55+
with self.lock:
56+
rs = self.schema_id_index.get(schema_id, None)
57+
return rs.schema if rs else None
58+
59+
def get_schema_by_guid(self, guid: str) -> Optional[Schema]:
60+
with self.lock:
61+
rs = self.schema_guid_index.get(guid, None)
62+
return rs.schema if rs else None
63+
64+
def get_registered_schema_by_schema(
65+
self,
66+
subject_name: str,
67+
schema: Schema
68+
) -> Optional[RegisteredSchema]:
69+
with self.lock:
70+
if subject_name in self.subject_schemas:
71+
for rs in self.subject_schemas[subject_name]:
72+
if rs.schema == schema:
73+
return rs
74+
return None
75+
76+
def get_version(self, subject_name: str, version: int) -> Optional[RegisteredSchema]:
77+
with self.lock:
78+
if subject_name in self.subject_schemas:
79+
for rs in self.subject_schemas[subject_name]:
80+
if rs.version == version:
81+
return rs
82+
return None
83+
84+
def get_latest_version(self, subject_name: str) -> Optional[RegisteredSchema]:
85+
with self.lock:
86+
if subject_name in self.subject_schemas:
87+
latest_version = 0
88+
latest_schema = None
89+
for rs in self.subject_schemas[subject_name]:
90+
if rs.version > latest_version:
91+
latest_version = rs.version
92+
latest_schema = rs
93+
return latest_schema
94+
return None
95+
96+
def get_latest_with_metadata(
97+
self, subject_name: str,
98+
metadata: Dict[str, str]
99+
) -> Optional[RegisteredSchema]:
100+
with self.lock:
101+
if subject_name in self.subject_schemas:
102+
rs: RegisteredSchema
103+
for rs in self.subject_schemas[subject_name]:
104+
if (rs.schema
105+
and rs.schema.metadata
106+
and rs.schema.metadata.properties
107+
and metadata.items() <= rs.schema.metadata.properties.properties.items()):
108+
return rs
109+
return None
110+
111+
def get_subjects(self) -> List[str]:
112+
with self.lock:
113+
return list(self.subject_schemas.keys())
114+
115+
def get_versions(self, subject_name: str) -> List[int]:
116+
with self.lock:
117+
if subject_name in self.subject_schemas:
118+
return [rs.version for rs in self.subject_schemas[subject_name]]
119+
return []
120+
121+
def remove_by_schema(self, registered_schema: RegisteredSchema):
122+
with self.lock:
123+
subject_name = registered_schema.subject
124+
if subject_name in self.subject_schemas:
125+
self.subject_schemas[subject_name].remove(registered_schema)
126+
127+
def remove_by_subject(self, subject_name: str) -> List[int]:
128+
with self.lock:
129+
versions = []
130+
if subject_name in self.subject_schemas:
131+
for rs in self.subject_schemas[subject_name]:
132+
versions.append(rs.version)
133+
schema_id = self.schema_index.pop(rs.schema, None)
134+
if schema_id is not None:
135+
self.schema_id_index.pop(schema_id, None)
136+
137+
del self.subject_schemas[subject_name]
138+
return versions
139+
140+
def clear(self):
141+
with self.lock:
142+
self.schema_id_index.clear()
143+
self.schema_guid_index.clear()
144+
self.schema_index.clear()
145+
self.subject_schemas.clear()
146+
147+
148+
class AsyncMockSchemaRegistryClient(AsyncSchemaRegistryClient):
149+
150+
def __init__(self, conf: dict):
151+
super().__init__(conf)
152+
self._store = _SchemaStore()
153+
154+
async def register_schema(
155+
self, subject_name: str, schema: 'Schema',
156+
normalize_schemas: bool = False
157+
) -> int:
158+
registered_schema = await self.register_schema_full_response(subject_name, schema, normalize_schemas)
159+
return registered_schema.schema_id
160+
161+
async def register_schema_full_response(
162+
self, subject_name: str, schema: 'Schema',
163+
normalize_schemas: bool = False
164+
) -> 'RegisteredSchema':
165+
registered_schema = self._store.get_registered_schema_by_schema(subject_name, schema)
166+
if registered_schema is not None:
167+
return registered_schema
168+
169+
latest_schema = self._store.get_latest_version(subject_name)
170+
latest_version = 1 if latest_schema is None else latest_schema.version + 1
171+
172+
registered_schema = RegisteredSchema(
173+
schema_id=1,
174+
guid=str(uuid.uuid4()),
175+
schema=schema,
176+
subject=subject_name,
177+
version=latest_version
178+
)
179+
180+
registered_schema = self._store.set(registered_schema)
181+
182+
return registered_schema
183+
184+
async def get_schema(
185+
self, schema_id: int, subject_name: Optional[str] = None,
186+
fmt: Optional[str] = None
187+
) -> 'Schema':
188+
schema = self._store.get_schema(schema_id)
189+
if schema is not None:
190+
return schema
191+
192+
raise SchemaRegistryError(404, 40400, "Schema Not Found")
193+
194+
async def get_schema_by_guid(
195+
self, guid: str, fmt: Optional[str] = None
196+
) -> 'Schema':
197+
schema = self._store.get_schema_by_guid(guid)
198+
if schema is not None:
199+
return schema
200+
201+
raise SchemaRegistryError(404, 40400, "Schema Not Found")
202+
203+
async def lookup_schema(
204+
self, subject_name: str, schema: 'Schema',
205+
normalize_schemas: bool = False, deleted: bool = False
206+
) -> 'RegisteredSchema':
207+
208+
registered_schema = self._store.get_registered_schema_by_schema(subject_name, schema)
209+
if registered_schema is not None:
210+
return registered_schema
211+
212+
raise SchemaRegistryError(404, 40400, "Schema Not Found")
213+
214+
async def get_subjects(self) -> List[str]:
215+
return self._store.get_subjects()
216+
217+
async def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int]:
218+
return self._store.remove_by_subject(subject_name)
219+
220+
async def get_latest_version(self, subject_name: str, fmt: Optional[str] = None) -> 'RegisteredSchema':
221+
registered_schema = self._store.get_latest_version(subject_name)
222+
if registered_schema is not None:
223+
return registered_schema
224+
225+
raise SchemaRegistryError(404, 40400, "Schema Not Found")
226+
227+
async def get_latest_with_metadata(
228+
self, subject_name: str, metadata: Dict[str, str],
229+
deleted: bool = False, fmt: Optional[str] = None
230+
) -> 'RegisteredSchema':
231+
registered_schema = self._store.get_latest_with_metadata(subject_name, metadata)
232+
if registered_schema is not None:
233+
return registered_schema
234+
235+
raise SchemaRegistryError(404, 40400, "Schema Not Found")
236+
237+
async def get_version(
238+
self, subject_name: str, version: int,
239+
deleted: bool = False, fmt: Optional[str] = None
240+
) -> 'RegisteredSchema':
241+
registered_schema = self._store.get_version(subject_name, version)
242+
if registered_schema is not None:
243+
return registered_schema
244+
245+
raise SchemaRegistryError(404, 40400, "Schema Not Found")
246+
247+
async def get_versions(self, subject_name: str) -> List[int]:
248+
return self._store.get_versions(subject_name)
249+
250+
async def delete_version(self, subject_name: str, version: int, permanent: bool = False) -> int:
251+
registered_schema = self._store.get_version(subject_name, version)
252+
if registered_schema is not None:
253+
self._store.remove_by_schema(registered_schema)
254+
return registered_schema.schema_id
255+
256+
raise SchemaRegistryError(404, 40400, "Schema Not Found")
257+
258+
async def set_config(
259+
self, subject_name: Optional[str] = None, config: 'ServerConfig' = None # noqa F821
260+
) -> 'ServerConfig': # noqa F821
261+
return None
262+
263+
async def get_config(self, subject_name: Optional[str] = None) -> 'ServerConfig': # noqa F821
264+
return None

src/confluent_kafka/schema_registry/_async/protobuf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ async def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = N
597597
writer_schema = None
598598

599599
if latest_schema is not None:
600-
migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None)
600+
migrations = await self._get_migrations(subject, writer_schema_raw, latest_schema, None)
601601
reader_schema_raw = latest_schema.schema
602602
fd_proto, pool = await self._get_parsed_schema(latest_schema.schema)
603603
reader_schema = pool.FindFileByName(fd_proto.name)

src/confluent_kafka/schema_registry/_async/schema_registry_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,8 +1158,8 @@ def clear_caches(self):
11581158

11591159
@staticmethod
11601160
def new_client(conf: dict) -> 'AsyncSchemaRegistryClient':
1161-
from confluent_kafka.schema_registry.mock_schema_registry_client import MockSchemaRegistryClient
1161+
from .mock_schema_registry_client import AsyncMockSchemaRegistryClient
11621162
url = conf.get("url")
11631163
if url.startswith("mock://"):
1164-
return MockSchemaRegistryClient(conf)
1164+
return AsyncMockSchemaRegistryClient(conf)
11651165
return AsyncSchemaRegistryClient(conf)

src/confluent_kafka/schema_registry/_async/serde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ async def _get_migrations(
216216
self, subject: str, source_info: Schema,
217217
target: RegisteredSchema, fmt: Optional[str]
218218
) -> List[Migration]:
219-
source = self._registry.lookup_schema(subject, source_info, False, True)
219+
source = await self._registry.lookup_schema(subject, source_info, False, True)
220220
migrations = []
221221
if source.version < target.version:
222222
migration_mode = RuleMode.UPGRADE

src/confluent_kafka/schema_registry/mock_schema_registry_client.py renamed to src/confluent_kafka/schema_registry/_sync/mock_schema_registry_client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from threading import Lock
2121
from typing import List, Dict, Optional
2222

23-
from . import SchemaRegistryClient, RegisteredSchema, Schema
24-
from .error import SchemaRegistryError
23+
from .schema_registry_client import SchemaRegistryClient
24+
from ..common.schema_registry_client import RegisteredSchema, Schema, ServerConfig
25+
from ..error import SchemaRegistryError
2526

2627

2728
class _SchemaStore(object):

src/confluent_kafka/schema_registry/_sync/schema_registry_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1158,7 +1158,7 @@ def clear_caches(self):
11581158

11591159
@staticmethod
11601160
def new_client(conf: dict) -> 'SchemaRegistryClient':
1161-
from confluent_kafka.schema_registry.mock_schema_registry_client import MockSchemaRegistryClient
1161+
from .mock_schema_registry_client import MockSchemaRegistryClient
11621162
url = conf.get("url")
11631163
if url.startswith("mock://"):
11641164
return MockSchemaRegistryClient(conf)

tests/integration/schema_registry/_async/test_avro_serializers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,8 @@ async def test_delivery_report_serialization(kafka_cluster, load_file, avsc, dat
278278
producer = kafka_cluster.async_producer(value_serializer=value_serializer)
279279

280280
async def assert_cb(err, msg):
281-
actual = value_deserializer(msg.value(),
282-
SerializationContext(topic, MessageField.VALUE, msg.headers()))
281+
actual = await value_deserializer(
282+
msg.value(), SerializationContext(topic, MessageField.VALUE, msg.headers()))
283283

284284
if record_type == "record":
285285
assert [v == actual[k] for k, v in data.items()]

tests/integration/schema_registry/_sync/test_avro_serializers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,8 @@ def test_delivery_report_serialization(kafka_cluster, load_file, avsc, data, rec
278278
producer = kafka_cluster.producer(value_serializer=value_serializer)
279279

280280
def assert_cb(err, msg):
281-
actual = value_deserializer(msg.value(),
282-
SerializationContext(topic, MessageField.VALUE, msg.headers()))
281+
actual = value_deserializer(
282+
msg.value(), SerializationContext(topic, MessageField.VALUE, msg.headers()))
283283

284284
if record_type == "record":
285285
assert [v == actual[k] for k, v in data.items()]

0 commit comments

Comments
 (0)