Skip to content

Commit

Permalink
Address review? But now it panics
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood committed Aug 23, 2024
1 parent 88858c6 commit c81aa5c
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 53 deletions.
12 changes: 9 additions & 3 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
use crate::semantic_index::ast_ids::AstIdsBuilder;
use crate::semantic_index::definition::{
AssignmentDefinitionNodeRef, ComprehensionDefinitionNodeRef, Definition, DefinitionNodeKey,
DefinitionNodeRef, ImportFromDefinitionNodeRef,
DefinitionNodeRef, ForStmtDefinitionNodeRef, ImportFromDefinitionNodeRef,
};
use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::{
Expand Down Expand Up @@ -644,8 +644,14 @@ where
Some(CurrentAssignment::AugAssign(aug_assign)) => {
self.add_definition(symbol, aug_assign);
}
Some(CurrentAssignment::For(for_stmt)) => {
self.add_definition(symbol, for_stmt);
Some(CurrentAssignment::For(node)) => {
self.add_definition(
symbol,
ForStmtDefinitionNodeRef {
iterable: &node.iter,
target: name_node,
},
);
}
Some(CurrentAssignment::Named(named)) => {
// TODO(dhruvmanila): If the current scope is a comprehension, then the
Expand Down
50 changes: 39 additions & 11 deletions crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl<'db> Definition<'db> {
pub(crate) enum DefinitionNodeRef<'a> {
Import(&'a ast::Alias),
ImportFrom(ImportFromDefinitionNodeRef<'a>),
For(&'a ast::StmtFor),
For(ForStmtDefinitionNodeRef<'a>),
Function(&'a ast::StmtFunctionDef),
Class(&'a ast::StmtClassDef),
NamedExpression(&'a ast::ExprNamed),
Expand All @@ -51,12 +51,6 @@ pub(crate) enum DefinitionNodeRef<'a> {
WithItem(WithItemDefinitionNodeRef<'a>),
}

impl<'a> From<&'a ast::StmtFor> for DefinitionNodeRef<'a> {
fn from(value: &'a ast::StmtFor) -> Self {
Self::For(value)
}
}

impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::StmtFunctionDef) -> Self {
Self::Function(node)
Expand Down Expand Up @@ -99,6 +93,12 @@ impl<'a> From<ImportFromDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
}
}

impl<'a> From<ForStmtDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(value: ForStmtDefinitionNodeRef<'a>) -> Self {
Self::For(value)
}
}

impl<'a> From<AssignmentDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
fn from(node_ref: AssignmentDefinitionNodeRef<'a>) -> Self {
Self::Assignment(node_ref)
Expand Down Expand Up @@ -141,6 +141,12 @@ pub(crate) struct WithItemDefinitionNodeRef<'a> {
pub(crate) target: &'a ast::ExprName,
}

#[derive(Copy, Clone, Debug)]
pub(crate) struct ForStmtDefinitionNodeRef<'a> {
pub(crate) iterable: &'a ast::Expr,
pub(crate) target: &'a ast::ExprName,
}

#[derive(Copy, Clone, Debug)]
pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::Comprehension,
Expand Down Expand Up @@ -181,8 +187,11 @@ impl DefinitionNodeRef<'_> {
DefinitionNodeRef::AugmentedAssignment(augmented_assignment) => {
DefinitionKind::AugmentedAssignment(AstNodeRef::new(parsed, augmented_assignment))
}
DefinitionNodeRef::For(for_stmt) => {
DefinitionKind::For(AstNodeRef::new(parsed, for_stmt))
DefinitionNodeRef::For(ForStmtDefinitionNodeRef { iterable, target }) => {
DefinitionKind::For(ForStmtDefinitionKind {
iterable: AstNodeRef::new(parsed.clone(), iterable),
target: AstNodeRef::new(parsed, target),
})
}
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef { node, first }) => {
DefinitionKind::Comprehension(ComprehensionDefinitionKind {
Expand Down Expand Up @@ -222,7 +231,10 @@ impl DefinitionNodeRef<'_> {
}) => target.into(),
Self::AnnotatedAssignment(node) => node.into(),
Self::AugmentedAssignment(node) => node.into(),
Self::For(node) => node.into(),
Self::For(ForStmtDefinitionNodeRef {
iterable: _,
target,
}) => target.into(),
Self::Comprehension(ComprehensionDefinitionNodeRef { node, first: _ }) => node.into(),
Self::Parameter(node) => match node {
ast::AnyParameterRef::Variadic(parameter) => parameter.into(),
Expand All @@ -243,7 +255,7 @@ pub enum DefinitionKind {
Assignment(AssignmentDefinitionKind),
AnnotatedAssignment(AstNodeRef<ast::StmtAnnAssign>),
AugmentedAssignment(AstNodeRef<ast::StmtAugAssign>),
For(AstNodeRef<ast::StmtFor>),
For(ForStmtDefinitionKind),
Comprehension(ComprehensionDefinitionKind),
Parameter(AstNodeRef<ast::Parameter>),
ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>),
Expand Down Expand Up @@ -314,6 +326,22 @@ impl WithItemDefinitionKind {
}
}

#[derive(Clone, Debug)]
pub struct ForStmtDefinitionKind {
iterable: AstNodeRef<ast::Expr>,
target: AstNodeRef<ast::ExprName>,
}

impl ForStmtDefinitionKind {
pub(crate) fn iterable(&self) -> &ast::Expr {
self.iterable.node()
}

pub(crate) fn target(&self) -> &ast::ExprName {
self.target.node()
}
}

#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub(crate) struct DefinitionNodeKey(NodeKey);

Expand Down
69 changes: 30 additions & 39 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,12 @@ impl<'db> TypeInferenceBuilder<'db> {
DefinitionKind::AugmentedAssignment(augmented_assignment) => {
self.infer_augment_assignment_definition(augmented_assignment.node(), definition);
}
DefinitionKind::For(for_statement) => {
self.infer_for_statement_definition(for_statement, definition);
DefinitionKind::For(for_statement_definition) => {
self.infer_for_statement_definition(
for_statement_definition.target(),
for_statement_definition.iterable(),
definition,
);
}
DefinitionKind::NamedExpression(named_expression) => {
self.infer_named_expression_definition(named_expression.node(), definition);
Expand Down Expand Up @@ -872,49 +876,36 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_body(orelse);
}

fn add_maybe_multiple_for_definitions(
&mut self,
for_target: &ast::Expr,
loop_var_ty: Type<'db>,
definition: Definition<'db>,
) {
if let ast::Expr::Tuple(ast::ExprTuple { elts, .. })
| ast::Expr::List(ast::ExprList { elts, .. }) = for_target
{
for elt in elts {
// TODO(Alex): unpack `loop_var_ty` as well, in tandem with unpacking `for_target`
self.add_maybe_multiple_for_definitions(elt, loop_var_ty, definition);
}
return;
}

self.types
.expressions
.insert(for_target.scoped_ast_id(self.db, self.scope), loop_var_ty);
self.types.definitions.insert(definition, loop_var_ty);
}

fn infer_for_statement_definition(
&mut self,
for_statement: &ast::StmtFor,
target: &ast::ExprName,
iterable: &ast::Expr,
definition: Definition<'db>,
) {
let ast::StmtFor {
range: _,
is_async: _,
target,
iter,
body: _,
orelse: _,
} = for_statement;
let expression = self.index.expression(&**iter);
let expression = self.index.expression(iterable);
let result = infer_expression_types(self.db, expression);
self.extend(result);
// TODO(Alex): the type of the loop var is the result of calling
// `type(x).__next__(x)` where `x` is the value returned by calling
// `type(iter).__iter__(iter)`
let loop_var_ty = Type::Unknown;
self.add_maybe_multiple_for_definitions(target, loop_var_ty, definition);
let iterable_ty = self
.types
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));

// TODO(Alex): only a valid iterable if the *type* of `iterable_ty` has an `__iter__`
// member (dunders are never looked up on an instance)
let _dunder_iter_ty = iterable_ty.member(self.db, &ast::name::Name::from("__iter__"));

// TODO(Alex):
// - infer the return type of the `__iter__` method, which gives us the iterator
// - lookup the `__next__` method on the iterator
// - infer the return type of the iterator's `__next__` method,
// which gives us the type of the variable being bound here
// (...or the type of the object being unpacked into multiple definitions, if it's something like
// `for k, v in d.items(): ...`)
let loop_var_value_ty = Type::Unknown;

self.types
.expressions
.insert(target.scoped_ast_id(self.db, self.scope), loop_var_value_ty);
self.types.definitions.insert(definition, loop_var_value_ty);
}

fn infer_while_statement(&mut self, while_statement: &ast::StmtWhile) {
Expand Down

0 comments on commit c81aa5c

Please sign in to comment.