Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issue #2810] Connect all the components of the /users/token endpoint together #3004

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/local.env
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ LOG_ENABLE_AUDIT=FALSE
# The auth token used by the local endpoints
API_AUTH_TOKEN=LOCAL_AUTH_12345678,LOCAL_AUTH_87654321,LOCAL_1234

LOGIN_GOV_JWK_ENDPOINT=http://localhost:5001/issuer1/jwks
LOGIN_GOV_JWK_ENDPOINT=http://mock-oauth2-server:5001/issuer1/jwks
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should look into defining this different in some way? Got a weird email directly from GitHub complaining about:

GitGuardian has detected the following Generic High Entropy Secret exposed within your GitHub account.

I can mark as a false positive, but maybe this is a problem for some reason.

LOGIN_GOV_ENDPOINT=http://localhost:5001
LOGIN_GOV_CLIENT_ID=TODO

Expand Down
27 changes: 10 additions & 17 deletions api/src/api/users/user_routes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging

import src.adapters.db as db
import src.adapters.db.flask_db as flask_db
from src.api import response
from src.api.route_utils import raise_flask_error
from src.api.users import user_schemas
from src.api.users.user_blueprint import user_blueprint
from src.auth.api_key_auth import api_key_auth
from src.services.users.login_gov_token_handler import process_login_gov_token

logger = logging.getLogger(__name__)

Expand All @@ -15,22 +17,13 @@
)
@user_blueprint.output(user_schemas.UserTokenResponseSchema)
@user_blueprint.auth_required(api_key_auth)
def user_token(x_oauth_login_gov: dict) -> response.ApiResponse:
@flask_db.with_db_session()
def user_token(db_session: db.Session, x_oauth_login_gov: dict) -> response.ApiResponse:
logger.info("POST /v1/users/token")

if x_oauth_login_gov:
data = {
"token": "the token goes here!",
"user": {
"user_id": "abc-...",
"email": "[email protected]",
"external_user_type": "login_gov",
},
"is_user_new": True,
}
return response.ApiResponse(message="Success", data=data)
with db_session.begin():
# UserTokenHeaderSchema validates that the header is present, so safe to fetch this way
result = process_login_gov_token(x_oauth_login_gov["x_oauth_login_gov"], db_session)

message = "Missing X-OAuth-login-gov header"
logger.info(message)

