From b7b1cdee3daa79c1707dc1400b67957a230a5b4b Mon Sep 17 00:00:00 2001 From: Oscar Date: Fri, 7 Feb 2025 12:48:16 -0500 Subject: [PATCH 1/4] Add name use checks to validation layer --- moto/dynamodb/models/__init__.py | 22 +- .../parsing/key_condition_expression.py | 28 +- moto/dynamodb/responses.py | 262 +++++++++++++++--- .../exceptions/test_dynamodb_exceptions.py | 259 +++++++++++++++++ .../test_key_condition_expression_parser.py | 29 +- tests/test_dynamodb/test_dynamodb.py | 2 +- .../test_dynamodb_condition_expressions.py | 2 +- 7 files changed, 533 insertions(+), 71 deletions(-) diff --git a/moto/dynamodb/models/__init__.py b/moto/dynamodb/models/__init__.py index d15cd904596b..8608032a7264 100644 --- a/moto/dynamodb/models/__init__.py +++ b/moto/dynamodb/models/__init__.py @@ -44,10 +44,7 @@ from moto.dynamodb.models.table_import import TableImport from moto.dynamodb.parsing import partiql from moto.dynamodb.parsing.executors import UpdateExpressionExecutor -from moto.dynamodb.parsing.expressions import ( # type: ignore - ExpressionAttributeName, - UpdateExpressionParser, -) +from moto.dynamodb.parsing.expressions import UpdateExpressionParser # type: ignore from moto.dynamodb.parsing.validators import UpdateExpressionValidator @@ -530,23 +527,6 @@ def update_item( except ItemSizeTooLarge: raise ItemSizeToUpdateTooLarge() - # Ensure all ExpressionAttributeNames are requested - # Either in the Condition, or in the UpdateExpression - attr_name_clauses = update_expression_ast.find_clauses( - [ExpressionAttributeName] - ) - attr_names_in_expression = [ - attr.get_attribute_name_placeholder() for attr in attr_name_clauses - ] - attr_names_in_condition = condition_expression_parser.expr_attr_names_found - for attr_name in expression_attribute_names or []: - if ( - attr_name not in attr_names_in_expression - and attr_name not in attr_names_in_condition - ): - raise MockValidationException( - f"Value provided in ExpressionAttributeNames unused in expressions: keys: {{{attr_name}}}" - ) else: item.update_with_attribute_updates(attribute_updates) # type: ignore if table.stream_shard is not None: diff --git a/moto/dynamodb/parsing/key_condition_expression.py b/moto/dynamodb/parsing/key_condition_expression.py index fb2fa8ff34ff..e7454c88e889 100644 --- a/moto/dynamodb/parsing/key_condition_expression.py +++ b/moto/dynamodb/parsing/key_condition_expression.py @@ -23,7 +23,7 @@ def parse_expression( expression_attribute_values: Dict[str, Dict[str, str]], expression_attribute_names: Dict[str, str], schema: List[Dict[str, str]], -) -> Tuple[Dict[str, Any], Optional[str], List[Dict[str, Any]]]: +) -> Tuple[Dict[str, Any], Optional[str], List[Dict[str, Any]], List[str]]: """ Parse a KeyConditionExpression using the provided expression attribute names/values @@ -37,6 +37,7 @@ def parse_expression( current_phrase = "" key_name = comparison = "" key_values: List[Union[Dict[str, str], str]] = [] + expression_attribute_names_used: List[str] = [] results: List[Tuple[str, str, Any]] = [] tokenizer = GenericTokenizer(key_condition_expression) for crnt_char in tokenizer: @@ -50,9 +51,11 @@ def parse_expression( else: # start_date < :sk and primary = :pk # ^ - key_name = expression_attribute_names.get( - current_phrase, current_phrase - ) + if expression_attribute_names.get(current_phrase): + key_name = expression_attribute_names[current_phrase] + expression_attribute_names_used.append(current_phrase) + else: + key_name = current_phrase current_phrase = "" current_stage = EXPRESSION_STAGES.COMPARISON tokenizer.skip_white_space() @@ -103,9 +106,11 @@ def parse_expression( EXPRESSION_STAGES.KEY_NAME, EXPRESSION_STAGES.INITIAL_STAGE, ]: - key_name = expression_attribute_names.get( - current_phrase, current_phrase - ) + if expression_attribute_names.get(current_phrase): + key_name = expression_attribute_names[current_phrase] + expression_attribute_names_used.append(current_phrase) + else: + key_name = current_phrase current_phrase = "" if crnt_char in ["<", ">"] and tokenizer.peek() == "=": comparison = crnt_char + tokenizer.__next__() @@ -118,9 +123,11 @@ def parse_expression( if current_stage == EXPRESSION_STAGES.KEY_NAME: # hashkey = :id and begins_with(sortkey, :sk) # ^ --> ^ - key_name = expression_attribute_names.get( - current_phrase, current_phrase - ) + if expression_attribute_names.get(current_phrase): + key_name = expression_attribute_names[current_phrase] + expression_attribute_names_used.append(current_phrase) + else: + key_name = current_phrase current_phrase = "" current_stage = EXPRESSION_STAGES.KEY_VALUE tokenizer.skip_white_space() @@ -192,6 +199,7 @@ def parse_expression( hash_value, range_comparison.upper() if range_comparison else None, range_values, + expression_attribute_names_used, ) diff --git a/moto/dynamodb/responses.py b/moto/dynamodb/responses.py index 3bada2e53376..73107a4b201a 100644 --- a/moto/dynamodb/responses.py +++ b/moto/dynamodb/responses.py @@ -6,9 +6,13 @@ from moto.core.common_types import TYPE_RESPONSE from moto.core.responses import BaseResponse +from moto.dynamodb.comparisons import create_condition_expression_parser from moto.dynamodb.models import DynamoDBBackend, Table, dynamodb_backends from moto.dynamodb.models.utilities import dynamo_json_dump -from moto.dynamodb.parsing.expressions import UpdateExpressionParser # type: ignore +from moto.dynamodb.parsing.expressions import ( # type: ignore + ExpressionAttributeName, + UpdateExpressionParser, +) from moto.dynamodb.parsing.key_condition_expression import parse_expression from moto.dynamodb.parsing.reserved_keywords import ReservedKeywords from moto.utilities.aws_headers import amz_crc32 @@ -153,6 +157,17 @@ def validate_put_has_gsi_keys_set_to_none(item: Dict[str, Any], table: Table) -> ) +def validate_attribute_names_used( + attribute_names: Optional[Dict[str, str]], names_used: List[str] +) -> None: + if attribute_names: + for name in attribute_names: + if name not in names_used: + raise MockValidationException( + f"Value provided in ExpressionAttributeNames unused in expressions: keys: {{{name}}}" + ) + + def check_projection_expression(expression: str) -> None: if expression.upper() in ReservedKeywords.get_reserved_keywords(): raise MockValidationException( @@ -168,6 +183,45 @@ def check_projection_expression(expression: str) -> None: ) +class ProjectionExpressionParser: + def __init__( + self, + projection_expression: Optional[str], + expression_attribute_names: Optional[Dict[str, str]], + ): + self.projection_expression = projection_expression + self.expression_attribute_names = ( + expression_attribute_names if expression_attribute_names else {} + ) + + self.expr_attr_names_found: List[str] = [] + + def parse(self) -> List[List[str]]: + """ + lvl1.lvl2.attr1,lvl1.attr2 --> [["lvl1", "lvl2", "attr1"], ["lvl1", "attr2]] + """ + + if self.projection_expression: + expressions = [x.strip() for x in self.projection_expression.split(",")] + duplicates = extract_duplicates(expressions) + if duplicates: + raise InvalidProjectionExpression(duplicates) + for expression in expressions: + check_projection_expression(expression) + output = [] + for nested_expr in expressions: + nested_array = [] + for expr in nested_expr.split("."): + if self.expression_attribute_names.get(expr): + self.expr_attr_names_found.append(expr) + nested_array.append(self.expression_attribute_names[expr]) + else: + nested_array.append(expr) + output.append(nested_array) + return output + return [] + + class DynamoHandler(BaseResponse): def __init__(self) -> None: super().__init__(service_name="dynamodb") @@ -499,6 +553,16 @@ def put_item(self) -> str: expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) expression_attribute_values = self._get_expr_attr_values() + parser = create_condition_expression_parser( + condition_expression, + expression_attribute_names, + expression_attribute_values, + ) + parser.parse() + validate_attribute_names_used( + expression_attribute_names, parser.expr_attr_names_found + ) + if condition_expression: overwrite = False @@ -597,9 +661,14 @@ def get_item(self) -> str: raise ProvidedKeyDoesNotExist expression_attribute_names = expression_attribute_names or {} - projection_expressions = self._adjust_projection_expression( + parser = ProjectionExpressionParser( projection_expression, expression_attribute_names ) + projection_expressions = parser.parse() + + validate_attribute_names_used( + expression_attribute_names, parser.expr_attr_names_found + ) item = self.dynamodb_backend.get_item(name, key, projection_expressions) if item: @@ -649,9 +718,13 @@ def batch_get_item(self) -> str: "ExpressionAttributeNames", {} ) - projection_expressions = self._adjust_projection_expression( + parser = ProjectionExpressionParser( projection_expression, expression_attribute_names ) + projection_expressions = parser.parse() + validate_attribute_names_used( + expression_attribute_names, parser.expr_attr_names_found + ) results["Responses"][table_name] = [] for key in keys: @@ -693,9 +766,17 @@ def query(self) -> str: filter_expression = self._get_filter_expression() expression_attribute_values = self._get_expr_attr_values() - projection_expressions = self._adjust_projection_expression( + condition_parser = create_condition_expression_parser( + filter_expression, expression_attribute_names, expression_attribute_values + ) + condition_parser.parse() + expression_attribute_names_used = condition_parser.expr_attr_names_found + + projection_parser = ProjectionExpressionParser( projection_expression, expression_attribute_names ) + projection_expressions = projection_parser.parse() + expression_attribute_names_used += projection_parser.expr_attr_names_found filter_kwargs = {} @@ -704,12 +785,20 @@ def query(self) -> str: schema = self.dynamodb_backend.get_schema( table_name=name, index_name=index_name ) - hash_key, range_comparison, range_values = parse_expression( + ( + hash_key, + range_comparison, + range_values, + expression_attribute_names_used_by_key_condition, + ) = parse_expression( key_condition_expression=key_condition_expression, expression_attribute_names=expression_attribute_names, expression_attribute_values=expression_attribute_values, schema=schema, ) + expression_attribute_names_used += ( + expression_attribute_names_used_by_key_condition + ) else: # 'KeyConditions': {u'forum_name': {u'ComparisonOperator': u'EQ', u'AttributeValueList': [{u'S': u'the-key'}]}} key_conditions = self.body.get("KeyConditions") @@ -749,6 +838,10 @@ def query(self) -> str: range_values = [] if query_filters: filter_kwargs.update(query_filters) + + validate_attribute_names_used( + expression_attribute_names, expression_attribute_names_used + ) index_name = self.body.get("IndexName") exclusive_start_key = self.body.get("ExclusiveStartKey") limit = self.body.get("Limit") @@ -785,30 +878,6 @@ def query(self) -> str: return dynamo_json_dump(result) - def _adjust_projection_expression( - self, projection_expression: Optional[str], expr_attr_names: Dict[str, str] - ) -> List[List[str]]: - """ - lvl1.lvl2.attr1,lvl1.attr2 --> [["lvl1", "lvl2", "attr1"], ["lvl1", "attr2]] - """ - - def _adjust(expression: str) -> str: - return (expr_attr_names or {}).get(expression, expression) - - if projection_expression: - expressions = [x.strip() for x in projection_expression.split(",")] - duplicates = extract_duplicates(expressions) - if duplicates: - raise InvalidProjectionExpression(duplicates) - for expression in expressions: - check_projection_expression(expression) - return [ - [_adjust(expr) for expr in nested_expr.split(".")] - for nested_expr in expressions - ] - - return [] - @include_consumed_capacity() def scan(self) -> str: name = self.body["TableName"] @@ -825,7 +894,25 @@ def scan(self) -> str: filter_expression = self._get_filter_expression() expression_attribute_values = self._get_expr_attr_values() expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) + + filter_parser = create_condition_expression_parser( + filter_expression, + expression_attribute_names, + expression_attribute_values, + ) + try: + filter_parser.parse() + except ValueError as err: + raise MockValidationException(f"Bad Filter Expression: {err}") + expression_attribute_names_used = filter_parser.expr_attr_names_found + projection_expression = self._get_projection_expression() + projection_parser = ProjectionExpressionParser( + projection_expression, expression_attribute_names + ) + projection_expressions = projection_parser.parse() + expression_attribute_names_used += projection_parser.expr_attr_names_found + exclusive_start_key = self.body.get("ExclusiveStartKey") limit = self.body.get("Limit") index_name = self.body.get("IndexName") @@ -849,8 +936,8 @@ def scan(self) -> str: f"The Segment parameter is zero-based and must be less than parameter TotalSegments: Segment: {segment} is not less than TotalSegments: {total_segments}" ) - projection_expressions = self._adjust_projection_expression( - projection_expression, expression_attribute_names + validate_attribute_names_used( + expression_attribute_names, expression_attribute_names_used ) try: @@ -900,6 +987,16 @@ def delete_item(self) -> str: "ReturnValuesOnConditionCheckFailure" ) + parser = create_condition_expression_parser( + condition_expression, + expression_attribute_names, + expression_attribute_values, + ) + parser.parse() + validate_attribute_names_used( + expression_attribute_names, parser.expr_attr_names_found + ) + item = self.dynamodb_backend.delete_item( name, key, @@ -937,8 +1034,17 @@ def update_item(self) -> str: raise MockValidationException( "Invalid UpdateExpression: The expression can not be empty;" ) + update_expression_ast = UpdateExpressionParser.make(update_expression) + attr_name_clauses = update_expression_ast.find_clauses( + [ExpressionAttributeName] + ) + expression_attribute_names_used = [ + attr.get_attribute_name_placeholder() for attr in attr_name_clauses + ] + else: update_expression = "" + expression_attribute_names_used = [] return_values_on_condition_check_failure = self.body.get( "ReturnValuesOnConditionCheckFailure" @@ -969,6 +1075,16 @@ def update_item(self) -> str: condition_expression = self.body.get("ConditionExpression") expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) expression_attribute_values = self._get_expr_attr_values() + condition_parser = create_condition_expression_parser( + condition_expression, + expression_attribute_names, + expression_attribute_values, + ) + condition_parser.parse() + expression_attribute_names_used += condition_parser.expr_attr_names_found + validate_attribute_names_used( + expression_attribute_names, expression_attribute_names_used + ) item = self.dynamodb_backend.update_item( name, @@ -1142,6 +1258,25 @@ def transact_write_items(self) -> str: item_attrs = item["Put"]["Item"] table = self.dynamodb_backend.get_table(item["Put"]["TableName"]) validate_put_has_empty_keys(item_attrs, table) + + condition_expression = item["Put"].get("ConditionExpression") + expression_attribute_names = item["Put"].get( + "ExpressionAttributeNames", {} + ) + expression_attribute_values = item["Put"].get( + "ExpressionAttributeValues", {} + ) + + parser = create_condition_expression_parser( + condition_expression, + expression_attribute_names, + expression_attribute_values, + ) + parser.parse() + validate_attribute_names_used( + expression_attribute_names, parser.expr_attr_names_found + ) + if "Update" in item: if item["Update"].get("ExpressionAttributeValues") == {}: raise ExpressionAttributeValuesEmpty @@ -1153,10 +1288,75 @@ def transact_write_items(self) -> str: table, custom_error_msg="One or more parameter values are not valid. The AttributeValue for a key attribute cannot contain an empty string value. Key: {}", ) + update_expression = item["Update"]["UpdateExpression"] UpdateExpressionParser.make(update_expression).validate( limit_set_actions=True ) + update_expression_ast = UpdateExpressionParser.make(update_expression) + update_expression_ast.validate(limit_set_actions=True) + attr_name_clauses = update_expression_ast.find_clauses( + [ExpressionAttributeName] + ) + expression_attribute_names_used = [ + attr.get_attribute_name_placeholder() for attr in attr_name_clauses + ] + + condition_expression = item["Update"].get("ConditionExpression") + expression_attribute_names = item["Update"].get( + "ExpressionAttributeNames", {} + ) + expression_attribute_values = item["Update"].get( + "ExpressionAttributeValues", {} + ) + parser = create_condition_expression_parser( + condition_expression, + expression_attribute_names, + expression_attribute_values, + ) + parser.parse() + expression_attribute_names_used += parser.expr_attr_names_found + + validate_attribute_names_used( + expression_attribute_names, expression_attribute_names_used + ) + if "Delete" in item: + condition_expression = item["Delete"].get("ConditionExpression") + expression_attribute_names = item["Delete"].get( + "ExpressionAttributeNames", {} + ) + expression_attribute_values = item["Delete"].get( + "ExpressionAttributeValues", {} + ) + + parser = create_condition_expression_parser( + condition_expression, + expression_attribute_names, + expression_attribute_values, + ) + parser.parse() + validate_attribute_names_used( + expression_attribute_names, parser.expr_attr_names_found + ) + if "ConditionCheck" in item: + condition_expression = item["ConditionCheck"].get("ConditionExpression") + expression_attribute_names = item["ConditionCheck"].get( + "ExpressionAttributeNames", {} + ) + expression_attribute_values = item["ConditionCheck"].get( + "ExpressionAttributeValues", {} + ) + + parser = create_condition_expression_parser( + condition_expression, + expression_attribute_names, + expression_attribute_values, + ) + parser.parse() + validate_attribute_names_used( + expression_attribute_names, parser.expr_attr_names_found + ) + self.dynamodb_backend.transact_write_items(transact_items) response: Dict[str, Any] = {"ConsumedCapacity": [], "ItemCollectionMetrics": {}} return dynamo_json_dump(response) diff --git a/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py b/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py index 8861f36e7694..dba7cd9adada 100644 --- a/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py +++ b/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py @@ -307,6 +307,265 @@ def test_update_item_unused_attribute_name(table_name=None): ) +@pytest.mark.aws_verified +@dynamodb_aws_verified() +def test_put_item_unused_attribute_name(table_name=None): + ddb = boto3.resource("dynamodb", region_name="us-east-1") + + table = ddb.Table(table_name) + + with pytest.raises(ClientError) as exc: + table.put_item( + Item={"pk": "pk1", "spec": {}, "am": 0}, + ConditionExpression="attribute_not_exists(body)", + ExpressionAttributeNames={"#count": "count"}, + ) + err = exc.value.response["Error"] + assert err["Code"] == "ValidationException" + assert ( + err["Message"] + == "Value provided in ExpressionAttributeNames unused in expressions: keys: {#count}" + ) + + +@pytest.mark.aws_verified +@dynamodb_aws_verified() +def test_get_item_unused_attribute_name(table_name=None): + ddb = boto3.resource("dynamodb", region_name="us-east-1") + + table = ddb.Table(table_name) + + with pytest.raises(ClientError) as exc: + table.get_item( + Key={"pk": "example_id"}, + ProjectionExpression="pk", + ExpressionAttributeNames={"#count": "count"}, + ) + err = exc.value.response["Error"] + assert err["Code"] == "ValidationException" + assert ( + err["Message"] + == "Value provided in ExpressionAttributeNames unused in expressions: keys: {#count}" + ) + + +@pytest.mark.aws_verified +@dynamodb_aws_verified() +def test_query_unused_attribute_name(table_name=None): + ddb = boto3.resource("dynamodb", region_name="us-east-1") + + table = ddb.Table(table_name) + + with pytest.raises(ClientError) as exc: + table.query( + KeyConditionExpression="(#0 = 1) AND (begins_with(#1, a))", + ExpressionAttributeNames={"#0": "pk", "#1": "sk", "#count": "count"}, + ) + err = exc.value.response["Error"] + assert err["Code"] == "ValidationException" + assert ( + err["Message"] + == "Value provided in ExpressionAttributeNames unused in expressions: keys: {#count}" + ) + + +@pytest.mark.aws_verified +@dynamodb_aws_verified() +def test_scan_unused_attribute_name(table_name=None): + ddb = boto3.resource("dynamodb", region_name="us-east-1") + + table = ddb.Table(table_name) + + with pytest.raises(ClientError) as exc: + table.scan( + TableName=table_name, + FilterExpression="#h = :h", + ExpressionAttributeNames={"#h": "pk", "#count": "count"}, + ExpressionAttributeValues={":h": {"S": "hash_value"}}, + ) + err = exc.value.response["Error"] + assert err["Code"] == "ValidationException" + assert ( + err["Message"] + == "Value provided in ExpressionAttributeNames unused in expressions: keys: {#count}" + ) + + +@pytest.mark.aws_verified +@dynamodb_aws_verified() +def test_delete_unused_attribute_name(table_name=None): + ddb = boto3.resource("dynamodb", region_name="us-east-1") + + table = ddb.Table(table_name) + + with pytest.raises(ClientError) as exc: + table.delete_item( + Key={"pk": "pk1"}, + ConditionExpression="attribute_not_exists(body)", + ExpressionAttributeNames={"#count": "count"}, + ) + err = exc.value.response["Error"] + assert err["Code"] == "ValidationException" + assert ( + err["Message"] + == "Value provided in ExpressionAttributeNames unused in expressions: keys: {#count}" + ) + + +@pytest.mark.aws_verified +@dynamodb_aws_verified() +def test_batch_get_item_unused_attribute_name(table_name=None): + ddb = boto3.resource("dynamodb", region_name="us-east-1") + + with pytest.raises(ClientError) as exc: + ddb.batch_get_item( + RequestItems={ + "users": { + "Keys": [ + {"username": {"S": "user0"}}, + {"username": {"S": "user1"}}, + {"username": {"S": "user2"}}, + {"username": {"S": "user3"}}, + ], + "ConsistentRead": True, + "ProjectionExpression": "#rl", + "ExpressionAttributeNames": {"#rl": "username", "#count": "count"}, + } + } + ) + err = exc.value.response["Error"] + assert err["Code"] == "ValidationException" + assert ( + err["Message"] + == "Value provided in ExpressionAttributeNames unused in expressions: keys: {#count}" + ) + + +@pytest.mark.aws_verified +@dynamodb_aws_verified() +def test_transact_write_item_put_unused_attribute_name(table_name=None): + ddb = boto3.client("dynamodb", region_name="us-east-1") + + with pytest.raises(ClientError) as exc: + ddb.transact_write_items( + TransactItems=[ + { + "Put": { + "Item": { + "pk": {"S": "foo"}, + "foo": {"S": "bar"}, + }, + "TableName": table_name, + "ConditionExpression": "#i <> foo", + "ExpressionAttributeNames": {"#i": "pk", "#count": "count"}, + }, + } + ] + ) + err = exc.value.response["Error"] + assert err["Code"] == "ValidationException" + assert ( + err["Message"] + == "Value provided in ExpressionAttributeNames unused in expressions: keys: {#count}" + ) + + +@pytest.mark.aws_verified +@dynamodb_aws_verified() +def test_transact_write_item_update_unused_attribute_name(table_name=None): + ddb = boto3.client("dynamodb", region_name="us-east-1") + + with pytest.raises(ClientError) as exc: + ddb.transact_write_items( + TransactItems=[ + { + "Update": { + "Key": {"id": {"S": "foo"}}, + "TableName": table_name, + "UpdateExpression": "SET #e = test", + "ExpressionAttributeNames": { + "#e": "email_address", + "#count": "count", + }, + } + } + ] + ) + err = exc.value.response["Error"] + assert err["Code"] == "ValidationException" + assert ( + err["Message"] + == "Value provided in ExpressionAttributeNames unused in expressions: keys: {#count}" + ) + + +@pytest.mark.aws_verified +@dynamodb_aws_verified() +def test_transact_write_item_delete_unused_attribute_name(table_name=None): + ddb = boto3.client("dynamodb", region_name="us-east-1") + + with pytest.raises(ClientError) as exc: + ddb.transact_write_items( + TransactItems=[ + { + "Delete": { + "Key": { + "pk": {"S": "foo"}, + "foo": {"S": "bar"}, + }, + "TableName": table_name, + "ConditionExpression": "#i <> foo", + "ExpressionAttributeNames": {"#i": "pk", "#count": "count"}, + }, + } + ] + ) + err = exc.value.response["Error"] + assert err["Code"] == "ValidationException" + assert ( + err["Message"] + == "Value provided in ExpressionAttributeNames unused in expressions: keys: {#count}" + ) + + +@pytest.mark.aws_verified +@dynamodb_aws_verified() +def test_transact_write_item_unused_attribute_name_in_condition_check(table_name=None): + ddb = boto3.client("dynamodb", region_name="us-east-1") + + with pytest.raises(ClientError) as exc: + ddb.transact_write_items( + TransactItems=[ + { + "ConditionCheck": { + "Key": {"id": {"S": "foo"}}, + "TableName": table_name, + "ConditionExpression": "attribute_exists(#e)", + "ExpressionAttributeNames": { + "#e": "email_address", + "#count": "count", + }, + } + }, + { + "Put": { + "Item": { + "id": {"S": "bar"}, + "email_address": {"S": "bar@moto.com"}, + }, + "TableName": table_name, + } + }, + ] + ) + err = exc.value.response["Error"] + assert err["Code"] == "ValidationException" + assert ( + err["Message"] + == "Value provided in ExpressionAttributeNames unused in expressions: keys: {#count}" + ) + + @mock_aws def test_batch_get_item_non_existing_table(): client = boto3.client("dynamodb", region_name="us-west-2") diff --git a/tests/test_dynamodb/models/test_key_condition_expression_parser.py b/tests/test_dynamodb/models/test_key_condition_expression_parser.py index 30545a3d61a1..bdacd5e299e1 100644 --- a/tests/test_dynamodb/models/test_key_condition_expression_parser.py +++ b/tests/test_dynamodb/models/test_key_condition_expression_parser.py @@ -10,7 +10,7 @@ class TestHashKey: @pytest.mark.parametrize("expression", ["job_id = :id", "job_id = :id "]) def test_hash_key_only(self, expression): eav = {":id": {"S": "asdasdasd"}} - desired_hash_key, comparison, range_values = parse_expression( + desired_hash_key, comparison, range_values, _ = parse_expression( expression_attribute_values=eav, key_condition_expression=expression, schema=self.schema, @@ -84,7 +84,7 @@ def test_unknown_range_key(self, expr): ) def test_begin_with(self, expr): eav = {":id": "pk", ":sk": "19"} - desired_hash_key, comparison, range_values = parse_expression( + desired_hash_key, comparison, range_values, _ = parse_expression( expression_attribute_values=eav, key_condition_expression=expr, schema=self.schema, @@ -119,7 +119,7 @@ def test_begin_with__wrong_case(self, fn): ) def test_in_between(self, expr): eav = {":id": "pk", ":sk1": "19", ":sk2": "21"} - desired_hash_key, comparison, range_values = parse_expression( + desired_hash_key, comparison, range_values, _ = parse_expression( expression_attribute_values=eav, key_condition_expression=expr, schema=self.schema, @@ -133,7 +133,7 @@ def test_in_between(self, expr): def test_numeric_comparisons(self, operator): eav = {":id": "pk", ":sk": "19"} expr = f"job_id = :id and start_date{operator}:sk" - desired_hash_key, comparison, range_values = parse_expression( + desired_hash_key, comparison, range_values, _ = parse_expression( expression_attribute_values=eav, key_condition_expression=expr, schema=self.schema, @@ -154,7 +154,7 @@ def test_numeric_comparisons(self, operator): ) def test_reverse_keys(self, expr): eav = {":id": "pk", ":sk1": "19", ":sk2": "21"} - desired_hash_key, comparison, range_values = parse_expression( + desired_hash_key, comparison, range_values, _ = parse_expression( expression_attribute_values=eav, key_condition_expression=expr, schema=self.schema, @@ -171,7 +171,7 @@ def test_reverse_keys(self, expr): ], ) def test_brackets(self, expr): - desired_hash_key, comparison, range_values = parse_expression( + desired_hash_key, comparison, range_values, _ = parse_expression( expression_attribute_values={":id": "pk", ":sk": "19"}, key_condition_expression=expr, schema=self.schema, @@ -187,7 +187,7 @@ def test_names_and_values(self): kce = ":j = :id" ean = {":j": "job_id"} eav = {":id": {"S": "asdasdasd"}} - desired_hash_key, comparison, range_values = parse_expression( + desired_hash_key, comparison, range_values, _ = parse_expression( expression_attribute_values=eav, key_condition_expression=kce, schema=self.schema, @@ -196,3 +196,18 @@ def test_names_and_values(self): assert desired_hash_key == eav[":id"] assert comparison is None assert range_values == [] + + +def test_expression_attribute_names_found(): + kce = ":j = :id" + ean = {":j": "job_id"} + eav = {":id": {"S": "asdasdasd"}} + desired_hash_key, comparison, range_values, expression_attribute_names_used = ( + parse_expression( + expression_attribute_values=eav, + key_condition_expression=kce, + schema=[{"AttributeName": "job_id", "KeyType": "HASH"}], + expression_attribute_names=ean, + ) + ) + assert expression_attribute_names_used == [":j"] diff --git a/tests/test_dynamodb/test_dynamodb.py b/tests/test_dynamodb/test_dynamodb.py index 9f5741926af4..7cc32d891240 100644 --- a/tests/test_dynamodb/test_dynamodb.py +++ b/tests/test_dynamodb/test_dynamodb.py @@ -1131,7 +1131,7 @@ def test_nested_projection_expression_using_scan_with_attr_expression_names(): # Test a scan results = table.scan( FilterExpression=Key("forum_name").eq("key1"), - ProjectionExpression="nested.level1.id, nested.level2", + ProjectionExpression="#nst.level1.id, #nst.#lvl2", ExpressionAttributeNames={"#nst": "nested", "#lvl2": "level2"}, )["Items"] assert results == [ diff --git a/tests/test_dynamodb/test_dynamodb_condition_expressions.py b/tests/test_dynamodb/test_dynamodb_condition_expressions.py index 6cd6c40b4bb2..0cf1b94d47e9 100644 --- a/tests/test_dynamodb/test_dynamodb_condition_expressions.py +++ b/tests/test_dynamodb/test_dynamodb_condition_expressions.py @@ -227,7 +227,7 @@ def test_condition_expressions(): Key={"client": {"S": "client1"}, "app": {"S": "app1"}}, ConditionExpression="attribute_not_exists(#existing)", ExpressionAttributeValues={":match": {"S": "match"}}, - ExpressionAttributeNames={"#existing": "existing", "#match": "match"}, + ExpressionAttributeNames={"#existing": "existing"}, ) From 4b7be48d6d93e5a6583fd3f96cc1bcc6dd29fc1a Mon Sep 17 00:00:00 2001 From: Oscar Date: Tue, 25 Feb 2025 09:40:52 -0500 Subject: [PATCH 2/4] Fixing syntax error in one of the attribute name tests --- tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py b/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py index dba7cd9adada..c4e2a4697fef 100644 --- a/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py +++ b/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py @@ -358,7 +358,7 @@ def test_query_unused_attribute_name(table_name=None): with pytest.raises(ClientError) as exc: table.query( - KeyConditionExpression="(#0 = 1) AND (begins_with(#1, a))", + KeyConditionExpression="(#0 = x) AND (begins_with(#1, a))", ExpressionAttributeNames={"#0": "pk", "#1": "sk", "#count": "count"}, ) err = exc.value.response["Error"] From 6b8b1714453847de573db59f69c59ad99359c656 Mon Sep 17 00:00:00 2001 From: Oscar Date: Tue, 25 Feb 2025 17:50:40 -0500 Subject: [PATCH 3/4] Common parsing of condition expression --- moto/dynamodb/responses.py | 93 ++++++++------------------------------ 1 file changed, 20 insertions(+), 73 deletions(-) diff --git a/moto/dynamodb/responses.py b/moto/dynamodb/responses.py index 73107a4b201a..fd54aaf4ce51 100644 --- a/moto/dynamodb/responses.py +++ b/moto/dynamodb/responses.py @@ -1251,6 +1251,24 @@ def transact_write_items(self) -> str: transact_items = self.body["TransactItems"] # Validate first - we should error before we start the transaction for item in transact_items: + # This logic is common among all types of write items + item_values = list(item.values())[0] # Each item only has one of Put, Update, Delete, ConditionCheck + condition_expression = item_values.get("ConditionExpression") + expression_attribute_names = item_values.get( + "ExpressionAttributeNames", {} + ) + expression_attribute_values = item_values.get( + "ExpressionAttributeValues", {} + ) + + parser = create_condition_expression_parser( + condition_expression, + expression_attribute_names, + expression_attribute_values, + ) + parser.parse() + expression_attribute_names_used = parser.expr_attr_names_found + if "Put" in item: if item["Put"].get("ExpressionAttributeValues") == {}: raise ExpressionAttributeValuesEmpty @@ -1259,24 +1277,6 @@ def transact_write_items(self) -> str: table = self.dynamodb_backend.get_table(item["Put"]["TableName"]) validate_put_has_empty_keys(item_attrs, table) - condition_expression = item["Put"].get("ConditionExpression") - expression_attribute_names = item["Put"].get( - "ExpressionAttributeNames", {} - ) - expression_attribute_values = item["Put"].get( - "ExpressionAttributeValues", {} - ) - - parser = create_condition_expression_parser( - condition_expression, - expression_attribute_names, - expression_attribute_values, - ) - parser.parse() - validate_attribute_names_used( - expression_attribute_names, parser.expr_attr_names_found - ) - if "Update" in item: if item["Update"].get("ExpressionAttributeValues") == {}: raise ExpressionAttributeValuesEmpty @@ -1298,64 +1298,11 @@ def transact_write_items(self) -> str: attr_name_clauses = update_expression_ast.find_clauses( [ExpressionAttributeName] ) - expression_attribute_names_used = [ + expression_attribute_names_used += [ attr.get_attribute_name_placeholder() for attr in attr_name_clauses ] - condition_expression = item["Update"].get("ConditionExpression") - expression_attribute_names = item["Update"].get( - "ExpressionAttributeNames", {} - ) - expression_attribute_values = item["Update"].get( - "ExpressionAttributeValues", {} - ) - parser = create_condition_expression_parser( - condition_expression, - expression_attribute_names, - expression_attribute_values, - ) - parser.parse() - expression_attribute_names_used += parser.expr_attr_names_found - - validate_attribute_names_used( - expression_attribute_names, expression_attribute_names_used - ) - if "Delete" in item: - condition_expression = item["Delete"].get("ConditionExpression") - expression_attribute_names = item["Delete"].get( - "ExpressionAttributeNames", {} - ) - expression_attribute_values = item["Delete"].get( - "ExpressionAttributeValues", {} - ) - - parser = create_condition_expression_parser( - condition_expression, - expression_attribute_names, - expression_attribute_values, - ) - parser.parse() - validate_attribute_names_used( - expression_attribute_names, parser.expr_attr_names_found - ) - if "ConditionCheck" in item: - condition_expression = item["ConditionCheck"].get("ConditionExpression") - expression_attribute_names = item["ConditionCheck"].get( - "ExpressionAttributeNames", {} - ) - expression_attribute_values = item["ConditionCheck"].get( - "ExpressionAttributeValues", {} - ) - - parser = create_condition_expression_parser( - condition_expression, - expression_attribute_names, - expression_attribute_values, - ) - parser.parse() - validate_attribute_names_used( - expression_attribute_names, parser.expr_attr_names_found - ) + validate_attribute_names_used(expression_attribute_names, expression_attribute_names_used) self.dynamodb_backend.transact_write_items(transact_items) response: Dict[str, Any] = {"ConsumedCapacity": [], "ItemCollectionMetrics": {}} From e85cdb48d8bb1cc612f27180f91d5f6cd9123698 Mon Sep 17 00:00:00 2001 From: BlizzardOfOzzy Date: Thu, 27 Feb 2025 20:42:30 +0000 Subject: [PATCH 4/4] formatting --- moto/dynamodb/responses.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/moto/dynamodb/responses.py b/moto/dynamodb/responses.py index fd54aaf4ce51..948b53e61e30 100644 --- a/moto/dynamodb/responses.py +++ b/moto/dynamodb/responses.py @@ -1252,11 +1252,11 @@ def transact_write_items(self) -> str: # Validate first - we should error before we start the transaction for item in transact_items: # This logic is common among all types of write items - item_values = list(item.values())[0] # Each item only has one of Put, Update, Delete, ConditionCheck + item_values = list(item.values())[ + 0 + ] # Each item only has one of Put, Update, Delete, ConditionCheck condition_expression = item_values.get("ConditionExpression") - expression_attribute_names = item_values.get( - "ExpressionAttributeNames", {} - ) + expression_attribute_names = item_values.get("ExpressionAttributeNames", {}) expression_attribute_values = item_values.get( "ExpressionAttributeValues", {} ) @@ -1302,7 +1302,9 @@ def transact_write_items(self) -> str: attr.get_attribute_name_placeholder() for attr in attr_name_clauses ] - validate_attribute_names_used(expression_attribute_names, expression_attribute_names_used) + validate_attribute_names_used( + expression_attribute_names, expression_attribute_names_used + ) self.dynamodb_backend.transact_write_items(transact_items) response: Dict[str, Any] = {"ConsumedCapacity": [], "ItemCollectionMetrics": {}}