Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/azure_mail/get_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Get the app access token for Azure Mail API."""

from azure_mail.main import _get_app_access_token

if __name__ == "__main__":
token = _get_app_access_token() # print(json.dumps(token))
68 changes: 60 additions & 8 deletions src/azure_mail/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Create an email to send with an Azure app."""

import atexit
import concurrent.futures
import datetime
import json
import os
import pathlib
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor

import dateutil.parser
import exchangelib
Expand Down Expand Up @@ -42,6 +45,16 @@ def _check_or_set_up_cache() -> msal.SerializableTokenCache:
return cache


def initialise_app(
client_id: str, authority: str, token_cache: msal.SerializableTokenCache
) -> msal.PublicClientApplication:
return msal.PublicClientApplication(
client_id,
authority=authority,
token_cache=token_cache,
)


def _get_app_access_token() -> dict:
"""
Acquire an access token for the Azure app through the MSAL library.
Expand All @@ -52,11 +65,25 @@ def _get_app_access_token() -> dict:

"""
authority = "https://login.microsoftonline.com/" + os.environ["TENANT_ID"]
global_token_cache = _check_or_set_up_cache()
app = msal.PublicClientApplication(

def check_cache() -> msal.SerializableTokenCache:
global_token_cache = _check_or_set_up_cache()
if not global_token_cache.has_state_changed:
return global_token_cache
return None

with ThreadPoolExecutor() as executor:
future = executor.submit(check_cache)
try:
global_token_cache = future.result(timeout=10)
except ThreadPoolExecutor as err:
msg = "Token cache check timed out."
raise RuntimeError(msg) from err

app = initialise_app(
os.environ["CLIENT_ID"],
authority=authority,
token_cache=global_token_cache,
authority,
global_token_cache,
)

accounts = app.get_accounts(username=os.environ["ACCOUNT"])
Expand All @@ -70,11 +97,11 @@ def interactive_auth() -> dict:
[os.environ["SCOPE"]], login_hint=os.environ["ACCOUNT"]
)

with concurrent.futures.ThreadPoolExecutor() as executor:
with ThreadPoolExecutor() as executor:
future = executor.submit(interactive_auth)
try:
result = future.result(timeout=10) # Timeout set to 10 seconds
except concurrent.futures.TimeoutError as err:
except ThreadPoolExecutor as err:
msg = "Interactive authentication timed out."
raise RuntimeError(msg) from err

Expand Down Expand Up @@ -114,6 +141,30 @@ def _setup_email_account(
)


def get_token_with_timeout(timeout: int) -> dict:
try:
# Find get_token.py in the same directory as this file
script_path = (pathlib.Path(__file__).parent / "get_token.py").resolve()
if not script_path.is_file():
message = f"Script not found: {script_path}"
raise RuntimeError(message)
result = subprocess.run( # noqa: S603
[os.sys.executable, str(script_path)],
capture_output=True,
text=True,
timeout=timeout,
check=False,
)
if result.returncode != 0:
message = f"Token script failed: {result.stderr}"
raise RuntimeError(message)
print("DEBUG: result.stdout =", repr(result.stdout), file=sys.stderr) # noqa: T201
return json.loads(result.stdout)
except subprocess.TimeoutExpired as err:
message = "Token acquisition timed out."
raise RuntimeError(message) from err


def create_email_list(
limit: str,
recipients: list[str],
Expand All @@ -124,7 +175,8 @@ def create_email_list(
If you wish to send an email using the members of the distribution list, you can
create a list with [member.mailbox for member in distribution_list.members].
"""
access_token = _get_app_access_token()
access_token = get_token_with_timeout(timeout=10)

account = _setup_email_account(
access_token=access_token,
)
Expand Down