Skip to content
This repository has been archived by the owner on Jul 7, 2022. It is now read-only.

Commit

Permalink
alternative csrf (#66)
Browse files Browse the repository at this point in the history
* alternative csrf

* use response module to mock http requests
  • Loading branch information
djrobstep authored Nov 16, 2016
1 parent 90a3572 commit 35d700e
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 78 deletions.
27 changes: 21 additions & 6 deletions dmutils/csrf.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
import binascii
from flask import session, request
from flask import session, request, current_app


TOKEN = '_csrf_token'
OLD_TOKEN = 'csrf_token'
REACT_HEADER_NAME = 'X-CSRFToken'


Expand All @@ -17,8 +18,22 @@ def get_csrf_token():
return session[TOKEN]


def check_valid_header_csrf():
try:
return session[TOKEN] == request.headers[REACT_HEADER_NAME]
except KeyError:
return False
def check_valid_csrf():
if not current_app.config.get('CSRF_ENABLED') and not current_app.config.get('CSRF_FAKED'):
return True

tokens_received = [
request.form.get(OLD_TOKEN, None),
request.form.get(TOKEN, None),
request.headers.get(REACT_HEADER_NAME, None)
]
tokens_received = set(filter(None, tokens_received))

tokens_from_session = [
session.get(TOKEN, None),
session.get(OLD_TOKEN, None)
]
tokens_from_session = set(filter(None, tokens_from_session))

intersect = tokens_received.intersection(tokens_from_session)
return bool(intersect)
8 changes: 4 additions & 4 deletions dmutils/flask_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dmutils import terms_of_use
from dmutils.forms import is_csrf_token_valid

from .csrf import check_valid_header_csrf
from .csrf import check_valid_csrf


def init_app(
Expand Down Expand Up @@ -99,10 +99,10 @@ def load_user(user_id):
@application.before_request
def check_csrf_token():
if request.method in ('POST', 'PATCH', 'PUT', 'DELETE'):
flask_csrf_valid = is_csrf_token_valid()
react_csrf_valid = check_valid_header_csrf()
old_csrf_valid = is_csrf_token_valid()
new_csrf_valid = check_valid_csrf()

if not (flask_csrf_valid or react_csrf_valid):
if not (old_csrf_valid or new_csrf_valid):
current_app.logger.info(
u'csrf.invalid_token: Aborting request, user_id: {user_id}',
extra={'user_id': session.get('user_id', '<unknown')})
Expand Down
6 changes: 5 additions & 1 deletion react/render_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ def render(self):


class RenderServer(object):
@property
def url(self):
return current_app.config.get('REACT_RENDER_URL', '')

def render(self, path, props=None, to_static_markup=False, request_headers=None):
url = current_app.config.get('REACT_RENDER_URL', '')
url = self.url

if props is None:
props = {}
Expand Down
1 change: 1 addition & 0 deletions requirements_for_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ pytest-sugar
cffi
tox
flake8
responses
45 changes: 40 additions & 5 deletions tests/test_flask_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,53 @@ def test_csrf_wrong(self):
)
assert res.status_code == 400

def test_alternate_csrf(self):
def test_new_style_csrf(self):
with self.app.session_transaction() as sess:
sess['_csrf_token'] = 'abc123'
sess['csrf_token'] = 'abc123'
sess['_csrf_token'] = 'def456'

for t in ['abc123', 'def456']:
res = self.app.post(
'/thing',
headers={
'X-CSRFToken': t
}
)
assert res.status_code == 200

res = self.app.post(
'/thing',
data={'csrf_token': t}
)
assert res.status_code == 200

res = self.app.post(
'/thing',
data={'_csrf_token': t}
)
assert res.status_code == 200

BAD = 'bad'

res = self.app.post(
'/thing',
data={'csrf_token': 'nope'},
headers={
'X-CSRFToken': 'abc123'
'X-CSRFToken': BAD
}
)
assert res.status_code == 200
assert res.status_code == 400

res = self.app.post(
'/thing',
data={'csrf_token': BAD}
)
assert res.status_code == 400

res = self.app.post(
'/thing',
data={'_csrf_token': BAD}
)
assert res.status_code == 400


class TestTemplateFilters(BaseApplicationTest):
Expand Down
42 changes: 21 additions & 21 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
import tempfile
import logging
import mock
from requests import Response
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
import responses
import six
import json

from dmutils import request_id
Expand All @@ -16,6 +13,11 @@

from tests.helpers import BaseApplicationTest, Config

if six.PY2:
from io import BytesIO as StringIO
else:
from io import StringIO


def test_request_id_filter_not_in_app_context():
assert RequestIdFilter().request_id == 'no-request-id'
Expand Down Expand Up @@ -170,19 +172,18 @@ class TestNotifyTeam(BaseApplicationTest):

config = NotifyTeamConfig()

@responses.activate
@mock.patch('dmutils.logging.send_email')
@mock.patch('dmutils.logging.requests')
def test_notify(self, requests, send_email):
def test_notify(self, send_email):
with self.flask.app_context():
slack_response = Response()
slack_response.status_code = 200
requests.post.return_value = slack_response
responses.add(responses.POST, url=self.config.DM_TEAM_SLACK_WEBHOOK, body='')

notify_team('Something Happened', 'It happened', 'https://example.com/it')

requests.post.assert_called_with(
self.config.DM_TEAM_SLACK_WEBHOOK,
json=mock.ANY,
)
resp = responses.calls[0].response

assert resp.url == self.config.DM_TEAM_SLACK_WEBHOOK
assert resp.json == mock.ANY

send_email.assert_called_once_with(
self.config.DM_TEAM_EMAIL,
Expand All @@ -192,18 +193,17 @@ def test_notify(self, requests, send_email):
self.config.DM_GENERIC_ADMIN_NAME,
)

@responses.activate
@mock.patch('dmutils.logging.send_email')
@mock.patch('dmutils.logging.requests')
def test_slack_error_path(self, requests, send_email):
def test_slack_error_path(self, send_email):
with self.flask.app_context():
error_response = Response()
error_response.status_code = 400
requests.post.return_value = error_response
responses.add(responses.POST, url=self.config.DM_TEAM_SLACK_WEBHOOK, status=400)
notify_team('Something Happened', 'It happened', 'https://example.com/it')

@responses.activate
@mock.patch('dmutils.logging.send_email')
@mock.patch('dmutils.logging.requests')
def test_email_error_path(self, requests, send_email):
def test_email_error_path(self, send_email):
with self.flask.app_context():
responses.add(responses.POST, url=self.config.DM_TEAM_SLACK_WEBHOOK, status=400)
send_email.side_effect = EmailError(':(')
notify_team('Something Happened', 'It happened', 'https://example.com/it')
80 changes: 39 additions & 41 deletions tests/test_render_server.py
Original file line number Diff line number Diff line change
@@ -1,107 +1,105 @@
from __future__ import absolute_import
from __future__ import absolute_import, unicode_literals

from mock import patch
from .helpers import BaseApplicationTest, Config
from react.render_server import render_server
from requests import Response
from hashlib import sha1
import pytest
from react.exceptions import RenderServerError, ReactRenderingError
from react.response import validate_form_data, from_response
from flask import request
from werkzeug.datastructures import MultiDict
import requests
import responses
from six.moves.urllib import parse as urls


class RenderConfig(Config):
REACT_RENDER = True
REACT_RENDER_URL = '/render'
REACT_RENDER_URL = 'http://example.com/render'
SERVER_NAME = 'http://api'


class TestRenderServer(BaseApplicationTest):
config = RenderConfig()

@responses.activate
@patch('react.render_server.hashlib')
@patch('react.render_server.requests')
@patch('react.render_server.get_csrf_token')
def test_render_server_success(self, get_csrf_token, requests, hashlib):
def test_render_server_success(self, get_csrf_token, hashlib):
get_csrf_token.return_value = 'abc123'

with self.flask.test_request_context('/test'):
sha = sha1()
hashlib.sha1.return_value = sha

res = Response()
res.status_code = 200
markup = 'hello world!'
res.json = lambda: {'markup': markup}
requests.post.return_value = res

path = '/widget/component.js'
result = render_server.render(path)
params = {'hash': sha.hexdigest()}

responses.add(responses.POST, render_server.url, json={'markup': markup})

result = render_server.render(path)
assert result.render() == markup
requests.post.assert_called_with(
'/render',
headers={'content-type': 'application/json'},
params={'hash': sha.hexdigest()},
data='{"path": "' + path + '", ''"serializedProps": "{\\"_serverContext\\": '
'{\\"location\\": \\"/test\\"}, \\"form_options\\": {\\"csrf_token\\": \\"abc123\\"}, '
'\\"options\\": '
'{\\"apiUrl\\": \\"http://api\\", \\"serverRender\\": true}}", '
'"toStaticMarkup": false}'
)

assert len(responses.calls) == 1
req = responses.calls[0].request

assert req.url == self.config.REACT_RENDER_URL + '?' + urls.urlencode(params)
assert req.headers['content-type'] == 'application/json'
assert req.body == '{"path": "' + path + '", ''"serializedProps": "{\\"_serverContext\\": ' \
'{\\"location\\": \\"/test\\"}, \\"form_options\\": {\\"csrf_token\\": \\"abc123\\"}, ' \
'\\"options\\": ' \
'{\\"apiUrl\\": \\"http://api\\", \\"serverRender\\": true}}", ' \
'"toStaticMarkup": false}'

@responses.activate
@patch('react.render_server.get_csrf_token')
def test_react_render_not_set(self, get_csrf_token):
get_csrf_token.return_value = 'abc123'

self.flask.config.update({'REACT_RENDER': None})

with self.flask.test_request_context('/test'):
responses.add(responses.POST, render_server.url, json={})

result = render_server.render('/widget/component.js')
assert result.render() == ''
assert result.get_props() == '{"_serverContext": ' \
'{"location": "/test"}, "form_options": {"csrf_token": "abc123"}, ' \
'"options": {"apiUrl": "http://api", ' \
'"serverRender": true}}'

@patch('react.render_server.requests')
def test_connection_error(self, requests):
@responses.activate
def test_connection_error(self):
e = requests.exceptions.ConnectionError('mock connection error!')

with self.flask.test_request_context('/test'):
requests.post.side_effect = requests.exceptions.ConnectionError
responses.add(responses.POST, render_server.url, body=e)

with pytest.raises(RenderServerError):
render_server.render('/path')

@patch('react.render_server.requests')
def test_non_200_status_code(self, requests):
@responses.activate
def test_non_200_status_code(self):
with self.flask.test_request_context('/test'):
res = Response()
res.status_code = 400
requests.post.return_value = res
responses.add(responses.POST, render_server.url, status=400)

with pytest.raises(RenderServerError):
render_server.render('/path')

@patch('react.render_server.requests')
def test_no_markup(self, requests):
@responses.activate
def test_no_markup(self):
with self.flask.test_request_context('/test'):
res = Response()
res.status_code = 200
res.json = lambda: {'markup': None}
requests.post.return_value = res
responses.add(responses.POST, render_server.url, json={'markup': None})

with pytest.raises(ReactRenderingError):
render_server.render('/path')

@patch('react.render_server.requests')
def test_render_error(self, requests):
@responses.activate
def test_render_error(self,):
with self.flask.test_request_context('/test'):
res = Response()
res.status_code = 200
res.json = lambda: {'error': 'an error'}
requests.post.return_value = res
responses.add(responses.POST, render_server.url, json={'error': 'an error'})

with pytest.raises(ReactRenderingError):
render_server.render('/path')
Expand Down

0 comments on commit 35d700e

Please sign in to comment.