From 3b27d5dbad18b2849834fe5c3736ed8b9c9ae2da Mon Sep 17 00:00:00 2001 From: InSync Date: Mon, 23 Dec 2024 01:02:28 +0700 Subject: [PATCH] [red-knot] More precise inference for chained boolean expressions (#15089) ## Summary Resolves #13632. ## Test Plan Markdown tests. --- .../mdtest/comparison/non_bool_returns.md | 4 +- .../resources/mdtest/expression/boolean.md | 8 +-- .../resources/mdtest/narrow/truthiness.md | 48 ++++++++++++++++++ .../src/types/infer.rs | 50 +++++++++++-------- 4 files changed, 84 insertions(+), 26 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/comparison/non_bool_returns.md b/crates/red_knot_python_semantic/resources/mdtest/comparison/non_bool_returns.md index bc535a5acf038b..e34afd6a05e16b 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/comparison/non_bool_returns.md +++ b/crates/red_knot_python_semantic/resources/mdtest/comparison/non_bool_returns.md @@ -31,10 +31,10 @@ class C: def __lt__(self, other) -> C: ... x = A() < B() < C() -reveal_type(x) # revealed: A | B +reveal_type(x) # revealed: A & ~AlwaysTruthy | B y = 0 < 1 < A() < 3 -reveal_type(y) # revealed: bool | A +reveal_type(y) # revealed: Literal[False] | A z = 10 < 0 < A() < B() < C() reveal_type(z) # revealed: Literal[False] diff --git a/crates/red_knot_python_semantic/resources/mdtest/expression/boolean.md b/crates/red_knot_python_semantic/resources/mdtest/expression/boolean.md index 7ce689164248a5..ad9075b8f876e7 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/expression/boolean.md +++ b/crates/red_knot_python_semantic/resources/mdtest/expression/boolean.md @@ -10,8 +10,8 @@ def _(foo: str): reveal_type(False or "z") # revealed: Literal["z"] reveal_type(False or True) # revealed: Literal[True] reveal_type(False or False) # revealed: Literal[False] - reveal_type(foo or False) # revealed: str | Literal[False] - reveal_type(foo or True) # revealed: str | Literal[True] + reveal_type(foo or False) # revealed: str & ~AlwaysFalsy | Literal[False] + reveal_type(foo or True) # revealed: str & ~AlwaysFalsy | Literal[True] ``` ## AND @@ -20,8 +20,8 @@ def _(foo: str): def _(foo: str): reveal_type(True and False) # revealed: Literal[False] reveal_type(False and True) # revealed: Literal[False] - reveal_type(foo and False) # revealed: str | Literal[False] - reveal_type(foo and True) # revealed: str | Literal[True] + reveal_type(foo and False) # revealed: str & ~AlwaysTruthy | Literal[False] + reveal_type(foo and True) # revealed: str & ~AlwaysTruthy | Literal[True] reveal_type("x" and "y" and "z") # revealed: Literal["z"] reveal_type("x" and "y" and "") # revealed: Literal[""] reveal_type("" and "y") # revealed: Literal[""] diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md index 9e391e4ef55f88..a252ba32dc82fc 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/truthiness.md @@ -219,3 +219,51 @@ else: # TODO: It should be A. We should improve UnionBuilder or IntersectionBuilder. (issue #15023) reveal_type(y) # revealed: A & ~AlwaysTruthy | A & ~AlwaysFalsy ``` + +## Narrowing in chained boolean expressions + +```py +from typing import Literal + +class A: ... + +def _(x: Literal[0, 1]): + reveal_type(x or A()) # revealed: Literal[1] | A + reveal_type(x and A()) # revealed: Literal[0] | A + +def _(x: str): + reveal_type(x or A()) # revealed: str & ~AlwaysFalsy | A + reveal_type(x and A()) # revealed: str & ~AlwaysTruthy | A + +def _(x: bool | str): + reveal_type(x or A()) # revealed: Literal[True] | str & ~AlwaysFalsy | A + reveal_type(x and A()) # revealed: Literal[False] | str & ~AlwaysTruthy | A + +class Falsy: + def __bool__(self) -> Literal[False]: ... + +class Truthy: + def __bool__(self) -> Literal[True]: ... + +def _(x: Falsy | Truthy): + reveal_type(x or A()) # revealed: Truthy | A + reveal_type(x and A()) # revealed: Falsy | A + +class MetaFalsy(type): + def __bool__(self) -> Literal[False]: ... + +class MetaTruthy(type): + def __bool__(self) -> Literal[False]: ... + +class FalsyClass(metaclass=MetaFalsy): ... +class TruthyClass(metaclass=MetaTruthy): ... + +def _(x: type[FalsyClass] | type[TruthyClass]): + # TODO: Should be `type[TruthyClass] | A` + # revealed: type[FalsyClass] & ~AlwaysFalsy | type[TruthyClass] & ~AlwaysFalsy | A + reveal_type(x or A()) + + # TODO: Should be `type[FalsyClass] | A` + # revealed: type[FalsyClass] & ~AlwaysTruthy | type[TruthyClass] & ~AlwaysTruthy | A + reveal_type(x and A()) +``` diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 74cb987d79030b..d689e303383390 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -3582,27 +3582,37 @@ impl<'db> TypeInferenceBuilder<'db> { n_values: usize, ) -> Type<'db> { let mut done = false; - UnionType::from_elements( - db, - values.into_iter().enumerate().map(|(i, ty)| { - if done { - Type::Never - } else { - let is_last = i == n_values - 1; - match (ty.bool(db), is_last, op) { - (Truthiness::Ambiguous, _, _) => ty, - (Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never, - (Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never, - (Truthiness::AlwaysFalse, _, ast::BoolOp::And) - | (Truthiness::AlwaysTrue, _, ast::BoolOp::Or) => { - done = true; - ty - } - (_, true, _) => ty, - } + + let elements = values.into_iter().enumerate().map(|(i, ty)| { + if done { + return Type::Never; + } + + let is_last = i == n_values - 1; + + match (ty.bool(db), is_last, op) { + (Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never, + (Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never, + + (Truthiness::AlwaysFalse, _, ast::BoolOp::And) + | (Truthiness::AlwaysTrue, _, ast::BoolOp::Or) => { + done = true; + ty } - }), - ) + + (Truthiness::Ambiguous, false, _) => IntersectionBuilder::new(db) + .add_positive(ty) + .add_negative(match op { + ast::BoolOp::And => Type::AlwaysTruthy, + ast::BoolOp::Or => Type::AlwaysFalsy, + }) + .build(), + + (_, true, _) => ty, + } + }); + + UnionType::from_elements(db, elements) } fn infer_compare_expression(&mut self, compare: &ast::ExprCompare) -> Type<'db> {