Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IAM based auth for S3 policy repo #691

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
113 changes: 110 additions & 3 deletions packages/opal-common/opal_common/sources/api_policy_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
from pathlib import Path
from typing import Optional, Tuple
from urllib.parse import urlparse
from xml.etree import ElementTree

import aiohttp
import aiofiles
from fastapi import status
from fastapi.exceptions import HTTPException
from opal_common.git_utils.tar_file_to_local_git_extractor import (
Expand All @@ -17,6 +19,7 @@
hash_file,
throw_if_bad_status_code,
tuple_to_dict,
async_time_cache,
)
from opal_server.config import PolicyBundleServerType
from tenacity import AsyncRetrying
Expand All @@ -43,6 +46,9 @@ class ApiPolicySource(BasePolicySource):
token (str, optional): auth token to include in connections to bundle server. Defaults to POLICY_BUNDLE_SERVER_TOKEN.
token_id (str, optional): auth token ID to include in connections to bundle server. Defaults to POLICY_BUNDLE_SERVER_TOKEN_ID.
bundle_server_type (PolicyBundleServerType, optional): the type of bundle server
region (str, optional): the aws region of s3 bucket containing the bundle
aws_role_arn (str, optional): the aws iam role to assume when accessing the s3 bucket. Only required when using temporary sts credentials.
aws_web_id_token_file (str, optional): the file containing a web id token for the target aws iam role. Only required when using temporary sts credentials.
"""

def __init__(
Expand All @@ -53,6 +59,8 @@ def __init__(
token: Optional[str] = None,
token_id: Optional[str] = None,
region: Optional[str] = None,
aws_role_arn: Optional[str] = None,
aws_web_id_token_file: Optional[str] = None,
bundle_server_type: Optional[PolicyBundleServerType] = None,
policy_bundle_path=".",
policy_bundle_git_add_pattern="*",
Expand All @@ -66,6 +74,8 @@ def __init__(
self.token_id = token_id
self.server_type = bundle_server_type
self.region = region
self.aws_role_arn = aws_role_arn
self.aws_web_id_token_file = aws_web_id_token_file
self.bundle_hash = None
self.etag = None
self.tmp_bundle_path = Path(policy_bundle_path)
Expand Down Expand Up @@ -126,7 +136,84 @@ async def api_update_policy(self) -> Tuple[bool, str, str]:
)
raise

def build_auth_headers(self, token=None, path=None):
@async_time_cache(ttl=3000)
async def get_temporary_sts_credentials(self) -> tuple[str, str, str]:
"""
This function will fetch a set of temporary credentials for a IAM role
from Amazon STS. It requires an aws region, the arn for the target role
and the file containing the web token.

This function will return the id and secret key required for login.
When using temporary credentials, AWS also requires a session token
which this function also provides.

