Skip to content

Commit 5fb1cb5

Browse files
committed
allow narrowing of inherited types
1 parent c8aea2f commit 5fb1cb5

File tree

2 files changed

+63
-5
lines changed

2 files changed

+63
-5
lines changed

schema_salad/avro/schema.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ class RecordSchema(NamedSchema):
477477
def make_field_objects(field_data: List[PropsType], names: Names) -> List[Field]:
478478
"""We're going to need to make message parameters too."""
479479
field_objects = [] # type: List[Field]
480-
field_names = [] # type: List[str]
480+
parsed_fields: Dict[str, PropsType] = {}
481481
for field in field_data:
482482
if hasattr(field, "get") and callable(field.get):
483483
atype = field.get("type")
@@ -504,10 +504,19 @@ def make_field_objects(field_data: List[PropsType], names: Names) -> List[Field]
504504
atype, name, has_default, default, order, names, doc, other_props
505505
)
506506
# make sure field name has not been used yet
507-
if new_field.name in field_names:
508-
fail_msg = f"Field name {new_field.name} already in use."
509-
raise SchemaParseException(fail_msg)
510-
field_names.append(new_field.name)
507+
if new_field.name in parsed_fields:
508+
old_field = parsed_fields[new_field.name]
509+
if "inherited_from" not in old_field:
510+
raise SchemaParseException(
511+
f"Field name {new_field.name} already in use."
512+
)
513+
if not is_subtype(old_field["type"], field["type"]):
514+
raise SchemaParseException(
515+
f"Field name {new_field.name} already in use with "
516+
"incompatible type. "
517+
f"{field['type']} vs {old_field['type']}."
518+
)
519+
parsed_fields[new_field.name] = field
511520
else:
512521
raise SchemaParseException(f"Not a valid field: {field}")
513522
field_objects.append(new_field)
@@ -655,3 +664,29 @@ def make_avsc_object(json_data: JsonDataType, names: Optional[Names] = None) ->
655664
# not for us!
656665
fail_msg = f"Could not make an Avro Schema object from {json_data}."
657666
raise SchemaParseException(fail_msg)
667+
668+
669+
def is_subtype(existing: PropType, new: PropType) -> bool:
670+
"""Checks if a new type specification is compatible with an existing type spec."""
671+
if existing == new:
672+
return True
673+
if isinstance(existing, list) and (new in existing):
674+
return True
675+
if (
676+
isinstance(existing, dict)
677+
and "type" in existing
678+
and existing["type"] == "array"
679+
and isinstance(new, dict)
680+
and "type" in new
681+
and new["type"] == "array"
682+
):
683+
return is_subtype(existing["items"], new["items"])
684+
if isinstance(existing, list) and isinstance(new, list):
685+
missing = False
686+
for _type in new:
687+
if _type not in existing and (
688+
not is_subtype(existing, cast(PropType, _type))
689+
):
690+
missing = True
691+
return not missing
692+
return False

schema_salad/tests/test_subtypes.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Confirm subtypes."""
2+
from schema_salad.avro import schema
3+
4+
import pytest
5+
6+
types = [
7+
(["int", "float", "double"], "int", True),
8+
(["int", "float", "double"], ["int"], True),
9+
(["int", "float", "double"], ["int", "float"], True),
10+
(["int", "float", "double"], ["int", "float", "File"], False),
11+
({"type": "array", "items": ["int", "float", "double"]}, ["int", "float"], False),
12+
(
13+
{"type": "array", "items": ["int", "float", "double"]},
14+
{"type": "array", "items": ["int", "float"]},
15+
True,
16+
),
17+
]
18+
19+
20+
@pytest.mark.parametrize("old,new,result", types)
21+
def test_subtypes(old: schema.PropType, new: schema.PropType, result: bool) -> None:
22+
"""Test is_subtype() function."""
23+
assert schema.is_subtype(old, new) == result

0 commit comments

Comments
 (0)