Skip to content

Commit

Permalink
Source-S3: Add support for EC2 instance profile
Browse files Browse the repository at this point in the history
  • Loading branch information
sidartha committed Feb 27, 2022
1 parent f4d54a9 commit 1d6a0a7
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 0 deletions.
6 changes: 6 additions & 0 deletions airbyte-config/init/src/main/resources/seed/source_specs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7003,6 +7003,12 @@
\ AWS."
default: ""
type: "string"
use_aws_default_credential_provider_chain:
title: "Use Aws default credential provider chain"
description: "Use default AWS credential provider chain (such as EC2\
\ instance profile). Leave the Access Key ID and Secret Access Key\
\ blank if setting this to true."
type: "boolean"
use_ssl:
title: "Use Ssl"
description: "Is remote server using secure SSL/TLS connection"
Expand Down
10 changes: 10 additions & 0 deletions airbyte-integrations/connectors/source-s3/source_s3/s3file.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def _setup_boto_session(self) -> None:
aws_secret_access_key=self._provider.get("aws_secret_access_key"),
)
self._boto_s3_resource = make_s3_resource(self._provider, session=self._boto_session)
elif self.use_aws_default_credential_provider_chain:
self._boto_session = boto3session.Session()
self._boto_s3_resource = make_s3_resource(self._provider, config=Config(), session=self._boto_session)
else:
self._boto_session = boto3session.Session()
self._boto_s3_resource = make_s3_resource(self._provider, config=Config(signature_version=UNSIGNED), session=self._boto_session)
Expand All @@ -43,6 +46,10 @@ def use_aws_account(provider: Mapping[str, str]) -> bool:
aws_secret_access_key = provider.get("aws_secret_access_key")
return True if (aws_access_key_id is not None and aws_secret_access_key is not None) else False

@staticmethod
def use_aws_default_credential_provider_chain(provider: Mapping[str, str]) -> bool:
return provider.get("use_aws_default_credential_provider_chain")

@contextmanager
def open(self, binary: bool) -> Iterator[Union[TextIO, BinaryIO]]:
"""
Expand All @@ -55,6 +62,9 @@ def open(self, binary: bool) -> Iterator[Union[TextIO, BinaryIO]]:
bucket = self._provider.get("bucket")
if self.use_aws_account(self._provider):
params = {"client": make_s3_client(self._provider, session=self._boto_session)}
elif self.use_aws_default_credential_provider_chain(self._provider):
config = ClientConfig()
params = {"client": make_s3_client(self._provider, config=config)}
else:
config = ClientConfig(signature_version=UNSIGNED)
params = {"client": make_s3_client(self._provider, config=config)}
Expand Down
4 changes: 4 additions & 0 deletions airbyte-integrations/connectors/source-s3/source_s3/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class Config:
)

endpoint: str = Field("", description="Endpoint to an S3 compatible service. Leave empty to use AWS.")
use_aws_default_credential_provider_chain: bool = Field(
default=None,
description="Use default AWS credential provider chain (such as EC2 instance profile). Leave the Access Key ID and Secret Access Key blank if setting this to true.",
)
use_ssl: bool = Field(default=None, description="Is remote server using secure SSL/TLS connection")
verify_ssl_cert: bool = Field(default=None, description="Allow self signed certificates")

Expand Down
3 changes: 3 additions & 0 deletions airbyte-integrations/connectors/source-s3/source_s3/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def _list_bucket(self, accept_key: Callable = lambda k: True) -> Iterator[FileIn
session = boto3session.Session(
aws_access_key_id=provider["aws_access_key_id"], aws_secret_access_key=provider["aws_secret_access_key"]
)
elif S3File.use_aws_default_credential_provider_chain(provider):
session = boto3session.Session()
client_config = Config()
else:
session = boto3session.Session()
client_config = Config(signature_version=UNSIGNED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,23 @@ class TestS3File:
)
def test_use_aws_account(self, provider: Mapping[str, str], return_true: bool) -> None:
assert S3File.use_aws_account(provider) is return_true

@pytest.mark.parametrize(
"provider, return_true",
[
(
{"storage": "S3", "bucket": "dummy", "path_prefix": "dummy", "use_aws_default_credential_provider_chain": True},
True,
),
(
{"storage": "S3", "bucket": "dummy", "path_prefix": "dummy", "use_aws_default_credential_provider_chain": False},
False,
),
(
{"storage": "S3", "bucket": "dummy", "path_prefix": "dummy", "use_aws_default_credential_provider_chain": None},
None,
),
],
)
def test_use_aws_default_credential_provider_chain(self, provider: Mapping[str, str], return_true: bool) -> None:
assert S3File.use_aws_default_credential_provider_chain(provider) is return_true

0 comments on commit 1d6a0a7

Please sign in to comment.