raise_flask_error(400, message)
logger.info("Successfully generated token for user")
return response.ApiResponse(message="Success", data=result)
1 change: 1 addition & 0 deletions api/src/api/users/user_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
class UserTokenHeaderSchema(Schema):
x_oauth_login_gov = fields.String(
data_key="X-OAuth-login-gov",
required=True,
metadata={
"description": "The login_gov header token",
},
Expand Down
2 changes: 2 additions & 0 deletions api/src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from src.app_config import AppConfig
from src.auth.api_jwt_auth import initialize_jwt_auth
from src.auth.auth_utils import get_app_security_scheme
from src.auth.login_gov_jwt_auth import initialize_login_gov_config
from src.data_migration.data_migration_blueprint import data_migration_blueprint
from src.search.backend.load_search_data_blueprint import load_search_data_blueprint
from src.task import task_blueprint
Expand Down Expand Up @@ -63,6 +64,7 @@ def create_app() -> APIFlask:
# will reuse the config from it, for now we'll do this a bit hacky
# This cannot be removed non-locally until we setup RSA keys for non-local envs
if os.getenv("ENVIRONMENT") == "local":
initialize_login_gov_config()
initialize_jwt_auth()

return app
Expand Down
36 changes: 31 additions & 5 deletions api/src/auth/login_gov_jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,34 @@ class LoginGovConfig(PydanticBaseEnvConfig):
client_id: str = Field(alias="LOGIN_GOV_CLIENT_ID")


def get_login_gov_config() -> LoginGovConfig:
config = LoginGovConfig()
_refresh_keys(config)
return config
# Initialize a config at startup
_config: LoginGovConfig | None = None


def initialize_login_gov_config() -> None:
global _config
if not _config:
_config = LoginGovConfig()
_refresh_keys(_config)

logger.info(
"Constructed login.gov configuration",
extra={
"login_gov_jwk_endpoint": _config.login_gov_jwk_endpoint,
"login_gov_endpoint": _config.login_gov_endpoint,
},
)


def get_config() -> LoginGovConfig:
global _config

if _config is None:
raise Exception(
"No Login.gov configuration - initialize_login_gov_config() must be run first"
)

return _config


@dataclasses.dataclass
Expand Down Expand Up @@ -65,7 +89,9 @@ def _refresh_keys(config: LoginGovConfig) -> None:
config.public_keys = list(public_keys)


def validate_token(token: str, config: LoginGovConfig) -> LoginGovUser:
def validate_token(token: str, config: LoginGovConfig | None = None) -> LoginGovUser:
if not config:
config = get_config()

# TODO - this iteration approach won't be necessary if the JWT we get
# from login.gov does actually set the KID in the header
Expand Down
13 changes: 7 additions & 6 deletions api/src/db/models/user_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,28 @@
from sqlalchemy.orm import Mapped, mapped_column, relationship

from src.adapters.db.type_decorators.postgres_type_decorators import LookupColumn
from src.constants.lookup_constants import ExternalUserType
from src.db.models.base import ApiSchemaTable, TimestampMixin
from src.db.models.lookup_models import LkExternalUserType


class User(ApiSchemaTable, TimestampMixin):
__tablename__ = "user"

user_id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True)
user_id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True, default=uuid.uuid4)


class LinkExternalUser(ApiSchemaTable, TimestampMixin):
__tablename__ = "link_external_user"

link_external_user_id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True)
link_external_user_id: Mapped[uuid.UUID] = mapped_column(primary_key=True, default=uuid.uuid4)

external_user_id: Mapped[str] = mapped_column(index=True, unique=True)

user_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey(User.user_id), index=True)
user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey(User.user_id), index=True)
user: Mapped[User] = relationship(User)