This result of this funciton is cached to avoid being rate limited by
STS.
"""
assert self.aws_web_id_token_file
assert self.aws_role_arn
assert self.region

async with aiofiles.open(self.aws_web_id_token_file) as token_file:
token = await token_file.read()

sts_url = f"sts.{self.region}.amazonaws.com"
params: dict[str, str] = {
"Action": "AssumeRoleWithWebIdentity",
"DurationSeconds": "3600",
"RoleSessionName": "Opal",
"RoleArn": self.aws_role_arn,
"WebIdentityToken": token,
"Version": "2011-06-15",
}

async with aiohttp.ClientSession() as session:
try:
async with session.get(
f"https://{sts_url}",
params=params,
headers={"Content-Type": "application/xml"},
) as response:
if response.status == status.HTTP_404_NOT_FOUND:
logger.warning(
"requested url not found: {sts_url}",
sts_url=sts_url,
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"requested url not found: {sts_url}",
)

body = await response.read()

# the default aws xml namespace
ns = {"": "https://sts.amazonaws.com/doc/2011-06-15/"}

et = ElementTree.fromstring(body)
credentials = et.find(
"AssumeRoleWithWebIdentityResult/Credentials", ns
)
assert credentials

id = credentials.findtext("AccessKeyId", namespaces=ns)
key = credentials.findtext("SecretAccessKey", namespaces=ns)
session_token = credentials.findtext("SessionToken", namespaces=ns)

assert id
assert key
assert session_token

except (aiohttp.ClientError, HTTPException) as e:
logger.warning("server connection error: {err}", err=repr(e))
raise
except Exception as e:
logger.error("unexpected server connection error: {err}", err=repr(e))
raise

logger.info("Successfully generated temporary AWS credentials")
return id, key, session_token

async def build_auth_headers(self, token=None, path=None):
# if it's a simple HTTP server with a bearer token
if self.server_type == PolicyBundleServerType.HTTP and token is not None:
return tuple_to_dict(get_authorization_header(token))
Expand All @@ -136,14 +223,34 @@ def build_auth_headers(self, token=None, path=None):
and token is not None
and self.token_id is not None
):
logger.info("Using provided token to log in to AWS_S3")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kiesenverseist hi, you don't really need elif here, right as there is a return from each case before moving to the next. so you could just have the if's one after the other starting from line 221.


split_url = urlparse(self.remote_source_url)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kiesenverseist this is duplicate with the code below, I would refactor the section of both cases of S3 to use common code. right?

host = split_url.netloc
path = split_url.path + "/" + path

return build_aws_rest_auth_headers(
self.token_id, token, host, path, self.region
)
elif (
self.server_type == PolicyBundleServerType.AWS_S3
and self.aws_role_arn is not None
and self.aws_web_id_token_file is not None
and self.region is not None
):
logger.info("Using IAM Web auth to log in to AWS_S3")

split_url = urlparse(self.remote_source_url)
host = split_url.netloc
path = split_url.path + "/" + path

id, key, session_token = await self.get_temporary_sts_credentials()

return build_aws_rest_auth_headers(
id, key, host, path, self.region, session_token
)
else:
logger.info("Not authenticating on bundle endpoint")
return {}

async def fetch_policy_bundle_from_api_source(
Expand All @@ -166,7 +273,7 @@ async def fetch_policy_bundle_from_api_source(
"""
path = "bundle.tar.gz"

auth_headers = self.build_auth_headers(token=token, path=path)
auth_headers = await self.build_auth_headers(token=token, path=path)
etag_headers = (
{"ETag": self.etag, "If-None-Match": self.etag} if self.etag else {}
)
Expand Down Expand Up @@ -278,4 +385,4 @@ async def check_for_changes(self):
prev_head=prev,
new_head=latest,
)
await self._on_new_policy(old=prev_commit, new=new_commit)
await self._on_new_policy(old=prev_commit, new=new_commit)
46 changes: 44 additions & 2 deletions packages/opal-common/opal_common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import threading
from datetime import datetime
from hashlib import sha1
from typing import Coroutine, Dict, List, Tuple
from typing import Callable, Coroutine, Dict, List, Tuple
import functools
import time

import aiohttp

Expand Down Expand Up @@ -57,7 +59,12 @@ def get_authorization_header(token: str) -> Tuple[str, str]:


def build_aws_rest_auth_headers(
key_id: str, secret_key: str, host: str, path: str, region: str
key_id: str,
secret_key: str,
host: str,
path: str,
region: str,
token: str | None,
):
"""Use the AWS signature algorithm (https://docs.aws.amazon.com/AmazonS3/la
test/userguide/RESTAuthentication.html) to generate the hTTP headers.
Expand All @@ -67,6 +74,7 @@ def build_aws_rest_auth_headers(
secret_key (str): Secret key (aka password) of an account in the S3 service.
host (str): S3 storage host
path (str): path to bundle file in s3 storage (including bucket)
token (str | None): Optional session token when using temporary credential.

Returns: http headers
"""
Expand All @@ -91,6 +99,10 @@ def getSignatureKey(key, dateStamp, regionName, serviceName):
canonical_headers = "host:" + host + "\n" + "x-amz-date:" + amzdate + "\n"
signed_headers = "host;x-amz-date"

if token:
canonical_headers += f"x-amz-security-token:{token}\n"
signed_headers += ";x-amz-security-token"

payload_hash = hashlib.sha256("".encode("utf-8")).hexdigest()

