Skip to content

Commit

Permalink
chore(general): support custom headers in platform integration (#6054)
Browse files Browse the repository at this point in the history
* chore(general): support custom headers in platform integration

* remove todo

* fix mypy issue

---------

Co-authored-by: Steve Vaknin <svaknin@paloaltonetworks.com>
  • Loading branch information
SteveVaknin and SteveVaknin authored Mar 3, 2024
1 parent f8ab3e3 commit a474920
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def _get_fixes_for_file(

headers = merge_dicts(
get_default_post_headers(self.bc_integration.bc_source, self.bc_integration.bc_source_version),
{"Authorization": self.bc_integration.get_auth_token()}
{"Authorization": self.bc_integration.get_auth_token()},
self.bc_integration.custom_auth_headers
)

if not self.bc_integration.http:
Expand Down
31 changes: 21 additions & 10 deletions checkov/common/bridgecrew/platform_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def clean(self) -> None:
self.prisma_api_url = normalize_prisma_url(os.getenv('PRISMA_API_URL', 'https://api0.prismacloud.io'))
self.prisma_policies_url: str | None = None
self.prisma_policy_filters_url: str | None = None
self.custom_auth_headers: dict[str, str] = {}
self.setup_api_urls()
self.customer_run_config_response = None
self.runtime_run_config_response = None
Expand Down Expand Up @@ -163,6 +164,7 @@ def init_instance(self, platform_integration_data: dict[str, Any]) -> None:
self.credentials = platform_integration_data["credentials"]
self.platform_integration_configured = platform_integration_data["platform_integration_configured"]
self.prisma_api_url = platform_integration_data["prisma_api_url"]
self.custom_auth_headers = platform_integration_data["custom_auth_headers"]
self.repo_branch = platform_integration_data["repo_branch"]
self.repo_id = platform_integration_data["repo_id"]
self.repo_path = platform_integration_data["repo_path"]
Expand All @@ -187,6 +189,7 @@ def generate_instance_data(self) -> dict[str, Any]:
"credentials": self.credentials,
"platform_integration_configured": self.platform_integration_configured,
"prisma_api_url": self.prisma_api_url,
"custom_auth_headers": self.custom_auth_headers,
"repo_branch": self.repo_branch,
"repo_id": self.repo_id,
"repo_path": self.repo_path,
Expand Down Expand Up @@ -479,7 +482,8 @@ def _get_s3_creds(self, repo_id: str, token: str) -> dict[str, Any]:
request = self.http.request("POST", self.integrations_api_url, # type:ignore[union-attr]
body=json.dumps({"repoId": repo_id, "support": self.support_flag_enabled}),
headers=merge_dicts({"Authorization": token, "Content-Type": "application/json"},
get_user_agent_header()))
get_user_agent_header(),
self.custom_auth_headers))
logging.debug(f'Request ID: {request.headers.get("x-amzn-requestid")}')
logging.debug(f'Trace ID: {request.headers.get("x-amzn-trace-id")}')
if request.status == 403:
Expand Down Expand Up @@ -834,7 +838,8 @@ def commit_repository(self, branch: str) -> str | None:
"Content-Type": "application/json",
'x-api-client': self.bc_source.name,
'x-api-checkov-version': checkov_version},
get_user_agent_header()
get_user_agent_header(),
self.custom_auth_headers
))
response = json.loads(request.data.decode("utf8"))
logging.debug(f'Request ID: {request.headers.get("x-amzn-requestid")}')
Expand Down Expand Up @@ -939,7 +944,8 @@ def get_customer_run_config(self) -> None:
try:
token = self.get_auth_token()
headers = merge_dicts(get_auth_header(token),
get_default_get_headers(self.bc_source, self.bc_source_version))
get_default_get_headers(self.bc_source, self.bc_source_version),
self.custom_auth_headers)

self.setup_http_manager()
if not self.http:
Expand Down Expand Up @@ -989,7 +995,8 @@ def get_reachability_run_config(self) -> Union[Dict[str, Any], None]:
try:
token = self.get_auth_token()
headers = merge_dicts(get_auth_header(token),
get_default_get_headers(self.bc_source, self.bc_source_version))
get_default_get_headers(self.bc_source, self.bc_source_version),
self.custom_auth_headers)

self.setup_http_manager()
if not self.http:
Expand Down Expand Up @@ -1030,7 +1037,8 @@ def get_runtime_run_config(self) -> None:

token = self.get_auth_token()
headers = merge_dicts(get_auth_header(token),
get_default_get_headers(self.bc_source, self.bc_source_version))
get_default_get_headers(self.bc_source, self.bc_source_version),
self.custom_auth_headers)

self.setup_http_manager()
if not self.http:
Expand Down Expand Up @@ -1075,7 +1083,7 @@ def get_prisma_build_policies(self, policy_filter: str) -> None:

try:
token = self.get_auth_token()
headers = merge_dicts(get_prisma_auth_header(token), get_prisma_get_headers())
headers = merge_dicts(get_prisma_auth_header(token), get_prisma_get_headers(), self.custom_auth_headers)

self.setup_http_manager()
if not self.http:
Expand Down Expand Up @@ -1107,7 +1115,7 @@ def get_prisma_policy_filters(self) -> Dict[str, Dict[str, Any]]:
request = None
try:
token = self.get_auth_token()
headers = merge_dicts(get_prisma_auth_header(token), get_prisma_get_headers())
headers = merge_dicts(get_prisma_auth_header(token), get_prisma_get_headers(), self.custom_auth_headers)

self.setup_http_manager()
if not self.http:
Expand Down Expand Up @@ -1301,10 +1309,12 @@ def get_default_headers(self, request_type: str) -> dict[str, Any]:

if request_type.upper() == "GET":
return merge_dicts(get_default_get_headers(self.bc_source, self.bc_source_version),
{"Authorization": self.get_auth_token()})
{"Authorization": self.get_auth_token()},
self.custom_auth_headers)
elif request_type.upper() == "POST":
return merge_dicts(get_default_post_headers(self.bc_source, self.bc_source_version),
{"Authorization": self.get_auth_token()})
{"Authorization": self.get_auth_token()},
self.custom_auth_headers)

logging.info(f"Unsupported request {request_type}")
return {}
Expand All @@ -1316,7 +1326,8 @@ def get_sso_prismacloud_url(self, report_url: str) -> str:
url_saml_config = f"{bc_integration.prisma_api_url}/saml/config"
token = self.get_auth_token()
headers = merge_dicts(get_auth_header(token),
get_default_get_headers(self.bc_source, self.bc_source_version))
get_default_get_headers(self.bc_source, self.bc_source_version),
bc_integration.custom_auth_headers)

request = self.http.request("GET", url_saml_config, headers=headers, timeout=10) # type:ignore[no-untyped-call]
if request.status >= 300:
Expand Down

0 comments on commit a474920

Please sign in to comment.