Skip to content

Commit d178c85

Browse files
committed
Fix #47: Can serialize Union with Literal
1 parent 572c79c commit d178c85

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

databind/src/databind/json/converters.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -763,13 +763,19 @@ def _check_style_compatibility(self, ctx: Context, style: str, value: t.Any) ->
763763
def convert(self, ctx: Context) -> t.Any:
764764
datatype = ctx.datatype
765765
union: t.Optional[Union]
766+
literal_types: list[TypeHint] = []
767+
766768
if isinstance(datatype, UnionTypeHint):
767769
if datatype.has_none_type():
768770
raise NotImplementedError("unable to handle Union type with None in it")
769-
if not all(isinstance(a, ClassTypeHint) for a in datatype):
770-
raise NotImplementedError(f"members of plain Union must be concrete types: {datatype}")
771-
members = {t.cast(ClassTypeHint, a).type.__name__: a for a in datatype}
772-
if len(members) != len(datatype):
771+
772+
literal_types = [a for a in datatype if isinstance(a, LiteralTypeHint)]
773+
non_literal_types = [a for a in datatype if not isinstance(a, LiteralTypeHint)]
774+
if not all(isinstance(a, ClassTypeHint) for a in non_literal_types):
775+
raise NotImplementedError(f"members of plain Union must be concrete or Literal types: {datatype}")
776+
777+
members = {t.cast(ClassTypeHint, a).type.__name__: a for a in non_literal_types}
778+
if len(members) != len(non_literal_types):
773779
raise NotImplementedError(f"members of plain Union cannot have overlapping type names: {datatype}")
774780
union = Union(members, Union.BEST_MATCH)
775781
elif isinstance(datatype, (AnnotatedTypeHint, ClassTypeHint)):
@@ -788,6 +794,11 @@ def convert(self, ctx: Context) -> t.Any:
788794
return ctx.spawn(ctx.value, member_type, None).convert()
789795
except ConversionError as exc:
790796
errors.append((exc.origin, exc))
797+
for literal_type in literal_types:
798+
try:
799+
return ctx.spawn(ctx.value, literal_type, None).convert()
800+
except ConversionError as exc:
801+
errors.append((exc.origin, exc))
791802
raise ConversionError(
792803
self,
793804
ctx,

databind/src/databind/json/tests/converters_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,3 +713,16 @@ def of(cls, v: str) -> "MyCls":
713713
mapper = make_mapper([JsonConverterSupport()])
714714
assert mapper.serialize(MyCls(), MyCls) == "MyCls"
715715
assert mapper.deserialize("MyCls", MyCls) == MyCls()
716+
717+
718+
def test_union_literal():
719+
mapper = make_mapper([UnionConverter(), PlainDatatypeConverter()])
720+
721+
IntType = int | t.Literal["hi", "bye"]
722+
StrType = str | t.Literal["hi", "bye"]
723+
724+
assert mapper.serialize("hi", IntType) == "hi"
725+
assert mapper.serialize(2, IntType) == 2
726+
727+
assert mapper.serialize("bye", StrType) == "bye"
728+
assert mapper.serialize("other", StrType) == "other"

0 commit comments

Comments
 (0)