From 57d586571f2d468f9243a84fe26916d9bc37398d Mon Sep 17 00:00:00 2001 From: sarneaud Date: Mon, 8 Aug 2016 14:26:22 +1000 Subject: [PATCH 1/2] Put more reusable code into utils repo from buyer-frontend Most of this was being copy-pasted, or not even copy-pasted. A significant amount of test utility code is still being copy-pasted. I also took the opportunity to improve test coverage a bit. We could still do with some more test coverage of the flask_init.py file. --- dmutils/flask_init.py | 64 ++++++++++++++++++++--- dmutils/forms.py | 96 +++++++++++++++++++++++++++++++++- requirements.txt | 3 +- tests/helpers.py | 25 +++++++++ tests/templates/test_form.html | 4 ++ tests/test_forms.py | 69 ++++++++++++++++++++++++ 6 files changed, 252 insertions(+), 9 deletions(-) create mode 100644 tests/templates/test_form.html create mode 100644 tests/test_forms.py diff --git a/dmutils/flask_init.py b/dmutils/flask_init.py index ae292114..bc01d86e 100644 --- a/dmutils/flask_init.py +++ b/dmutils/flask_init.py @@ -1,8 +1,11 @@ import os from flask_featureflags.contrib.inline import InlineFeatureFlag from . import config, logging, proxy_fix, request_id, formats, filters -from flask import Markup +from flask import Markup, redirect, request, session from flask.ext.script import Manager, Server +from flask_login import current_user + +from user import User def init_app( @@ -42,11 +45,64 @@ def init_app( if search_api_client: search_api_client.init_app(application) + @application.before_request + def set_scheme(): + request.environ['wsgi.url_scheme'] = application.config['DM_HTTP_PROTO'] + @application.after_request def add_header(response): response.headers['X-Frame-Options'] = 'DENY' return response + @application.after_request + def add_cache_control(response): + if request.method != 'GET' or response.status_code in (301, 302): + return response + + vary = response.headers.get('Vary', None) + if vary: + response.headers['Vary'] = vary + ', Cookie' + else: + response.headers['Vary'] = 'Cookie' + + if current_user.is_authenticated: + response.cache_control.private = True + if response.cache_control.max_age is None: + response.cache_control.max_age = application.config['DM_DEFAULT_CACHE_MAX_AGE'] + + return response + + @application.context_processor + def inject_global_template_variables(): + return dict( + pluralize=pluralize, + **(application.config['BASE_TEMPLATE_DATA'] or {})) + + +def init_frontend_app(application, login_manager): + @login_manager.user_loader + def load_user(user_id): + return User.load_user(data_api_client, user_id) + + @application.before_request + def refresh_session(): + session.permanent = True + session.modified = True + + @application.before_request + def remove_trailing_slash(): + if request.path != application.config['URL_PREFIX'] + '/' and request.path.endswith('/'): + if request.query_string: + return redirect( + '{}?{}'.format( + request.path[:-1], + request.query_string.decode('utf-8') + ), + code=301 + ) + else: + return redirect(request.path[:-1], code=301) + @application.template_filter('markdown') def markdown_filter_flask(data): return Markup(filters.markdown_filter(data)) @@ -57,12 +113,6 @@ def markdown_filter_flask(data): application.add_template_filter(formats.datetimeformat) application.add_template_filter(filters.smartjoin) - @application.context_processor - def inject_global_template_variables(): - return dict( - pluralize=pluralize, - **(application.config['BASE_TEMPLATE_DATA'] or {})) - def pluralize(count, singular, plural): return singular if count == 1 else plural diff --git a/dmutils/forms.py b/dmutils/forms.py index 4383dcdd..1e348e26 100644 --- a/dmutils/forms.py +++ b/dmutils/forms.py @@ -1,4 +1,17 @@ -from wtforms import StringField +from datetime import timedelta +from functools import wraps +import re + +from flask import abort, current_app, render_template, request, Response, session +from wtforms import Form, StringField +from wtforms.csrf.core import CSRF +from wtforms.csrf.session import SessionCSRF +from wtforms.validators import Regexp + + +email_regex = Regexp(r'^[^@^\s]+@[\d\w-]+(\.[\d\w-]+)+$', + flags=re.UNICODE, + message='You must provide a valid email address') class StripWhitespaceStringField(StringField): @@ -12,3 +25,84 @@ def strip_whitespace(value): if value is not None and hasattr(value, 'strip'): return value.strip() return value + + +class FakeCsrf(CSRF): + """ + For testing purposes only. + """ + + valid_token = 'valid_fake_csrf_token' + + def generate_csrf_token(self, csrf_token): + return self.valid_token + + def validate_csrf_token(self, form, field): + if field.data != self.valid_token: + raise ValueError('Invalid (fake) CSRF token') + + +class DmForm(Form): + + class Meta: + csrf = True + csrf_class = SessionCSRF + csrf_secret = None + csrf_time_limit = None + + @property + def csrf_context(self): + return session + + def __init__(self, *args, **kwargs): + if current_app.config['CSRF_ENABLED']: + self.Meta.csrf_secret = current_app.config['SECRET_KEY'] + self.Meta.csrf_time_limit = timedelta(seconds=current_app.config['CSRF_TIME_LIMIT']) + elif current_app.config.get('CSRF_FAKED', False): + self.Meta.csrf_class = FakeCsrf + else: + # FIXME: deprecated + self.Meta.csrf = False + self.Meta.csrf_class = None + super(DmForm, self).__init__(*args, **kwargs) + + +def render_template_with_csrf(template_name, status_code=200, **kwargs): + if 'form' not in kwargs: + kwargs['form'] = DmForm() + response = Response(render_template(template_name, **kwargs)) + + # CSRF tokens are user-specific, even if the user isn't logged in + response.cache_control.private = True + + max_age = current_app.config['DM_DEFAULT_CACHE_MAX_AGE'] + max_age = min(max_age, current_app.config.get('CSRF_TIME_LIMIT', max_age)) + response.cache_control.max_age = max_age + + return response, status_code + + +def is_csrf_token_valid(): + if not current_app.config['CSRF_ENABLED'] and not current_app.config.get('CSRF_FAKED', False): + return True + if 'csrf_token' not in request.form: + return False + form = DmForm(csrf_token=request.form['csrf_token']) + return form.validate() + + +def valid_csrf_or_abort(): + if is_csrf_token_valid(): + return + current_app.logger.info( + u'csrf.invalid_token: Aborting request, user_id: {user_id}', + extra={'user_id': session.get('user_id', ' +{{ form.csrf_token }} +{{ form.stripped_field }} + diff --git a/tests/test_forms.py b/tests/test_forms.py new file mode 100644 index 00000000..3cb3e961 --- /dev/null +++ b/tests/test_forms.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- + +from dmutils.forms import DmForm, email_regex, FakeCsrf, render_template_with_csrf, StripWhitespaceStringField + +from helpers import BaseApplicationTest + + +class TestForm(DmForm): + stripped_field = StripWhitespaceStringField('Stripped', id='stripped_field') + + +class TestFormHandling(BaseApplicationTest): + + def test_whitespace_stripping(self): + with self.flask.app_context(): + form = TestForm(stripped_field=' asdf ', csrf_token=FakeCsrf.valid_token) + assert form.validate() + assert form.stripped_field.data == 'asdf' + + def test_csrf_protection(self): + with self.flask.app_context(): + form = TestForm(stripped_field='asdf', csrf_token='bad') + assert not form.validate() + assert 'csrf_token' in form.errors + + def test_does_not_crash_on_missing_csrf_token(self): + with self.flask.app_context(): + form = TestForm(stripped_field='asdf') + assert not form.validate() + assert 'csrf_token' in form.errors + + def test_render_template_with_csrf(self): + with self.flask.app_context(): + response, status_code = render_template_with_csrf('test_form.html', 123) + assert status_code == 123 + assert response.cache_control.private + assert response.cache_control.max_age == self.flask.config['CSRF_TIME_LIMIT'] + assert FakeCsrf.valid_token in response.data + + +def test_valid_email_formats(): + cases = [ + 'good@example.com', + 'good-email@example.com', + 'good-email+plus@example.com', + 'good@subdomain.example.com', + 'good@hyphenated-subdomain.example.com', + ] + for address in cases: + assert email_regex.regex.match(address) is not None, address + + +def test_invalid_email_formats(): + cases = [ + '', + 'bad', + 'bad@@example.com', + 'bad @example.com', + 'bad@.com', + 'bad.example.com', + '@', + '@example.com', + 'bad@', + 'bad@example.com,bad2@example.com', + 'bad@example.com bad2@example.com', + 'bad@example.com,other.example.com', + ] + for address in cases: + assert email_regex.regex.match(address) is None, address From b05c6518f16b26ae2a7106690b9cc42065cf4b1a Mon Sep 17 00:00:00 2001 From: sarneaud Date: Mon, 8 Aug 2016 14:51:29 +1000 Subject: [PATCH 2/2] Add frontend-specific request logging to utils --- dmutils/flask_init.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/dmutils/flask_init.py b/dmutils/flask_init.py index bc01d86e..4eee4b9a 100644 --- a/dmutils/flask_init.py +++ b/dmutils/flask_init.py @@ -5,7 +5,7 @@ from flask.ext.script import Manager, Server from flask_login import current_user -from user import User +from user import User, user_logging_string def init_app( @@ -80,6 +80,17 @@ def inject_global_template_variables(): def init_frontend_app(application, login_manager): + + def request_log_handler(response): + params = { + 'method': request.method, + 'url': request.url, + 'status': response.status_code, + 'user': user_logging_string(current_user), + } + application.logger.info('{method} {url} {status} {user}', extra=params) + application.extensions['request_log_handler'] = request_log_handler + @login_manager.user_loader def load_user(user_id): return User.load_user(data_api_client, user_id)