Skip to content

Commit

Permalink
🎉 Source Okta: added parameter 'start_date' (#15050)
Browse files Browse the repository at this point in the history
* Added parameter 'start_date' in Okta source

added: parameter 'start_date' to source Okta
changed: unit tests

* changes: fix in the case of ISSUE: #14196

* Okta documentation in new format

* changes: fix to use super() instead of instance of stream parent

* changes: additional changes into OKTA documentaton

* changes: switch release to beta

* changed: set dockerImageTag -> 0.1.11

* changed:  source_specs

* ...

* ...

* Rollback releaseStage

* Refactored start date logic

* Deleted microseconds from state

* Add start date to all streams

* Updated to linter

* Fixed unit tests

* Updated unit tests

* auto-bump connector version [ci skip]

Co-authored-by: Serhii <serglazebny@gmail.com>
Co-authored-by: Octavia Squidington III <octavia-squidington-iii@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 8, 2022
1 parent 8282a45 commit 8d9a3aa
Show file tree
Hide file tree
Showing 13 changed files with 431 additions and 269 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@
- name: Okta
sourceDefinitionId: 1d4fdb25-64fc-4569-92da-fcdca79a8372
dockerRepository: airbyte/source-okta
dockerImageTag: 0.1.11
dockerImageTag: 0.1.12
documentationUrl: https://docs.airbyte.io/integrations/sources/okta
icon: okta.svg
sourceType: api
Expand Down
18 changes: 13 additions & 5 deletions airbyte-config/init/src/main/resources/seed/source_specs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6031,7 +6031,7 @@
- - "client_secret"
oauthFlowOutputParameters:
- - "access_token"
- dockerImage: "airbyte/source-okta:0.1.11"
- dockerImage: "airbyte/source-okta:0.1.12"
spec:
documentationUrl: "https://docs.airbyte.io/integrations/sources/okta"
connectionSpecification:
Expand All @@ -6047,6 +6047,14 @@
description: "The Okta domain. See the <a href=\"https://docs.airbyte.io/integrations/sources/okta\"\
>docs</a> for instructions on how to find it."
airbyte_secret: false
start_date:
type: "string"
pattern: "^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}Z$"
description: "UTC date and time in the format YYYY-MM-DDTHH:MM:SSZ. Any\
\ data before this date will not be replicated."
examples:
- "2022-07-22T00:00:00Z"
title: "Start Date"
credentials:
title: "Authorization Method *"
type: "object"
Expand Down Expand Up @@ -6107,15 +6115,15 @@
oauth_config_specification:
oauth_user_input_from_connector_config_specification:
type: "object"
additionalProperties: false
additionalProperties: true
properties:
domain:
type: "string"
path_in_connector_config:
- "domain"
complete_oauth_output_specification:
type: "object"
additionalProperties: false
additionalProperties: true
properties:
refresh_token:
type: "string"
Expand All @@ -6124,15 +6132,15 @@
- "refresh_token"
complete_oauth_server_input_specification:
type: "object"
additionalProperties: false
additionalProperties: true
properties:
client_id:
type: "string"
client_secret:
type: "string"
complete_oauth_server_output_specification:
type: "object"
additionalProperties: false
additionalProperties: true
properties:
client_id:
type: "string"
Expand Down
2 changes: 1 addition & 1 deletion airbyte-integrations/connectors/source-okta/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ RUN pip install .
ENV AIRBYTE_ENTRYPOINT "python /airbyte/integration_code/main.py"
ENTRYPOINT ["python", "/airbyte/integration_code/main.py"]

LABEL io.airbyte.version=0.1.11
LABEL io.airbyte.version=0.1.12
LABEL io.airbyte.name=airbyte/source-okta
2 changes: 1 addition & 1 deletion airbyte-integrations/connectors/source-okta/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ Make sure to familiarize yourself with [pytest test discovery](https://docs.pyte
First install test dependencies into your virtual environment:

```shell
pip install .[tests]
pip install .'[tests]'
```

### Unit Tests
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from typing import Any, Mapping, Tuple

import requests
from airbyte_cdk.sources.streams.http.auth import Oauth2Authenticator


class OktaOauth2Authenticator(Oauth2Authenticator):
def get_refresh_request_body(self) -> Mapping[str, Any]:
return {
"grant_type": "refresh_token",
"refresh_token": self.refresh_token,
}

def refresh_access_token(self) -> Tuple[str, int]:
try:
response = requests.request(
method="POST",
url=self.token_refresh_endpoint,
data=self.get_refresh_request_body(),
auth=(self.client_id, self.client_secret),
)
response.raise_for_status()
response_json = response.json()
return response_json["access_token"], response_json["expires_in"]
except Exception as e:
raise Exception(f"Error while refreshing access token: {e}") from e
157 changes: 64 additions & 93 deletions airbyte-integrations/connectors/source-okta/source_okta/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@
from airbyte_cdk.sources import AbstractSource
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.http import HttpStream
from airbyte_cdk.sources.streams.http.auth import Oauth2Authenticator, TokenAuthenticator

from .utils import datetime_to_string, delete_milliseconds, get_api_endpoint, get_start_date, initialize_authenticator


class OktaStream(HttpStream, ABC):
page_size = 200

def __init__(self, url_base: str, *args, **kwargs):
def __init__(self, url_base: str, start_date: pendulum.datetime, *args, **kwargs):
super().__init__(*args, **kwargs)
# Inject custom url base to the stream
self._url_base = url_base.rstrip("/") + "/"
self.start_date = start_date

@property
def url_base(self) -> str:
Expand Down Expand Up @@ -97,11 +99,10 @@ def request_params(
stream_slice: Mapping[str, any] = None,
next_page_token: Mapping[str, Any] = None,
) -> MutableMapping[str, Any]:
stream_state = stream_state or {}
params = super().request_params(stream_state, stream_slice, next_page_token)
latest_entry = stream_state.get(self.cursor_field)
if latest_entry:
params["filter"] = f'{self.cursor_field} gt "{latest_entry}"'
latest_entry = stream_state.get(self.cursor_field) if stream_state else datetime_to_string(self.start_date)
filter_param = {"filter": f'{self.cursor_field} gt "{latest_entry}"'}
params.update(filter_param)
return params


Expand All @@ -120,7 +121,7 @@ class GroupMembers(IncrementalOktaStream):
use_cache = True

def stream_slices(self, **kwargs):
group_stream = Groups(authenticator=self.authenticator, url_base=self.url_base)
group_stream = Groups(authenticator=self.authenticator, url_base=self.url_base, start_date=self.start_date)
for group in group_stream.read_records(sync_mode=SyncMode.full_refresh):
yield {"group_id": group["id"]}

Expand All @@ -134,10 +135,12 @@ def request_params(
stream_slice: Mapping[str, any] = None,
next_page_token: Mapping[str, Any] = None,
) -> MutableMapping[str, Any]:
params = OktaStream.request_params(self, stream_state, stream_slice, next_page_token)
latest_entry = stream_state.get(self.cursor_field)
if latest_entry:
params["after"] = latest_entry
# Filter param should be ignored SCIM filter expressions can't use the published
# attribute since it may conflict with the logic of the since, after, and until query params.
# Docs: https://developer.okta.com/docs/reference/api/system-log/#expression-filter
params = super(IncrementalOktaStream, self).request_params(stream_state, stream_slice, next_page_token)
latest_entry = stream_state.get(self.cursor_field) if stream_state else self.min_user_id
params["after"] = latest_entry
return params

def get_updated_state(self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]) -> Mapping[str, Any]:
Expand All @@ -154,7 +157,7 @@ class GroupRoleAssignments(OktaStream):
use_cache = True

def stream_slices(self, **kwargs):
group_stream = Groups(authenticator=self.authenticator, url_base=self.url_base)
group_stream = Groups(authenticator=self.authenticator, url_base=self.url_base, start_date=self.start_date)
for group in group_stream.read_records(sync_mode=SyncMode.full_refresh):
yield {"group_id": group["id"]}

Expand All @@ -168,6 +171,28 @@ class Logs(IncrementalOktaStream):
cursor_field = "published"
primary_key = "uuid"

def __init__(self, url_base, **kwargs):
super().__init__(url_base=url_base, **kwargs)
self._raise_on_http_errors: bool = True

@property
def raise_on_http_errors(self) -> bool:
return self._raise_on_http_errors

def should_retry(self, response: requests.Response) -> bool:
"""
When the connector gets abnormal state API retrun errror with 400 status code
and internal error code E0000001. The connector ignores an error with 400 code
to finish successfully sync and inform the user about an error in logs with an
error message.
"""

if response.status_code == 400 and response.json().get("errorCode") == "E0000001":
self.logger.info(f"{response.json()['errorSummary']}")
self._raise_on_http_errors = False
return False
return HttpStream.should_retry(self, response)

def path(self, **kwargs) -> str:
return "logs"

Expand All @@ -177,24 +202,27 @@ def request_params(
stream_slice: Mapping[str, any] = None,
next_page_token: Mapping[str, Any] = None,
) -> MutableMapping[str, Any]:
# The log stream use a different params to get data
# https://developer.okta.com/docs/reference/api/system-log/#datetime-filter
stream_state = stream_state or {}
params = OktaStream.request_params(self, stream_state, stream_slice, next_page_token)
latest_entry = stream_state.get(self.cursor_field)
if latest_entry:
params["since"] = latest_entry
# [Test-driven Development] Set until When the cursor value from the stream state
# is abnormally large, otherwise the server side that sets now to until
# will throw an error: The "until" date must be later than the "since" date
# https://developer.okta.com/docs/reference/api/system-log/#request-parameters
parsed = pendulum.parse(latest_entry)
utc_now = pendulum.utcnow()
if parsed > utc_now:
params["until"] = latest_entry

# The log stream use a different params to get data.
# Docs: https://developer.okta.com/docs/reference/api/system-log/#datetime-filter
# Filter param should be ignored SCIM filter expressions can't use the published
# attribute since it may conflict with the logic of the since, after, and until query params.
# Docs: https://developer.okta.com/docs/reference/api/system-log/#expression-filter
params = super(IncrementalOktaStream, self).request_params(stream_state, stream_slice, next_page_token)
latest_entry = stream_state.get(self.cursor_field) if stream_state else self.start_date
params["since"] = latest_entry
return params

def parse_response(
self,
response: requests.Response,
**kwargs,
) -> Iterable[Mapping]:
data = response.json() if isinstance(response.json(), list) else []

for record in data:
record[self.cursor_field] = delete_milliseconds(record[self.cursor_field])
yield record


class Users(IncrementalOktaStream):
cursor_field = "lastUpdated"
Expand Down Expand Up @@ -242,7 +270,7 @@ class UserRoleAssignments(OktaStream):
use_cache = True

def stream_slices(self, **kwargs):
user_stream = Users(authenticator=self.authenticator, url_base=self.url_base)
user_stream = Users(authenticator=self.authenticator, url_base=self.url_base, start_date=self.start_date)
for user in user_stream.read_records(sync_mode=SyncMode.full_refresh):
yield {"user_id": user["id"]}

Expand All @@ -264,7 +292,7 @@ def parse_response(
yield from response.json()["permissions"]

def stream_slices(self, **kwargs):
custom_roles = CustomRoles(authenticator=self.authenticator, url_base=self.url_base)
custom_roles = CustomRoles(authenticator=self.authenticator, url_base=self.url_base, start_date=self.start_date)
for role in custom_roles.read_records(sync_mode=SyncMode.full_refresh):
yield {"role_id": role["id"]}

Expand All @@ -273,66 +301,11 @@ def path(self, stream_slice: Mapping[str, Any] = None, **kwargs) -> str:
return f"iam/roles/{role_id}/permissions"


class OktaOauth2Authenticator(Oauth2Authenticator):
def get_refresh_request_body(self) -> Mapping[str, Any]:
return {
"grant_type": "refresh_token",
"refresh_token": self.refresh_token,
}

def refresh_access_token(self) -> Tuple[str, int]:
try:
response = requests.request(
method="POST",
url=self.token_refresh_endpoint,
data=self.get_refresh_request_body(),
auth=(self.client_id, self.client_secret),
)
response.raise_for_status()
response_json = response.json()
return response_json["access_token"], response_json["expires_in"]
except Exception as e:
raise Exception(f"Error while refreshing access token: {e}") from e


class SourceOkta(AbstractSource):
def initialize_authenticator(self, config: Mapping[str, Any]):
if "token" in config:
return TokenAuthenticator(config["token"], auth_method="SSWS")

creds = config.get("credentials")
if not creds:
raise Exception("Config validation error. `credentials` not specified.")

auth_type = creds.get("auth_type")
if not auth_type:
raise Exception("Config validation error. `auth_type` not specified.")

if auth_type == "api_token":
return TokenAuthenticator(creds["api_token"], auth_method="SSWS")

if auth_type == "oauth2.0":
return OktaOauth2Authenticator(
token_refresh_endpoint=self.get_token_refresh_endpoint(config),
client_secret=creds["client_secret"],
client_id=creds["client_id"],
refresh_token=creds["refresh_token"],
)

@staticmethod
def get_url_base(config: Mapping[str, Any]) -> str:
return config.get("base_url") or f"https://{config['domain']}.okta.com"

def get_api_endpoint(self, config: Mapping[str, Any]) -> str:
return parse.urljoin(self.get_url_base(config), "/api/v1/")

def get_token_refresh_endpoint(self, config: Mapping[str, Any]) -> str:
return parse.urljoin(self.get_url_base(config), "/oauth2/v1/token")

def check_connection(self, logger, config) -> Tuple[bool, any]:
try:
auth = self.initialize_authenticator(config)
api_endpoint = self.get_api_endpoint(config)
auth = initialize_authenticator(config)
api_endpoint = get_api_endpoint(config)
url = parse.urljoin(api_endpoint, "users")

response = requests.get(
Expand All @@ -349,13 +322,11 @@ def check_connection(self, logger, config) -> Tuple[bool, any]:
return False, "Failed to authenticate with the provided credentials"

def streams(self, config: Mapping[str, Any]) -> List[Stream]:
auth = self.initialize_authenticator(config)
api_endpoint = self.get_api_endpoint(config)
auth = initialize_authenticator(config)
api_endpoint = get_api_endpoint(config)
start_date = get_start_date(config)

initialization_params = {
"authenticator": auth,
"url_base": api_endpoint,
}
initialization_params = {"authenticator": auth, "url_base": api_endpoint, "start_date": start_date}

return [
Groups(**initialization_params),
Expand Down
Loading

0 comments on commit 8d9a3aa

Please sign in to comment.