diff --git a/dmutils/csrf.py b/dmutils/csrf.py new file mode 100644 index 00000000..bb8d8e0f --- /dev/null +++ b/dmutils/csrf.py @@ -0,0 +1,24 @@ +import os +import binascii +from flask import session, request + + +TOKEN = '_csrf_token' +REACT_HEADER_NAME = 'X-CSRFToken' + + +def random_string(length=32): + return binascii.b2a_hex(os.urandom(length)).decode('utf-8') + + +def get_csrf_token(): + if TOKEN not in session: + session[TOKEN] = random_string() + return session[TOKEN] + + +def check_valid_header_csrf(): + try: + return session[TOKEN] == request.headers[REACT_HEADER_NAME] + except KeyError: + return False diff --git a/dmutils/flask_init.py b/dmutils/flask_init.py index 25ce8102..958c925d 100644 --- a/dmutils/flask_init.py +++ b/dmutils/flask_init.py @@ -5,8 +5,8 @@ import flask_featureflags from . import config, logging, force_https, request_id, formats, filters -from flask import Markup, redirect, request, session -from flask_script import Manager, Server +from flask import Markup, redirect, request, session, current_app, abort +from flask.ext.script import Manager, Server from flask_login import current_user from werkzeug.contrib.fixers import ProxyFix @@ -14,7 +14,9 @@ from .user import User, user_logging_string from dmutils import terms_of_use -from dmutils.forms import valid_csrf_or_abort +from dmutils.forms import is_csrf_token_valid + +from .csrf import check_valid_header_csrf def init_app( @@ -97,7 +99,14 @@ def load_user(user_id): @application.before_request def check_csrf_token(): if request.method in ('POST', 'PATCH', 'PUT', 'DELETE'): - valid_csrf_or_abort() + flask_csrf_valid = is_csrf_token_valid() + react_csrf_valid = check_valid_header_csrf() + + if not (flask_csrf_valid or react_csrf_valid): + current_app.logger.info( + u'csrf.invalid_token: Aborting request, user_id: {user_id}', + extra={'user_id': session.get('user_id', '