external_user_type: Mapped[int] = mapped_column(
external_user_type: Mapped[ExternalUserType] = mapped_column(
"external_user_type_id",
LookupColumn(LkExternalUserType),
ForeignKey(LkExternalUserType.external_user_type_id),
Expand All @@ -39,10 +40,10 @@ class LinkExternalUser(ApiSchemaTable, TimestampMixin):
class UserTokenSession(ApiSchemaTable, TimestampMixin):
__tablename__ = "user_token_session"

user_id: Mapped[uuid.UUID] = mapped_column(UUID, ForeignKey(User.user_id), primary_key=True)
user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey(User.user_id), primary_key=True)
user: Mapped[User] = relationship(User)

token_id: Mapped[uuid.UUID] = mapped_column(UUID, primary_key=True)
token_id: Mapped[uuid.UUID] = mapped_column(primary_key=True)

expires_at: Mapped[datetime]

Expand Down
Empty file.
78 changes: 78 additions & 0 deletions api/src/services/users/login_gov_token_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import logging

from sqlalchemy import select
from sqlalchemy.orm import selectinload

import src.adapters.db as db
from src.api.route_utils import raise_flask_error
from src.auth.api_jwt_auth import create_jwt_for_user
from src.auth.auth_errors import JwtValidationError
from src.auth.login_gov_jwt_auth import validate_token
from src.constants.lookup_constants import ExternalUserType
from src.db.models.user_models import LinkExternalUser, User

logger = logging.getLogger(__name__)


def process_login_gov_token(token: str, db_session: db.Session) -> dict:

try:
login_gov_user = validate_token(token)
except JwtValidationError as e:
logger.info("Login.gov token validation failed", extra={"auth.issue": e.message})
raise_flask_error(401, e.message)

external_user: LinkExternalUser | None = db_session.execute(
select(LinkExternalUser)
.where(LinkExternalUser.external_user_id == login_gov_user.user_id)
# We only support login.gov right now, so this does nothing, but let's
# be explicit just in case.
.where(LinkExternalUser.external_user_type == ExternalUserType.LOGIN_GOV)
.options(selectinload("*"))
).scalar()

is_user_new = external_user is None

# If we didn't find anything, we want to create the user
if external_user is None:
external_user = _create_login_gov_user(login_gov_user.user_id, db_session)

# Update fields on the external user table
external_user.email = login_gov_user.email

# Flush the records to the DB so any auto-generated IDs and similar are populated
# prior to us trying to work with the user further.
# NOTE: This doesn't commit yet - but effectively moves the cache from memory to the DB transaction
db_session.flush()

token, _ = create_jwt_for_user(external_user.user, db_session)

# TODO - make a pydantic object? return token + user? Figure it out
return _build_response(token, external_user, is_user_new)


def _create_login_gov_user(external_user_id: str, db_session: db.Session) -> LinkExternalUser:
user = User()
db_session.add(user)

external_user = LinkExternalUser(
user=user,
external_user_type=ExternalUserType.LOGIN_GOV,
external_user_id=external_user_id,
# note we set other params in the calling method to also handle updates
)
db_session.add(external_user)

return external_user


def _build_response(token: str, external_user: LinkExternalUser, is_user_new: bool) -> dict:
return {
"token": token,
"user": {
"user_id": external_user.user_id,
"email": external_user.email,
"external_user_type": external_user.external_user_type,
},
"is_user_new": is_user_new,
}
59 changes: 58 additions & 1 deletion api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
import moto
import pytest
from apiflask import APIFlask
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from sqlalchemy import text

import src.adapters.db as db
import src.app as app_entry
import src.auth.login_gov_jwt_auth as login_gov_jwt_auth
import tests.src.db.models.factories as factories
from src.adapters import search
from src.constants.schema import Schemas
Expand Down Expand Up @@ -217,6 +220,60 @@ def opportunity_index_alias(search_client, monkeypatch_session):
return alias


####################
# Auth
####################


def _generate_rsa_key_pair():
# Rather than define a private/public key, generate one for the tests
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)

private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)

public_key = key.public_key().public_bytes(
encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo
)

return private_key, public_key


@pytest.fixture(scope="session")
def rsa_key_pair():
return _generate_rsa_key_pair()


@pytest.fixture(scope="session")
def private_rsa_key(rsa_key_pair):
return rsa_key_pair[0]


@pytest.fixture(scope="session")
def public_rsa_key(rsa_key_pair):
return rsa_key_pair[1]


@pytest.fixture(scope="session")
def other_rsa_key_pair():
return _generate_rsa_key_pair()


@pytest.fixture(scope="session")
def setup_login_gov_auth(monkeypatch_session, public_rsa_key):
# TODO - describe
def override_method(config):
config.public_keys = [public_rsa_key]

monkeypatch_session.setattr(login_gov_jwt_auth, "_refresh_keys", override_method)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for self - what if I instead made a dummy endpoint that returned the keys? That would at least validate the parsing logic


monkeypatch_session.setenv("LOGIN_GOV_ENDPOINT", "http://localhost:3000")
monkeypatch_session.setenv("LOGIN_GOV_CLIENT_ID", "AUDIENCE_TEST")


####################
# Test App & Client
####################
Expand All @@ -225,7 +282,7 @@ def opportunity_index_alias(search_client, monkeypatch_session):
# Make app session scoped so the database connection pool is only created once
# for the test session. This speeds up the tests.
@pytest.fixture(scope="session")
def app(db_client, opportunity_index_alias) -> APIFlask:
def app(db_client, opportunity_index_alias, setup_login_gov_auth) -> APIFlask:
return app_entry.create_app()


Expand Down
Loading
Loading