Skip to content
This repository has been archived by the owner on Jul 7, 2022. It is now read-only.

Commit

Permalink
Merge pull request #10 from AusDTO/sharing-is-caring
Browse files Browse the repository at this point in the history
Put more reusable code into utils repo from buyer-frontend
  • Loading branch information
sarneaud authored Aug 8, 2016
2 parents 022e969 + b05c651 commit 3e89a38
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 9 deletions.
75 changes: 68 additions & 7 deletions dmutils/flask_init.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
from flask_featureflags.contrib.inline import InlineFeatureFlag
from . import config, logging, proxy_fix, request_id, formats, filters
from flask import Markup
from flask import Markup, redirect, request, session
from flask.ext.script import Manager, Server
from flask_login import current_user

from user import User, user_logging_string


def init_app(
Expand Down Expand Up @@ -42,11 +45,75 @@ def init_app(
if search_api_client:
search_api_client.init_app(application)

@application.before_request
def set_scheme():
request.environ['wsgi.url_scheme'] = application.config['DM_HTTP_PROTO']

@application.after_request
def add_header(response):
response.headers['X-Frame-Options'] = 'DENY'
return response

@application.after_request
def add_cache_control(response):
if request.method != 'GET' or response.status_code in (301, 302):
return response

vary = response.headers.get('Vary', None)
if vary:
response.headers['Vary'] = vary + ', Cookie'
else:
response.headers['Vary'] = 'Cookie'

if current_user.is_authenticated:
response.cache_control.private = True
if response.cache_control.max_age is None:
response.cache_control.max_age = application.config['DM_DEFAULT_CACHE_MAX_AGE']

return response

@application.context_processor
def inject_global_template_variables():
return dict(
pluralize=pluralize,
**(application.config['BASE_TEMPLATE_DATA'] or {}))


def init_frontend_app(application, login_manager):

def request_log_handler(response):
params = {
'method': request.method,
'url': request.url,
'status': response.status_code,
'user': user_logging_string(current_user),
}
application.logger.info('{method} {url} {status} {user}', extra=params)
application.extensions['request_log_handler'] = request_log_handler

@login_manager.user_loader
def load_user(user_id):
return User.load_user(data_api_client, user_id)

@application.before_request
def refresh_session():
session.permanent = True
session.modified = True

@application.before_request
def remove_trailing_slash():
if request.path != application.config['URL_PREFIX'] + '/' and request.path.endswith('/'):
if request.query_string:
return redirect(
'{}?{}'.format(
request.path[:-1],
request.query_string.decode('utf-8')
),
code=301
)
else:
return redirect(request.path[:-1], code=301)

@application.template_filter('markdown')
def markdown_filter_flask(data):
return Markup(filters.markdown_filter(data))
Expand All @@ -57,12 +124,6 @@ def markdown_filter_flask(data):
application.add_template_filter(formats.datetimeformat)
application.add_template_filter(filters.smartjoin)

@application.context_processor
def inject_global_template_variables():
return dict(
pluralize=pluralize,
**(application.config['BASE_TEMPLATE_DATA'] or {}))


def pluralize(count, singular, plural):
return singular if count == 1 else plural
Expand Down
96 changes: 95 additions & 1 deletion dmutils/forms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
from wtforms import StringField
from datetime import timedelta
from functools import wraps
import re

from flask import abort, current_app, render_template, request, Response, session
from wtforms import Form, StringField
from wtforms.csrf.core import CSRF
from wtforms.csrf.session import SessionCSRF
from wtforms.validators import Regexp


email_regex = Regexp(r'^[^@^\s]+@[\d\w-]+(\.[\d\w-]+)+$',
flags=re.UNICODE,
message='You must provide a valid email address')


class StripWhitespaceStringField(StringField):
Expand All @@ -12,3 +25,84 @@ def strip_whitespace(value):
if value is not None and hasattr(value, 'strip'):
return value.strip()
return value


class FakeCsrf(CSRF):
"""
For testing purposes only.
"""

valid_token = 'valid_fake_csrf_token'

def generate_csrf_token(self, csrf_token):
return self.valid_token

def validate_csrf_token(self, form, field):
if field.data != self.valid_token:
raise ValueError('Invalid (fake) CSRF token')


class DmForm(Form):

class Meta:
csrf = True
csrf_class = SessionCSRF
csrf_secret = None
csrf_time_limit = None

@property
def csrf_context(self):
return session

def __init__(self, *args, **kwargs):
if current_app.config['CSRF_ENABLED']:
self.Meta.csrf_secret = current_app.config['SECRET_KEY']
self.Meta.csrf_time_limit = timedelta(seconds=current_app.config['CSRF_TIME_LIMIT'])
elif current_app.config.get('CSRF_FAKED', False):
self.Meta.csrf_class = FakeCsrf
else:
# FIXME: deprecated
self.Meta.csrf = False
self.Meta.csrf_class = None
super(DmForm, self).__init__(*args, **kwargs)


def render_template_with_csrf(template_name, status_code=200, **kwargs):
if 'form' not in kwargs:
kwargs['form'] = DmForm()
response = Response(render_template(template_name, **kwargs))

# CSRF tokens are user-specific, even if the user isn't logged in
response.cache_control.private = True

max_age = current_app.config['DM_DEFAULT_CACHE_MAX_AGE']
max_age = min(max_age, current_app.config.get('CSRF_TIME_LIMIT', max_age))
response.cache_control.max_age = max_age

return response, status_code


def is_csrf_token_valid():
if not current_app.config['CSRF_ENABLED'] and not current_app.config.get('CSRF_FAKED', False):
return True
if 'csrf_token' not in request.form:
return False
form = DmForm(csrf_token=request.form['csrf_token'])
return form.validate()


def valid_csrf_or_abort():
if is_csrf_token_valid():
return
current_app.logger.info(
u'csrf.invalid_token: Aborting request, user_id: {user_id}',
extra={'user_id': session.get('user_id', '<unknown')})
abort(400, 'Invalid CSRF token. Please try again.')


def check_csrf(view):
@wraps(view)
def wrapped(*args, **kwargs):
valid_csrf_or_abort()
return view(*args, **kwargs)
return wrapped
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ pyyaml==3.11
python-json-logger==0.1.4
inflection==0.2.1
Flask-FeatureFlags==0.6
Flask-Login==0.3.2
monotonic==0.3
pytz==2015.4
Flask-WTF==0.12
markdown==2.6.2
WTForms==2.1
Flask-Script==2.0.5
waitress==0.9.0
workdays==1.4
25 changes: 25 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from flask import Flask
from flask_login import LoginManager

from dmutils.flask_init import init_app, init_frontend_app

from datetime import datetime
import mock

Expand All @@ -14,3 +19,23 @@ def mock_file(filename, length, name=None):
mock_file.name = name

return mock_file


class Config(object):

CSRF_ENABLED = False
CSRF_FAKED = True
CSRF_TIME_LIMIT = 30
DM_DEFAULT_CACHE_MAX_AGE = 60
SECRET_KEY = 'secret'
BASE_TEMPLATE_DATA = {}


class BaseApplicationTest(object):

def setup(self):
self.flask = Flask('test_app', template_folder='tests/templates/')
self.login_manager = LoginManager()
init_app(self.flask, Config)
init_frontend_app(self.flask, self.login_manager)
self.app = self.flask.test_client()
4 changes: 4 additions & 0 deletions tests/templates/test_form.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
<form>
{{ form.csrf_token }}
{{ form.stripped_field }}
</form>
69 changes: 69 additions & 0 deletions tests/test_forms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-

from dmutils.forms import DmForm, email_regex, FakeCsrf, render_template_with_csrf, StripWhitespaceStringField

from helpers import BaseApplicationTest


class TestForm(DmForm):
stripped_field = StripWhitespaceStringField('Stripped', id='stripped_field')


class TestFormHandling(BaseApplicationTest):

def test_whitespace_stripping(self):
with self.flask.app_context():
form = TestForm(stripped_field=' asdf ', csrf_token=FakeCsrf.valid_token)
assert form.validate()
assert form.stripped_field.data == 'asdf'

def test_csrf_protection(self):
with self.flask.app_context():
form = TestForm(stripped_field='asdf', csrf_token='bad')
assert not form.validate()
assert 'csrf_token' in form.errors

def test_does_not_crash_on_missing_csrf_token(self):
with self.flask.app_context():
form = TestForm(stripped_field='asdf')
assert not form.validate()
assert 'csrf_token' in form.errors

def test_render_template_with_csrf(self):
with self.flask.app_context():
response, status_code = render_template_with_csrf('test_form.html', 123)
assert status_code == 123
assert response.cache_control.private
assert response.cache_control.max_age == self.flask.config['CSRF_TIME_LIMIT']
assert FakeCsrf.valid_token in response.data


def test_valid_email_formats():
cases = [
'[email protected]',
'[email protected]',
'[email protected]',
'[email protected]',
'[email protected]',
]
for address in cases:
assert email_regex.regex.match(address) is not None, address


def test_invalid_email_formats():
cases = [
'',
'bad',
'bad@@example.com',
'bad @example.com',
'[email protected]',
'bad.example.com',
'@',
'@example.com',
'bad@',
'[email protected],[email protected]',
'[email protected] [email protected]',
'[email protected],other.example.com',
]
for address in cases:
assert email_regex.regex.match(address) is None, address

0 comments on commit 3e89a38

Please sign in to comment.