Skip to content

Commit

Permalink
Merge pull request #64 from prashantgupta24/OneofField-feat-request
Browse files Browse the repository at this point in the history
✨support oneofs with a single field
  • Loading branch information
gabe-l-hart authored Oct 12, 2023
2 parents 69bbf5a + e05bf50 commit 8d04378
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 0 deletions.
16 changes: 16 additions & 0 deletions py_to_proto/dataclass_to_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,22 @@ def get_oneof_fields(
)
log.debug3("Using default oneof field name: %s", oneof_field_name)
oneof_fields.append((oneof_field_name, arg))

# here it's not a union, but it's still annotated.
# Special case in which we only have one field in the Union
# but we still want to create a one-of in case OneofField is present
# see https://github.com/IBM/py-to-proto/issues/63
elif get_origin(field_def.type) is Annotated and any(
type(arg) is OneofField for arg in get_args(field_def.type)
):
# it can only be 1 arg, hence no need to iterate through the args
oneof_field_name = self._get_unique_annotation(field_def.type, OneofField)
assert (
len(oneof_field_name) > 0
), "Got OneofField annotation without any name?"

log.debug3("Using oneof field name: %s", oneof_field_name)
oneof_fields.append((oneof_field_name, field_def.type))
return oneof_fields

def get_oneof_name(self, field_def: dataclasses.Field) -> str:
Expand Down
109 changes: 109 additions & 0 deletions tests/test_dataclass_to_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,115 @@ class Baz:
assert strseq_values_fld.label == strseq_values_fld.LABEL_REPEATED


def test_dataclass_to_proto_union_one_of_field(temp_dpool):
"""Make sure that a dataclasss with a Union of a single primitive field
annotated with OneOfField works correctly"""

@dataclass
class Foo:
foo: Union[
Annotated[bool, OneofField("foo_bool"), FieldNumber(10)], # type: ignore
]

desc = dataclass_to_proto("foo.bar", Foo, descriptor_pool=temp_dpool)
assert len(desc.oneofs) == 1
oneof_desc = desc.oneofs_by_name["foo"]
foobool_fld = desc.fields_by_name["foo_bool"]
assert foobool_fld.type == foobool_fld.TYPE_BOOL
assert foobool_fld.containing_oneof is oneof_desc
assert foobool_fld.number == 10


def test_dataclass_to_proto_one_of_field(temp_dpool):
"""Make sure that a dataclasss with a single primitive field
annotated with OneOfField works correctly"""

@dataclass
class Foo:
foo: Annotated[bool, OneofField("foo_bool"), FieldNumber(10)]

desc = dataclass_to_proto("foo.bar", Foo, descriptor_pool=temp_dpool)
assert len(desc.oneofs) == 1
oneof_desc = desc.oneofs_by_name["foo"]
foobool_fld = desc.fields_by_name["foo_bool"]
assert foobool_fld.type == foobool_fld.TYPE_BOOL
assert foobool_fld.containing_oneof is oneof_desc
assert foobool_fld.number == 10


def test_dataclass_to_proto_two_one_of_fields(temp_dpool):
"""Make sure that a dataclasss with two separate single primitive fields
annotated with OneOfField works correctly"""

@dataclass
class Foo:
foo: Annotated[bool, OneofField("foo_bool"), FieldNumber(10)]
bar: Annotated[bool, OneofField("bar_bool"), FieldNumber(20)]

desc = dataclass_to_proto("foo.bar", Foo, descriptor_pool=temp_dpool)
assert len(desc.oneofs) == 2

oneof_desc = desc.oneofs_by_name["foo"]
foobool_fld = desc.fields_by_name["foo_bool"]
assert foobool_fld.type == foobool_fld.TYPE_BOOL
assert foobool_fld.containing_oneof is oneof_desc
assert foobool_fld.number == 10

oneof_desc = desc.oneofs_by_name["bar"]
barbool_fld = desc.fields_by_name["bar_bool"]
assert barbool_fld.type == barbool_fld.TYPE_BOOL
assert barbool_fld.containing_oneof is oneof_desc
assert barbool_fld.number == 20


def test_dataclass_to_proto_no_oneof_field_name_throws(temp_dpool):
"""Make sure that a dataclasss with a single primitive field
annotated with OneOfField but without a name throws"""

@dataclass
class Foo:
foo: Union[
Annotated[bool, OneofField(), FieldNumber(10)],
] # type: ignore

with pytest.raises(
AssertionError, match="Got OneofField annotation without any name"
):
dataclass_to_proto("foo.bar", Foo, descriptor_pool=temp_dpool)


def test_dataclass_to_proto_oneof_len_one_no_annotated(temp_dpool):
"""Make sure that a dataclasss with a single primitive field
annotated without a OneOfField creates no one-ofs"""

@dataclass
class Foo:
foo: Union[
Annotated[bool, FieldNumber(10)],
] # type: ignore

desc = dataclass_to_proto("foo.bar", Foo, descriptor_pool=temp_dpool)
# since no OneofField in annotated, no oneof got created
assert len(desc.oneofs) == 0
foo_fld = desc.fields_by_name["foo"]
assert foo_fld.type == foo_fld.TYPE_BOOL
assert foo_fld.number == 10


def test_dataclass_to_proto_oneof_union_single_field(temp_dpool):
"""Make sure that a dataclasss with a union of a single field creates no one-ofs"""

@dataclass
class Foo:
foo: Union[bool] # type: ignore

desc = dataclass_to_proto("foo.bar", Foo, descriptor_pool=temp_dpool)
# A union with a single field is treated as a single field
assert len(desc.oneofs) == 0
foo_fld = desc.fields_by_name["foo"]
assert foo_fld.type == foo_fld.TYPE_BOOL


def test_dataclass_to_proto_oneof_primitives(temp_dpool):
"""Make sure that a oneof with primitive fields works correctly"""

Expand Down

0 comments on commit 8d04378

Please sign in to comment.