diff --git a/crates/red_knot_python_semantic/resources/mdtest/assignment/augmented.md b/crates/red_knot_python_semantic/resources/mdtest/assignment/augmented.md index 65dfcef23ad52..316d5bdfdab4a 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/assignment/augmented.md +++ b/crates/red_knot_python_semantic/resources/mdtest/assignment/augmented.md @@ -18,6 +18,14 @@ class C: x = C() x -= 1 reveal_type(x) # revealed: str + +class C: + def __iadd__(self, other: str) -> float: + return "Hello, world!" + +x = C() +x += "Hello" +reveal_type(x) # revealed: float ``` ## Unsupported types diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 5080bd27c45f7..e6ed20fb5be0d 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -33,7 +33,7 @@ use itertools::Itertools; use ruff_db::files::File; use ruff_db::parsed::parsed_module; use ruff_python_ast::name::Name; -use ruff_python_ast::{self as ast, AnyNodeRef, Expr, ExprContext, Operator, UnaryOp}; +use ruff_python_ast::{self as ast, AnyNodeRef, Expr, ExprContext, UnaryOp}; use ruff_text_size::Ranged; use rustc_hash::FxHashMap; use salsa; @@ -1425,10 +1425,9 @@ impl<'db> TypeInferenceBuilder<'db> { }; let value_type = self.infer_expression(value); - // TODO(charlie): Add remaining branches for different types of augmented assignments. - if let (Operator::Sub, Type::Instance(class)) = (*op, target_type) { - let class_member = class.class_member(self.db, "__isub__"); - let call = class_member.call(self.db, &[value_type]); + if let Type::Instance(class) = target_type { + let class_member = class.class_member(self.db, op.in_place_dunder()); + let call = class_member.call(self.db, &[target_type, value_type]); return match call.return_ty_result(self.db, AnyNodeRef::StmtAugAssign(assignment), self) { @@ -2516,9 +2515,7 @@ impl<'db> TypeInferenceBuilder<'db> { } = attribute; let value_ty = self.infer_expression(value); - let member_ty = value_ty.member(self.db, &Name::new(&attr.id)); - - member_ty + value_ty.member(self.db, &Name::new(&attr.id)) } fn infer_attribute_expression(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> { diff --git a/crates/ruff_benchmark/benches/red_knot.rs b/crates/ruff_benchmark/benches/red_knot.rs index 6e52ccc409965..b059df6af7b6e 100644 --- a/crates/ruff_benchmark/benches/red_knot.rs +++ b/crates/ruff_benchmark/benches/red_knot.rs @@ -28,22 +28,30 @@ static EXPECTED_DIAGNOSTICS: &[&str] = &[ // We don't support terminal statements in control flow yet: "/src/tomllib/_parser.py:66:18: Name `s` used when possibly not defined", "/src/tomllib/_parser.py:98:12: Name `char` used when possibly not defined", + "/src/tomllib/_parser.py:99:13: Operator `+=` is unsupported for type `int` with type `Literal[1]`", "/src/tomllib/_parser.py:101:12: Name `char` used when possibly not defined", "/src/tomllib/_parser.py:104:14: Name `char` used when possibly not defined", "/src/tomllib/_parser.py:108:17: Conflicting declared types for `second_char`: Unknown, str | None", "/src/tomllib/_parser.py:115:14: Name `char` used when possibly not defined", "/src/tomllib/_parser.py:126:12: Name `char` used when possibly not defined", + "/src/tomllib/_parser.py:130:9: Operator `+=` is unsupported for type `int` with type `Literal[1]`", "/src/tomllib/_parser.py:267:9: Conflicting declared types for `char`: Unknown, str | None", "/src/tomllib/_parser.py:348:20: Name `nest` used when possibly not defined", "/src/tomllib/_parser.py:353:5: Name `nest` used when possibly not defined", "/src/tomllib/_parser.py:353:5: Method `__getitem__` of type `Unbound | @Todo` is not callable on object of type `Unbound | @Todo`", "/src/tomllib/_parser.py:364:9: Conflicting declared types for `char`: Unknown, str | None", + "/src/tomllib/_parser.py:367:5: Operator `+=` is unsupported for type `int` with type `Literal[1]`", "/src/tomllib/_parser.py:381:13: Conflicting declared types for `char`: Unknown, str | None", + "/src/tomllib/_parser.py:384:9: Operator `+=` is unsupported for type `int` with type `Literal[1]`", "/src/tomllib/_parser.py:395:9: Conflicting declared types for `char`: Unknown, str | None", + "/src/tomllib/_parser.py:429:9: Operator `+=` is unsupported for type `int` with type `Literal[1]`", "/src/tomllib/_parser.py:453:24: Name `nest` used when possibly not defined", "/src/tomllib/_parser.py:455:9: Name `nest` used when possibly not defined", "/src/tomllib/_parser.py:455:9: Method `__getitem__` of type `Unbound | @Todo` is not callable on object of type `Unbound | @Todo`", + "/src/tomllib/_parser.py:464:9: Operator `+=` is unsupported for type `int` with type `Literal[1]`", "/src/tomllib/_parser.py:482:16: Name `char` used when possibly not defined", + "/src/tomllib/_parser.py:484:13: Operator `+=` is unsupported for type `int` with type `Literal[1]`", + "/src/tomllib/_parser.py:545:5: Operator `+=` is unsupported for type `int` with type `Literal[1]`", "/src/tomllib/_parser.py:566:12: Name `char` used when possibly not defined", "/src/tomllib/_parser.py:573:12: Name `char` used when possibly not defined", "/src/tomllib/_parser.py:579:12: Name `char` used when possibly not defined", diff --git a/crates/ruff_python_ast/src/nodes.rs b/crates/ruff_python_ast/src/nodes.rs index a35868d6a0371..b0eb7f3e8a423 100644 --- a/crates/ruff_python_ast/src/nodes.rs +++ b/crates/ruff_python_ast/src/nodes.rs @@ -2972,6 +2972,7 @@ impl Operator { } } + /// Returns the dunder method name for the operator. pub const fn dunder(self) -> &'static str { match self { Operator::Add => "__add__", @@ -2990,6 +2991,26 @@ impl Operator { } } + /// Returns the in-place dunder method name for the operator. + pub const fn in_place_dunder(self) -> &'static str { + match self { + Operator::Add => "__iadd__", + Operator::Sub => "__isub__", + Operator::Mult => "__imul__", + Operator::MatMult => "__imatmul__", + Operator::Div => "__itruediv__", + Operator::Mod => "__imod__", + Operator::Pow => "__ipow__", + Operator::LShift => "__ilshift__", + Operator::RShift => "__irshift__", + Operator::BitOr => "__ior__", + Operator::BitXor => "__ixor__", + Operator::BitAnd => "__iand__", + Operator::FloorDiv => "__ifloordiv__", + } + } + + /// Returns the reflected dunder method name for the operator. pub const fn reflected_dunder(self) -> &'static str { match self { Operator::Add => "__radd__",