diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/unsupported.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/unsupported.md index 237bf5a6b2208..cb5fab1840105 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/comparison/unsupported.md +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/unsupported.md @@ -7,11 +7,21 @@ reveal_type(a) # revealed: bool b = 0 not in 10 # error: "Operator `not in` is not supported for types `Literal[0]` and `Literal[10]`" reveal_type(b) # revealed: bool -c = object() < 5 # error: "Operator `<` is not supported for types `object` and `Literal[5]`" +c = object() < 5 # error: "Operator `<` is not supported for types `object` and `int`" reveal_type(c) # revealed: Unknown # TODO should error, need to check if __lt__ signature is valid for right operand d = 5 < object() # TODO: should be `Unknown` reveal_type(d) # revealed: bool + +int_literal_or_str_literal = 1 if flag else "foo" +# error: "Operator `in` is not supported for types `Literal[42]` and `Literal[1]`, in comparing `Literal[42]` with `Literal[1] | Literal["foo"]`" +e = 42 in int_literal_or_str_literal +reveal_type(e) # revealed: bool + +# TODO: should error, need to check if __lt__ signature is valid for right operand +# error may be "Operator `<` is not supported for types `int` and `str`, in comparing `tuple[Literal[1], Literal[2]]` with `tuple[Literal[1], Literal["hello"]]` +f = (1, 2) < (1, "hello") +reveal_type(f) # revealed: @Todo ``` diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 3e0edef49bed8..90b56788c720b 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -2776,18 +2776,28 @@ impl<'db> TypeInferenceBuilder<'db> { let right_ty = self.expression_ty(right); self.infer_binary_type_comparison(left_ty, *op, right_ty) - .unwrap_or_else(|| { + .unwrap_or_else(|error| { // Handle unsupported operators (diagnostic, `bool`/`Unknown` outcome) self.add_diagnostic( AnyNodeRef::ExprCompare(compare), "operator-unsupported", format_args!( - "Operator `{}` is not supported for types `{}` and `{}`", - op, - left_ty.display(self.db), - right_ty.display(self.db) + "Operator `{}` is not supported for types `{}` and `{}`{}", + error.op, + error.left_ty.display(self.db), + error.right_ty.display(self.db), + if (left_ty, right_ty) == (error.left_ty, error.right_ty) { + String::new() + } else { + format!( + ", in comparing `{}` with `{}`", + left_ty.display(self.db), + right_ty.display(self.db) + ) + } ), ); + match op { // `in, not in, is, is not` always return bool instances ast::CmpOp::In @@ -2814,7 +2824,7 @@ impl<'db> TypeInferenceBuilder<'db> { left: Type<'db>, op: ast::CmpOp, right: Type<'db>, - ) -> Option> { + ) -> Result, CompareUnsupportedError<'db>> { // Note: identity (is, is not) for equal builtin types is unreliable and not part of the // language spec. // - `[ast::CompOp::Is]`: return `false` if unequal, `bool` if equal @@ -2825,39 +2835,43 @@ impl<'db> TypeInferenceBuilder<'db> { for element in union.elements(self.db) { builder = builder.add(self.infer_binary_type_comparison(*element, op, other)?); } - Some(builder.build()) + Ok(builder.build()) } (other, Type::Union(union)) => { let mut builder = UnionBuilder::new(self.db); for element in union.elements(self.db) { builder = builder.add(self.infer_binary_type_comparison(other, op, *element)?); } - Some(builder.build()) + Ok(builder.build()) } (Type::IntLiteral(n), Type::IntLiteral(m)) => match op { - ast::CmpOp::Eq => Some(Type::BooleanLiteral(n == m)), - ast::CmpOp::NotEq => Some(Type::BooleanLiteral(n != m)), - ast::CmpOp::Lt => Some(Type::BooleanLiteral(n < m)), - ast::CmpOp::LtE => Some(Type::BooleanLiteral(n <= m)), - ast::CmpOp::Gt => Some(Type::BooleanLiteral(n > m)), - ast::CmpOp::GtE => Some(Type::BooleanLiteral(n >= m)), + ast::CmpOp::Eq => Ok(Type::BooleanLiteral(n == m)), + ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(n != m)), + ast::CmpOp::Lt => Ok(Type::BooleanLiteral(n < m)), + ast::CmpOp::LtE => Ok(Type::BooleanLiteral(n <= m)), + ast::CmpOp::Gt => Ok(Type::BooleanLiteral(n > m)), + ast::CmpOp::GtE => Ok(Type::BooleanLiteral(n >= m)), ast::CmpOp::Is => { if n == m { - Some(KnownClass::Bool.to_instance(self.db)) + Ok(KnownClass::Bool.to_instance(self.db)) } else { - Some(Type::BooleanLiteral(false)) + Ok(Type::BooleanLiteral(false)) } } ast::CmpOp::IsNot => { if n == m { - Some(KnownClass::Bool.to_instance(self.db)) + Ok(KnownClass::Bool.to_instance(self.db)) } else { - Some(Type::BooleanLiteral(true)) + Ok(Type::BooleanLiteral(true)) } } // Undefined for (int, int) - ast::CmpOp::In | ast::CmpOp::NotIn => None, + ast::CmpOp::In | ast::CmpOp::NotIn => Err(CompareUnsupportedError { + op, + left_ty: left, + right_ty: right, + }), }, (Type::IntLiteral(_), Type::Instance(_)) => { self.infer_binary_type_comparison(KnownClass::Int.to_instance(self.db), op, right) @@ -2888,26 +2902,26 @@ impl<'db> TypeInferenceBuilder<'db> { let s1 = salsa_s1.value(self.db); let s2 = salsa_s2.value(self.db); match op { - ast::CmpOp::Eq => Some(Type::BooleanLiteral(s1 == s2)), - ast::CmpOp::NotEq => Some(Type::BooleanLiteral(s1 != s2)), - ast::CmpOp::Lt => Some(Type::BooleanLiteral(s1 < s2)), - ast::CmpOp::LtE => Some(Type::BooleanLiteral(s1 <= s2)), - ast::CmpOp::Gt => Some(Type::BooleanLiteral(s1 > s2)), - ast::CmpOp::GtE => Some(Type::BooleanLiteral(s1 >= s2)), - ast::CmpOp::In => Some(Type::BooleanLiteral(s2.contains(s1.as_ref()))), - ast::CmpOp::NotIn => Some(Type::BooleanLiteral(!s2.contains(s1.as_ref()))), + ast::CmpOp::Eq => Ok(Type::BooleanLiteral(s1 == s2)), + ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(s1 != s2)), + ast::CmpOp::Lt => Ok(Type::BooleanLiteral(s1 < s2)), + ast::CmpOp::LtE => Ok(Type::BooleanLiteral(s1 <= s2)), + ast::CmpOp::Gt => Ok(Type::BooleanLiteral(s1 > s2)), + ast::CmpOp::GtE => Ok(Type::BooleanLiteral(s1 >= s2)), + ast::CmpOp::In => Ok(Type::BooleanLiteral(s2.contains(s1.as_ref()))), + ast::CmpOp::NotIn => Ok(Type::BooleanLiteral(!s2.contains(s1.as_ref()))), ast::CmpOp::Is => { if s1 == s2 { - Some(KnownClass::Bool.to_instance(self.db)) + Ok(KnownClass::Bool.to_instance(self.db)) } else { - Some(Type::BooleanLiteral(false)) + Ok(Type::BooleanLiteral(false)) } } ast::CmpOp::IsNot => { if s1 == s2 { - Some(KnownClass::Bool.to_instance(self.db)) + Ok(KnownClass::Bool.to_instance(self.db)) } else { - Some(Type::BooleanLiteral(true)) + Ok(Type::BooleanLiteral(true)) } } } @@ -2930,30 +2944,30 @@ impl<'db> TypeInferenceBuilder<'db> { let b1 = &**salsa_b1.value(self.db); let b2 = &**salsa_b2.value(self.db); match op { - ast::CmpOp::Eq => Some(Type::BooleanLiteral(b1 == b2)), - ast::CmpOp::NotEq => Some(Type::BooleanLiteral(b1 != b2)), - ast::CmpOp::Lt => Some(Type::BooleanLiteral(b1 < b2)), - ast::CmpOp::LtE => Some(Type::BooleanLiteral(b1 <= b2)), - ast::CmpOp::Gt => Some(Type::BooleanLiteral(b1 > b2)), - ast::CmpOp::GtE => Some(Type::BooleanLiteral(b1 >= b2)), + ast::CmpOp::Eq => Ok(Type::BooleanLiteral(b1 == b2)), + ast::CmpOp::NotEq => Ok(Type::BooleanLiteral(b1 != b2)), + ast::CmpOp::Lt => Ok(Type::BooleanLiteral(b1 < b2)), + ast::CmpOp::LtE => Ok(Type::BooleanLiteral(b1 <= b2)), + ast::CmpOp::Gt => Ok(Type::BooleanLiteral(b1 > b2)), + ast::CmpOp::GtE => Ok(Type::BooleanLiteral(b1 >= b2)), ast::CmpOp::In => { - Some(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_some())) + Ok(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_some())) } ast::CmpOp::NotIn => { - Some(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_none())) + Ok(Type::BooleanLiteral(memchr::memmem::find(b2, b1).is_none())) } ast::CmpOp::Is => { if b1 == b2 { - Some(KnownClass::Bool.to_instance(self.db)) + Ok(KnownClass::Bool.to_instance(self.db)) } else { - Some(Type::BooleanLiteral(false)) + Ok(Type::BooleanLiteral(false)) } } ast::CmpOp::IsNot => { if b1 == b2 { - Some(KnownClass::Bool.to_instance(self.db)) + Ok(KnownClass::Bool.to_instance(self.db)) } else { - Some(Type::BooleanLiteral(true)) + Ok(Type::BooleanLiteral(true)) } } } @@ -2991,7 +3005,7 @@ impl<'db> TypeInferenceBuilder<'db> { ).expect("infer_binary_type_comparison should never return None for `CmpOp::Eq`"); match eq_result { - Type::Todo => return Some(Type::Todo), + Type::Todo => return Ok(Type::Todo), ty => match ty.bool(self.db) { Truthiness::AlwaysTrue => eq_count += 1, Truthiness::AlwaysFalse => not_eq_count += 1, @@ -3001,11 +3015,11 @@ impl<'db> TypeInferenceBuilder<'db> { } if eq_count >= 1 { - Some(Type::BooleanLiteral(op.is_in())) + Ok(Type::BooleanLiteral(op.is_in())) } else if not_eq_count == rhs_elements.len() { - Some(Type::BooleanLiteral(op.is_not_in())) + Ok(Type::BooleanLiteral(op.is_not_in())) } else { - Some(KnownClass::Bool.to_instance(self.db)) + Ok(KnownClass::Bool.to_instance(self.db)) } } ast::CmpOp::Is | ast::CmpOp::IsNot => { @@ -3016,7 +3030,7 @@ impl<'db> TypeInferenceBuilder<'db> { "infer_binary_type_comparison should never return None for `CmpOp::Eq`", ); - Some(match eq_result { + Ok(match eq_result { Type::Todo => Type::Todo, ty => match ty.bool(self.db) { Truthiness::AlwaysFalse => Type::BooleanLiteral(op.is_is_not()), @@ -3029,16 +3043,19 @@ impl<'db> TypeInferenceBuilder<'db> { // Lookup the rich comparison `__dunder__` methods on instances (Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op { - ast::CmpOp::Lt => { - perform_rich_comparison(self.db, left_class_ty, right_class_ty, "__lt__") - } + ast::CmpOp::Lt => perform_rich_comparison( + self.db, + left_class_ty, + right_class_ty, + RichCompareOperator::Lt, + ), // TODO: implement mapping from `ast::CmpOp` to rich comparison methods - _ => Some(Type::Todo), + _ => Ok(Type::Todo), }, // TODO: handle more types _ => match op { - ast::CmpOp::Is | ast::CmpOp::IsNot => Some(KnownClass::Bool.to_instance(self.db)), - _ => Some(Type::Todo), + ast::CmpOp::Is | ast::CmpOp::IsNot => Ok(KnownClass::Bool.to_instance(self.db)), + _ => Ok(Type::Todo), }, } } @@ -3053,7 +3070,7 @@ impl<'db> TypeInferenceBuilder<'db> { left: &[Type<'db>], op: RichCompareOperator, right: &[Type<'db>], - ) -> Option> { + ) -> Result, CompareUnsupportedError<'db>> { // Compare paired elements from left and right slices for (l_ty, r_ty) in left.iter().copied().zip(right.iter().copied()) { let eq_result = self @@ -3062,7 +3079,7 @@ impl<'db> TypeInferenceBuilder<'db> { match eq_result { // If propagation is required, return the result as is - Type::Todo => return Some(Type::Todo), + Type::Todo => return Ok(Type::Todo), ty => match ty.bool(self.db) { // Types are equal, continue to the next pair Truthiness::AlwaysTrue => continue, @@ -3072,7 +3089,7 @@ impl<'db> TypeInferenceBuilder<'db> { } // If the intermediate result is ambiguous, we cannot determine the final result as BooleanLiteral. // In this case, we simply return a bool instance. - Truthiness::Ambiguous => return Some(KnownClass::Bool.to_instance(self.db)), + Truthiness::Ambiguous => return Ok(KnownClass::Bool.to_instance(self.db)), }, } } @@ -3082,7 +3099,7 @@ impl<'db> TypeInferenceBuilder<'db> { // We return a comparison of the slice lengths based on the operator. let (left_len, right_len) = (left.len(), right.len()); - Some(Type::BooleanLiteral(match op { + Ok(Type::BooleanLiteral(match op { RichCompareOperator::Eq => left_len == right_len, RichCompareOperator::Ne => left_len != right_len, RichCompareOperator::Lt => left_len < right_len, @@ -3556,6 +3573,26 @@ impl From for ast::CmpOp { } } +impl RichCompareOperator { + const fn dunder_name(self) -> &'static str { + match self { + RichCompareOperator::Eq => "__eq__", + RichCompareOperator::Ne => "__ne__", + RichCompareOperator::Lt => "__lt__", + RichCompareOperator::Le => "__le__", + RichCompareOperator::Gt => "__gt__", + RichCompareOperator::Ge => "__ge__", + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct CompareUnsupportedError<'db> { + op: ast::CmpOp, + left_ty: Type<'db>, + right_ty: Type<'db>, +} + fn format_import_from_module(level: u32, module: Option<&str>) -> String { format!( "{}{}", @@ -3636,8 +3673,8 @@ fn perform_rich_comparison<'db>( db: &'db dyn Db, left: ClassType<'db>, right: ClassType<'db>, - dunder_name: &str, -) -> Option> { + op: RichCompareOperator, +) -> Result, CompareUnsupportedError<'db>> { // The following resource has details about the rich comparison algorithm: // https://snarky.ca/unravelling-rich-comparison-operators/ // @@ -3645,17 +3682,26 @@ fn perform_rich_comparison<'db>( // l.h.s. // TODO: `object.__ne__` will call `__eq__` if `__ne__` is not defined - let dunder = left.class_member(db, dunder_name); + let dunder = left.class_member(db, op.dunder_name()); if !dunder.is_unbound() { // TODO: this currently gives the return type even if the arg types are invalid // (e.g. int.__lt__ with string instance should be None, currently bool) return dunder .call(db, &[Type::Instance(left), Type::Instance(right)]) - .return_ty(db); + .return_ty(db) + .ok_or_else(|| CompareUnsupportedError { + op: op.into(), + left_ty: Type::Instance(left), + right_ty: Type::Instance(right), + }); } // TODO: reflected dunder -- (==, ==), (!=, !=), (<, >), (>, <), (<=, >=), (>=, <=) - None + Err(CompareUnsupportedError { + op: op.into(), + left_ty: Type::Instance(left), + right_ty: Type::Instance(right), + }) } #[cfg(test)]