canonical_request = (
Expand Down Expand Up @@ -138,8 +150,13 @@ def getSignatureKey(key, dateStamp, regionName, serviceName):
+ signature
)

token_header: dict[str, str] = {}
if token:
token_header["x-amz-security-token"] = token

return {
"x-amz-date": amzdate,
**token_header,
"x-amz-content-sha256": SHA256_EMPTY,
"Authorization": authorization_header,
}
Expand Down Expand Up @@ -275,3 +292,28 @@ def run_coro(self, coro: Coroutine):
run_coro() is thread-safe.
"""
return asyncio.run_coroutine_threadsafe(coro, loop=self.loop).result()


def async_time_cache(ttl: float):
"""
This decorator is a wrapper around lru_cache that makes it time sensitive.

ttl is in seconds
"""

def decorator(func: Callable):
# instead of directly caching the function, a time "hash" is
# also passed in as a param that will invalidate the cache
# after at most ttl seconds
@functools.lru_cache
def wrapped(*args, __ttl_hash=None, **kwargs):
coro = func(*args, **kwargs)
return asyncio.ensure_future(coro)

def ret(*args, **kwargs):
ttl_hash = round(time.time() / ttl)
return wrapped(*args, **kwargs, __ttl_hash=ttl_hash)

return ret

return decorator
15 changes: 14 additions & 1 deletion packages/opal-server/opal_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class OpalServerConfig(Confi):
AUTH_PRIVATE_KEY_PASSPHRASE = confi.str("AUTH_PRIVATE_KEY_PASSPHRASE", None)

AUTH_PRIVATE_KEY = confi.delay(
lambda AUTH_PRIVATE_KEY_FORMAT=None, AUTH_PRIVATE_KEY_PASSPHRASE="": confi.private_key(
lambda AUTH_PRIVATE_KEY_FORMAT=None,
AUTH_PRIVATE_KEY_PASSPHRASE="": confi.private_key(
"AUTH_PRIVATE_KEY",
default=None,
key_format=AUTH_PRIVATE_KEY_FORMAT,
Expand Down Expand Up @@ -133,6 +134,18 @@ class OpalServerConfig(Confi):
"us-east-1",
description="The AWS region of the S3 bucket",
)
POLICY_BUNDLE_AWS_ROLE_ARN = confi.str(
"AWS_ROLE_ARN",
# default to the env var injected by aws
os.getenv("AWS_ROLE_ARN"),
description="The IAM role to be used when accessing the bundle server. This is set by AWS automatically in EKS, but can be overridden if required.",
)
POLICY_BUNDLE_AWS_WEB_IDENTITY_TOKEN_FILE = confi.str(
"AWS_WEB_IDENTITY_TOKEN_FILE",
# default to the env var injected by aws
os.getenv("AWS_WEB_IDENTITY_TOKEN_FILE"),
description="The oidc token for the IAM role to be used when accessing the bundle server. This is set by AWS automatically in EKS, but can be overridden if required.",
)
POLICY_BUNDLE_TMP_PATH = confi.str(
"POLICY_BUNDLE_TMP_PATH",
"/tmp/bundle.tar.gz",
Expand Down
3 changes: 3 additions & 0 deletions packages/opal-server/opal_server/policy/watcher/factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import partial
from typing import Any, List, Optional
import os

from fastapi_websocket_pubsub.pub_sub_server import PubSubEndpoint
from opal_common.confi.confi import load_conf_if_none
Expand Down Expand Up @@ -129,6 +130,8 @@ def setup_watcher_task(
policy_bundle_path=opal_server_config.POLICY_BUNDLE_TMP_PATH,
policy_bundle_git_add_pattern=opal_server_config.POLICY_BUNDLE_GIT_ADD_PATTERN,
region=policy_bundle_aws_region,
aws_role_arn=opal_server_config.POLICY_BUNDLE_AWS_ROLE_ARN,
aws_web_id_token_file=opal_server_config.POLICY_BUNDLE_AWS_WEB_IDENTITY_TOKEN_FILE,
)
else:
raise ValueError("Unknown value for OPAL_POLICY_SOURCE_TYPE")
Expand Down
Loading