Skip to content

Commit 88d3da4

Browse files
authored
Merge pull request #2049 from Morikko/support-dataclass-transform
Support dataclass transform
2 parents 86c3a02 + 15a7513 commit 88d3da4

File tree

5 files changed

+906
-84
lines changed

5 files changed

+906
-84
lines changed

conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,14 @@ def jedi_path():
156156
return os.path.dirname(__file__)
157157

158158

159+
@pytest.fixture()
160+
def skip_pre_python311(environment):
161+
if environment.version_info < (3, 11):
162+
# This if is just needed to avoid that tests ever skip way more than
163+
# they should for all Python versions.
164+
pytest.skip()
165+
166+
159167
@pytest.fixture()
160168
def skip_pre_python38(environment):
161169
if environment.version_info < (3, 8):

jedi/inference/value/klass.py

Lines changed: 307 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
====================================== ========================================
3737
3838
"""
39+
from __future__ import annotations
40+
41+
from typing import List, Optional, Tuple
42+
3943
from jedi import debug
4044
from jedi.parser_utils import get_cached_parent_scope, expr_is_dotted, \
4145
function_is_property
@@ -47,11 +51,15 @@
4751
from jedi.inference.names import TreeNameDefinition, ValueName
4852
from jedi.inference.arguments import unpack_arglist, ValuesArguments
4953
from jedi.inference.base_value import ValueSet, iterator_to_value_set, \
50-
NO_VALUES
54+
NO_VALUES, ValueWrapper
5155
from jedi.inference.context import ClassContext
52-
from jedi.inference.value.function import FunctionAndClassBase
56+
from jedi.inference.value.function import FunctionAndClassBase, FunctionMixin
57+
from jedi.inference.value.decorator import Decoratee
5358
from jedi.inference.gradual.generics import LazyGenericManager, TupleGenericManager
5459
from jedi.plugins import plugin_manager
60+
from inspect import Parameter
61+
from jedi.inference.names import BaseTreeParamName
62+
from jedi.inference.signature import AbstractSignature
5563

5664

5765
class ClassName(TreeNameDefinition):
@@ -129,6 +137,65 @@ def _filter(self, names):
129137
return [name for name in names if self._access_possible(name)]
130138

131139

140+
def init_param_value(arg_nodes) -> Optional[bool]:
141+
"""
142+
Returns:
143+
144+
- ``True`` if ``@dataclass(init=True)``
145+
- ``False`` if ``@dataclass(init=False)``
146+
- ``None`` if not specified ``@dataclass()``
147+
"""
148+
for arg_node in arg_nodes:
149+
if (
150+
arg_node.type == "argument"
151+
and arg_node.children[0].value == "init"
152+
):
153+
if arg_node.children[2].value == "False":
154+
return False
155+
elif arg_node.children[2].value == "True":
156+
return True
157+
158+
return None
159+
160+
161+
def get_dataclass_param_names(cls) -> List[DataclassParamName]:
162+
"""
163+
``cls`` is a :class:`ClassMixin`. The type is only documented as mypy would
164+
complain that some fields are missing.
165+
166+
.. code:: python
167+
168+
@dataclass
169+
class A:
170+
a: int
171+
b: str = "toto"
172+
173+
For the previous example, the param names would be ``a`` and ``b``.
174+
"""
175+
param_names = []
176+
filter_ = cls.as_context().get_global_filter()
177+
for name in sorted(filter_.values(), key=lambda name: name.start_pos):
178+
d = name.tree_name.get_definition()
179+
annassign = d.children[1]
180+
if d.type == 'expr_stmt' and annassign.type == 'annassign':
181+
node = annassign.children[1]
182+
if node.type == "atom_expr" and node.children[0].value == "ClassVar":
183+
continue
184+
185+
if len(annassign.children) < 4:
186+
default = None
187+
else:
188+
default = annassign.children[3]
189+
190+
param_names.append(DataclassParamName(
191+
parent_context=cls.parent_context,
192+
tree_name=name.tree_name,
193+
annotation_node=annassign.children[1],
194+
default_node=default,
195+
))
196+
return param_names
197+
198+
132199
class ClassMixin:
133200
def is_class(self):
134201
return True
@@ -221,6 +288,73 @@ def get_filters(self, origin_scope=None, is_instance=False,
221288
assert x is not None
222289
yield x
223290

291+
def _has_dataclass_transform_metaclasses(self) -> Tuple[bool, Optional[bool]]:
292+
for meta in self.get_metaclasses(): # type: ignore[attr-defined]
293+
if (
294+
isinstance(meta, Decoratee)
295+
# Internal leakage :|
296+
and isinstance(meta._wrapped_value, DataclassTransformer)
297+
):
298+
return True, meta._wrapped_value.init_mode_from_new()
299+
300+
return False, None
301+
302+
def _get_dataclass_transform_signatures(self) -> List[DataclassSignature]:
303+
"""
304+
Returns: A non-empty list if the class has dataclass semantics else an
305+
empty list.
306+
307+
The dataclass-like semantics will be assumed for any class that directly
308+
or indirectly derives from the decorated class or uses the decorated
309+
class as a metaclass.
310+
"""
311+
param_names = []
312+
is_dataclass_transform = False
313+
default_init_mode: Optional[bool] = None
314+
for cls in reversed(list(self.py__mro__())):
315+
if not is_dataclass_transform:
316+
317+
# If dataclass_transform is applied to a class, dataclass-like semantics
318+
# will be assumed for any class that directly or indirectly derives from
319+
# the decorated class or uses the decorated class as a metaclass.
320+
if (
321+
isinstance(cls, DataclassTransformer)
322+
and cls.init_mode_from_init_subclass
323+
):
324+
is_dataclass_transform = True
325+
default_init_mode = cls.init_mode_from_init_subclass
326+
327+
elif (
328+
# Some object like CompiledValues would not be compatible
329+
isinstance(cls, ClassMixin)
330+
):
331+
is_dataclass_transform, default_init_mode = (
332+
cls._has_dataclass_transform_metaclasses()
333+
)
334+
335+
# Attributes on the decorated class and its base classes are not
336+
# considered to be fields.
337+
if is_dataclass_transform:
338+
continue
339+
340+
# All inherited classes behave like dataclass semantics
341+
if (
342+
is_dataclass_transform
343+
and isinstance(cls, ClassValue)
344+
and (
345+
cls.init_param_mode()
346+
or (cls.init_param_mode() is None and default_init_mode)
347+
)
348+
):
349+
param_names.extend(
350+
get_dataclass_param_names(cls)
351+
)
352+
353+
if is_dataclass_transform:
354+
return [DataclassSignature(cls, param_names)]
355+
else:
356+
return []
357+
224358
def get_signatures(self):
225359
# Since calling staticmethod without a function is illegal, the Jedi
226360
# plugin doesn't return anything. Therefore call directly and get what
@@ -232,7 +366,12 @@ def get_signatures(self):
232366
return sigs
233367
args = ValuesArguments([])
234368
init_funcs = self.py__call__(args).py__getattribute__('__init__')
235-
return [sig.bind(self) for sig in init_funcs.get_signatures()]
369+
370+
dataclass_sigs = self._get_dataclass_transform_signatures()
371+
if dataclass_sigs:
372+
return dataclass_sigs
373+
else:
374+
return [sig.bind(self) for sig in init_funcs.get_signatures()]
236375

237376
def _as_context(self):
238377
return ClassContext(self)
@@ -319,6 +458,158 @@ def iter(iterable: Iterable[_T]) -> Iterator[_T]: ...
319458
return ValueSet({self})
320459

321460

461+
class DataclassParamName(BaseTreeParamName):
462+
"""
463+
Represent a field declaration on a class with dataclass semantics.
464+
"""
465+
466+
def __init__(self, parent_context, tree_name, annotation_node, default_node):
467+
super().__init__(parent_context, tree_name)
468+
self.annotation_node = annotation_node
469+
self.default_node = default_node
470+
471+
def get_kind(self):
472+
return Parameter.POSITIONAL_OR_KEYWORD
473+
474+
def infer(self):
475+
if self.annotation_node is None:
476+
return NO_VALUES
477+
else:
478+
return self.parent_context.infer_node(self.annotation_node)
479+
480+
481+
class DataclassSignature(AbstractSignature):
482+
"""
483+
It represents the ``__init__`` signature of a class with dataclass semantics.
484+
485+
.. code:: python
486+
487+
"""
488+
def __init__(self, value, param_names):
489+
super().__init__(value)
490+
self._param_names = param_names
491+
492+
def get_param_names(self, resolve_stars=False):
493+
return self._param_names
494+
495+
496+
class DataclassDecorator(ValueWrapper, FunctionMixin):
497+
"""
498+
A dataclass(-like) decorator with custom parameters.
499+
500+
.. code:: python
501+
502+
@dataclass(init=True) # this
503+
class A: ...
504+
505+
@dataclass_transform
506+
def create_model(*, init=False): pass
507+
508+
@create_model(init=False) # or this
509+
class B: ...
510+
"""
511+
512+
def __init__(self, function, arguments, default_init: bool = True):
513+
"""
514+
Args:
515+
function: Decoratee | function
516+
arguments: The parameters to the dataclass function decorator
517+
default_init: Boolean to indicate the default init value
518+
"""
519+
super().__init__(function)
520+
argument_init = self._init_param_value(arguments)
521+
self.init_param_mode = (
522+
argument_init if argument_init is not None else default_init
523+
)
524+
525+
def _init_param_value(self, arguments) -> Optional[bool]:
526+
if not arguments.argument_node:
527+
return None
528+
529+
arg_nodes = (
530+
arguments.argument_node.children
531+
if arguments.argument_node.type == "arglist"
532+
else [arguments.argument_node]
533+
)
534+
535+
return init_param_value(arg_nodes)
536+
537+
538+
class DataclassTransformer(ValueWrapper, ClassMixin):
539+
"""
540+
A class decorated with the ``dataclass_transform`` decorator. dataclass-like
541+
semantics will be assumed for any class that directly or indirectly derives
542+
from the decorated class or uses the decorated class as a metaclass.
543+
Attributes on the decorated class and its base classes are not considered to
544+
be fields.
545+
"""
546+
def __init__(self, wrapped_value):
547+
super().__init__(wrapped_value)
548+
549+
def init_mode_from_new(self) -> bool:
550+
"""Default value if missing is ``True``"""
551+
new_methods = self._wrapped_value.py__getattribute__("__new__")
552+
553+
if not new_methods:
554+
return True
555+
556+
new_method = list(new_methods)[0]
557+
558+
for param in new_method.get_param_names():
559+
if (
560+
param.string_name == "init"
561+
and param.default_node
562+
and param.default_node.type == "keyword"
563+
):
564+
if param.default_node.value == "False":
565+
return False
566+
elif param.default_node.value == "True":
567+
return True
568+
569+
return True
570+
571+
@property
572+
def init_mode_from_init_subclass(self) -> Optional[bool]:
573+
# def __init_subclass__(cls) -> None: ... is hardcoded in the typeshed
574+
# so the extra parameters can not be inferred.
575+
return True
576+
577+
578+
class DataclassWrapper(ValueWrapper, ClassMixin):
579+
"""
580+
A class with dataclass semantics from a decorator. The init parameters are
581+
only from the current class and parent classes decorated where the ``init``
582+
parameter was ``True``.
583+
584+
.. code:: python
585+
586+
@dataclass
587+
class A: ... # this
588+
589+
@dataclass_transform
590+
def create_model(): pass
591+
592+
@create_model()
593+
class B: ... # or this
594+
"""
595+
596+
def __init__(
597+
self, wrapped_value, should_generate_init: bool
598+
):
599+
super().__init__(wrapped_value)
600+
self.should_generate_init = should_generate_init
601+
602+
def get_signatures(self):
603+
param_names = []
604+
for cls in reversed(list(self.py__mro__())):
605+
if (
606+
isinstance(cls, DataclassWrapper)
607+
and cls.should_generate_init
608+
):
609+
param_names.extend(get_dataclass_param_names(cls))
610+
return [DataclassSignature(cls, param_names)]
611+
612+
322613
class ClassValue(ClassMixin, FunctionAndClassBase, metaclass=CachedMetaClass):
323614
api_type = 'class'
324615

@@ -385,6 +676,19 @@ def get_metaclasses(self):
385676
return values
386677
return NO_VALUES
387678

679+
def init_param_mode(self) -> Optional[bool]:
680+
"""
681+
It returns ``True`` if ``class X(init=False):`` else ``False``.
682+
"""
683+
bases_arguments = self._get_bases_arguments()
684+
685+
if bases_arguments.argument_node.type != "arglist":
686+
# If it is not inheriting from the base model and having
687+
# extra parameters, then init behavior is not changed.
688+
return None
689+
690+
return init_param_value(bases_arguments.argument_node.children)
691+
388692
@plugin_manager.decorate()
389693
def get_metaclass_signatures(self, metaclasses):
390694
return []

0 commit comments

Comments
 (0)