diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/bool-call.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/bool-call.md new file mode 100644 index 0000000000000..d7ae47b4fdd07 --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/bool-call.md @@ -0,0 +1,32 @@ +## Narrowing for `bool(..)` checks + +```py +def flag() -> bool: ... + +x = 1 if flag() else None + +# valid invocation, positive +reveal_type(x) # revealed: Literal[1] | None +if bool(x is not None): + reveal_type(x) # revealed: Literal[1] + +# valid invocation, negative +reveal_type(x) # revealed: Literal[1] | None +if not bool(x is not None): + reveal_type(x) # revealed: None + +# no args/narrowing +reveal_type(x) # revealed: Literal[1] | None +if not bool(): + reveal_type(x) # revealed: Literal[1] | None + +# invalid invocation, too many positional args +reveal_type(x) # revealed: Literal[1] | None +if bool(x is not None, 5): # TODO diagnostic + reveal_type(x) # revealed: Literal[1] | None + +# invalid invocation, too many kwargs +reveal_type(x) # revealed: Literal[1] | None +if bool(x is not None, y=5): # TODO diagnostic + reveal_type(x) # revealed: Literal[1] | None +``` diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 7ce84cb7fb553..6005ab781acc2 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -385,46 +385,58 @@ impl<'db> NarrowingConstraintsBuilder<'db> { let scope = self.scope(); let inference = infer_expression_types(self.db, expression); + let callable_ty = + inference.expression_ty(expr_call.func.scoped_expression_id(self.db, scope)); + // TODO: add support for PEP 604 union types on the right hand side of `isinstance` // and `issubclass`, for example `isinstance(x, str | (int | float))`. - match inference - .expression_ty(expr_call.func.scoped_expression_id(self.db, scope)) - .into_function_literal() - .and_then(|f| f.known(self.db)) - .and_then(KnownFunction::constraint_function) - { - Some(function) if expr_call.arguments.keywords.is_empty() => { - if let [ast::Expr::Name(ast::ExprName { id, .. }), class_info] = + match callable_ty { + Type::FunctionLiteral(function_type) if expr_call.arguments.keywords.is_empty() => { + let function = function_type + .known(self.db) + .and_then(KnownFunction::constraint_function)?; + + let [ast::Expr::Name(ast::ExprName { id, .. }), class_info] = &*expr_call.arguments.args - { - let symbol = self.symbols().symbol_id_by_name(id).unwrap(); + else { + return None; + }; - let class_info_ty = - inference.expression_ty(class_info.scoped_expression_id(self.db, scope)); + let symbol = self.symbols().symbol_id_by_name(id).unwrap(); - let to_constraint = match function { - KnownConstraintFunction::IsInstance => { - |class_literal: ClassLiteralType<'db>| { - Type::instance(class_literal.class) - } - } - KnownConstraintFunction::IsSubclass => { - |class_literal: ClassLiteralType<'db>| { - Type::subclass_of(class_literal.class) - } - } - }; + let class_info_ty = + inference.expression_ty(class_info.scoped_expression_id(self.db, scope)); - generate_classinfo_constraint(self.db, &class_info_ty, to_constraint).map( - |constraint| { - let mut constraints = NarrowingConstraints::default(); - constraints.insert(symbol, constraint.negate_if(self.db, !is_positive)); - constraints - }, - ) - } else { - None - } + let to_constraint = match function { + KnownConstraintFunction::IsInstance => { + |class_literal: ClassLiteralType<'db>| Type::instance(class_literal.class) + } + KnownConstraintFunction::IsSubclass => { + |class_literal: ClassLiteralType<'db>| { + Type::subclass_of(class_literal.class) + } + } + }; + + generate_classinfo_constraint(self.db, &class_info_ty, to_constraint).map( + |constraint| { + let mut constraints = NarrowingConstraints::default(); + constraints.insert(symbol, constraint.negate_if(self.db, !is_positive)); + constraints + }, + ) + } + // for the expression `bool(E)`, we further narrow the type based on `E` + Type::ClassLiteral(class_type) + if expr_call.arguments.args.len() == 1 + && expr_call.arguments.keywords.is_empty() + && class_type.class.is_known(self.db, KnownClass::Bool) => + { + self.evaluate_expression_node_constraint( + &expr_call.arguments.args[0], + expression, + is_positive, + ) } _ => None, }