Skip to content

Commit

Permalink
Updated azure-cosmos to 4.7.0, requiring dropped support for obsolete…
Browse files Browse the repository at this point in the history
… CosmosDBStorage class. (#2165)

Co-authored-by: Tracy Boehrer <trboehre@microsoft.com>
  • Loading branch information
tracyboehrer and Tracy Boehrer authored Sep 10, 2024
1 parent b3e7436 commit 50e72c0
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 775 deletions.
4 changes: 1 addition & 3 deletions libraries/botbuilder-azure/botbuilder/azure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,17 @@

from .about import __version__
from .azure_queue_storage import AzureQueueStorage
from .cosmosdb_storage import CosmosDbStorage, CosmosDbConfig, CosmosDbKeyEscape
from .cosmosdb_partitioned_storage import (
CosmosDbPartitionedStorage,
CosmosDbPartitionedConfig,
CosmosDbKeyEscape,
)
from .blob_storage import BlobStorage, BlobStorageSettings

__all__ = [
"AzureQueueStorage",
"BlobStorage",
"BlobStorageSettings",
"CosmosDbStorage",
"CosmosDbConfig",
"CosmosDbKeyEscape",
"CosmosDbPartitionedStorage",
"CosmosDbPartitionedConfig",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from typing import Dict, List
from threading import Lock
import json

from hashlib import sha256
from azure.core import MatchConditions
from azure.cosmos import documents, http_constants
from jsonpickle.pickler import Pickler
from jsonpickle.unpickler import Unpickler
import azure.cosmos.cosmos_client as cosmos_client # pylint: disable=no-name-in-module,import-error
import azure.cosmos.errors as cosmos_errors # pylint: disable=no-name-in-module,import-error
import azure.cosmos.exceptions as cosmos_exceptions
from botbuilder.core.storage import Storage
from botbuilder.azure import CosmosDbKeyEscape


class CosmosDbPartitionedConfig:
Expand Down Expand Up @@ -63,6 +63,49 @@ def __init__(
self.compatibility_mode = compatibility_mode or kwargs.get("compatibility_mode")


class CosmosDbKeyEscape:
@staticmethod
def sanitize_key(
key: str, key_suffix: str = "", compatibility_mode: bool = True
) -> str:
"""Return the sanitized key.
Replace characters that are not allowed in keys in Cosmos.
:param key: The provided key to be escaped.
:param key_suffix: The string to add a the end of all RowKeys.
:param compatibility_mode: True if keys should be truncated in order to support previous CosmosDb
max key length of 255. This behavior can be overridden by setting
cosmosdb_partitioned_config.compatibility_mode to False.
:return str:
"""
# forbidden characters
bad_chars = ["\\", "?", "/", "#", "\t", "\n", "\r", "*"]
# replace those with with '*' and the
# Unicode code point of the character and return the new string
key = "".join(map(lambda x: "*" + str(ord(x)) if x in bad_chars else x, key))

if key_suffix is None:
key_suffix = ""

return CosmosDbKeyEscape.truncate_key(f"{key}{key_suffix}", compatibility_mode)

@staticmethod
def truncate_key(key: str, compatibility_mode: bool = True) -> str:
max_key_len = 255

if not compatibility_mode:
return key

if len(key) > max_key_len:
aux_hash = sha256(key.encode("utf-8"))
aux_hex = aux_hash.hexdigest()

key = key[0 : max_key_len - len(aux_hex)] + aux_hex

return key


class CosmosDbPartitionedStorage(Storage):
"""A CosmosDB based storage provider using partitioning for a bot."""

Expand Down Expand Up @@ -99,7 +142,8 @@ async def read(self, keys: List[str]) -> Dict[str, object]:
:return dict:
"""
if not keys:
raise Exception("Keys are required when reading")
# No keys passed in, no result to return. Back-compat with original CosmosDBStorage.
return {}

await self.initialize()

Expand All @@ -111,8 +155,8 @@ async def read(self, keys: List[str]) -> Dict[str, object]:
key, self.config.key_suffix, self.config.compatibility_mode
)

read_item_response = self.client.ReadItem(
self.__item_link(escaped_key), self.__get_partition_key(escaped_key)
read_item_response = self.container.read_item(
escaped_key, self.__get_partition_key(escaped_key)
)
document_store_item = read_item_response
if document_store_item:
Expand All @@ -122,13 +166,8 @@ async def read(self, keys: List[str]) -> Dict[str, object]:
# When an item is not found a CosmosException is thrown, but we want to
# return an empty collection so in this instance we catch and do not rethrow.
# Throw for any other exception.
except cosmos_errors.HTTPFailure as err:
if (
err.status_code
== cosmos_errors.http_constants.StatusCodes.NOT_FOUND
):
continue
raise err
except cosmos_exceptions.CosmosResourceNotFoundError:
continue
except Exception as err:
raise err
return store_items
Expand Down Expand Up @@ -162,20 +201,16 @@ async def write(self, changes: Dict[str, object]):
if e_tag == "":
raise Exception("cosmosdb_storage.write(): etag missing")

access_condition = {
"accessCondition": {"type": "IfMatch", "condition": e_tag}
}
options = (
access_condition if e_tag != "*" and e_tag and e_tag != "" else None
)
access_condition = e_tag != "*" and e_tag and e_tag != ""

try:
self.client.UpsertItem(
database_or_Container_link=self.__container_link,
document=doc,
options=options,
self.container.upsert_item(
body=doc,
etag=e_tag if access_condition else None,
match_condition=(
MatchConditions.IfNotModified if access_condition else None
),
)
except cosmos_errors.HTTPFailure as err:
raise err
except Exception as err:
raise err

Expand All @@ -192,69 +227,66 @@ async def delete(self, keys: List[str]):
key, self.config.key_suffix, self.config.compatibility_mode
)
try:
self.client.DeleteItem(
document_link=self.__item_link(escaped_key),
options=self.__get_partition_key(escaped_key),
self.container.delete_item(
escaped_key,
self.__get_partition_key(escaped_key),
)
except cosmos_errors.HTTPFailure as err:
if (
err.status_code
== cosmos_errors.http_constants.StatusCodes.NOT_FOUND
):
continue
raise err
except cosmos_exceptions.CosmosResourceNotFoundError:
continue
except Exception as err:
raise err

async def initialize(self):
if not self.container:
if not self.client:
connection_policy = self.config.cosmos_client_options.get(
"connection_policy", documents.ConnectionPolicy()
)

# kwargs 'connection_verify' is to handle CosmosClient overwriting the
# ConnectionPolicy.DisableSSLVerification value.
self.client = cosmos_client.CosmosClient(
self.config.cosmos_db_endpoint,
{"masterKey": self.config.auth_key},
self.config.cosmos_client_options.get("connection_policy", None),
self.config.auth_key,
self.config.cosmos_client_options.get("consistency_level", None),
**{
"connection_policy": connection_policy,
"connection_verify": not connection_policy.DisableSSLVerification,
},
)

if not self.database:
with self.__lock:
try:
if not self.database:
self.database = self.client.CreateDatabase(
{"id": self.config.database_id}
)
except cosmos_errors.HTTPFailure:
self.database = self.client.ReadDatabase(
"dbs/" + self.config.database_id
if not self.database:
self.database = self.client.create_database_if_not_exists(
self.config.database_id
)

self.__get_or_create_container()

def __get_or_create_container(self):
with self.__lock:
container_def = {
"id": self.config.container_id,
"partitionKey": {
"paths": ["/id"],
"kind": documents.PartitionKind.Hash,
},
partition_key = {
"paths": ["/id"],
"kind": documents.PartitionKind.Hash,
}
try:
if not self.container:
self.container = self.client.CreateContainer(
"dbs/" + self.database["id"],
container_def,
{"offerThroughput": self.config.container_throughput},
self.container = self.database.create_container(
self.config.container_id,
partition_key,
offer_throughput=self.config.container_throughput,
)
except cosmos_errors.HTTPFailure as err:
except cosmos_exceptions.CosmosHttpResponseError as err:
if err.status_code == http_constants.StatusCodes.CONFLICT:
self.container = self.client.ReadContainer(
"dbs/" + self.database["id"] + "/colls/" + container_def["id"]
self.container = self.database.get_container_client(
self.config.container_id
)
if "partitionKey" not in self.container:
properties = self.container.read()
if "partitionKey" not in properties:
self.compatability_mode_partition_key = True
else:
paths = self.container["partitionKey"]["paths"]
paths = properties["partitionKey"]["paths"]
if "/partitionKey" in paths:
self.compatability_mode_partition_key = True
elif "/id" not in paths:
Expand All @@ -267,7 +299,7 @@ def __get_or_create_container(self):
raise err

def __get_partition_key(self, key: str) -> str:
return None if self.compatability_mode_partition_key else {"partitionKey": key}
return None if self.compatability_mode_partition_key else key

@staticmethod
def __create_si(result) -> object:
Expand Down Expand Up @@ -303,28 +335,3 @@ def __create_dict(store_item: object) -> Dict:

# loop through attributes and write and return a dict
return json_dict

def __item_link(self, identifier) -> str:
"""Return the item link of a item in the container.
:param identifier:
:return str:
"""
return self.__container_link + "/docs/" + identifier

@property
def __container_link(self) -> str:
"""Return the container link in the database.
:param:
:return str:
"""
return self.__database_link + "/colls/" + self.config.container_id

@property
def __database_link(self) -> str:
"""Return the database link.
:return str:
"""
return "dbs/" + self.config.database_id
Loading

0 comments on commit 50e72c0

Please sign in to comment.