# coding: utf-8
"""
flask_wtf.csrf
~~~~~~~~~~~~~~
CSRF protection for Flask.
:copyright: (c) 2013 by Hsiaoming Yang.
"""
import os
import hmac
import hashlib
import time
from flask import current_app, session, request, abort
from ._compat import to_bytes, string_types
try:
from urlparse import urlparse
except ImportError:
# python 3
from urllib.parse import urlparse
__all__ = ('generate_csrf', 'validate_csrf', 'CsrfProtect')
def generate_csrf(secret_key=None, time_limit=None):
"""Generate csrf token code.
:param secret_key: A secret key for mixing in the token,
default is Flask.secret_key.
:param time_limit: Token valid in the time limit,
default is 3600s.
"""
if not secret_key:
secret_key = current_app.config.get(
'WTF_CSRF_SECRET_KEY', current_app.secret_key
)
if not secret_key:
raise Exception('Must provide secret_key to use csrf.')
if time_limit is None:
time_limit = current_app.config.get('WTF_CSRF_TIME_LIMIT', 3600)
if 'csrf_token' not in session:
session['csrf_token'] = hashlib.sha1(os.urandom(64)).hexdigest()
if time_limit:
expires = time.time() + time_limit
csrf_build = '%s%s' % (session['csrf_token'], expires)
else:
expires = ''
csrf_build = session['csrf_token']
hmac_csrf = hmac.new(
to_bytes(secret_key),
to_bytes(csrf_build),
digestmod=hashlib.sha1
).hexdigest()
return '%s##%s' % (expires, hmac_csrf)
def validate_csrf(data, secret_key=None, time_limit=None):
"""Check if the given data is a valid csrf token.
:param data: The csrf token value to be checked.
:param secret_key: A secret key for mixing in the token,
default is Flask.secret_key.
:param time_limit: Check if the csrf token is expired.
default is True.
"""
if not data or '##' not in data:
return False
expires, hmac_csrf = data.split('##', 1)
try:
expires = float(expires)
except:
return False
if time_limit is None:
time_limit = current_app.config.get('WTF_CSRF_TIME_LIMIT', 3600)
if time_limit:
now = time.time()
if now > expires:
return False
if not secret_key:
secret_key = current_app.config.get(
'WTF_CSRF_SECRET_KEY', current_app.secret_key
)
if 'csrf_token' not in session:
return False
csrf_build = '%s%s' % (session['csrf_token'], expires)
hmac_compare = hmac.new(
to_bytes(secret_key),
to_bytes(csrf_build),
digestmod=hashlib.sha1
).hexdigest()
return hmac_compare == hmac_csrf
class CsrfProtect(object):
"""Enable csrf protect for Flask.
Register it with::
app = Flask(__name__)
CsrfProtect(app)
And in the templates, add the token input::
If you need to send the token via AJAX, and there is no form::
You can grab the csrf token with JavaScript, and send the token together.
"""
def __init__(self, app=None):
self._exempt_views = set()
if app:
self.init_app(app)
def init_app(self, app):
app.jinja_env.globals['csrf_token'] = generate_csrf
strict = app.config.get('WTF_CSRF_SSL_STRICT', True)
csrf_enabled = app.config.get('WTF_CSRF_ENABLED', True)
@app.before_request
def _csrf_protect():
# many things come from django.middleware.csrf
if not csrf_enabled:
return
if request.method in ('GET', 'HEAD', 'OPTIONS', 'TRACE'):
return
if self._exempt_views:
if not request.endpoint:
return
view = app.view_functions.get(request.endpoint)
if not view:
return
dest = '%s.%s' % (view.__module__, view.__name__)
if dest in self._exempt_views:
return
csrf_token = None
if request.method in ('POST', 'PUT', 'PATCH'):
# find the ``csrf_token`` field in the subitted form
# if the form had a prefix, the name will be ``{prefix}-csrf_token``
for key in request.form:
if key.endswith('csrf_token'):
csrf_token = request.form[key]
if not csrf_token:
# You can get csrf token from header
# The header name is the same as Django
csrf_token = request.headers.get('X-CSRFToken')
if not csrf_token:
# The header name is the same as Rails
csrf_token = request.headers.get('X-CSRF-Token')
if not validate_csrf(csrf_token):
reason = 'CSRF token missing or incorrect.'
return self._error_response(reason)
if request.is_secure and strict:
if not request.referrer:
reason = 'Referrer checking failed - no Referrer.'
return self._error_response(reason)
good_referrer = 'https://%s/' % request.host
if not same_origin(request.referrer, good_referrer):
reason = 'Referrer checking failed - origin not match.'
return self._error_response(reason)
request.csrf_valid = True # mark this request is csrf valid
def exempt(self, view):
"""A decorator that can exclude a view from csrf protection.
Remember to put the decorator above the `route`::
csrf = CsrfProtect(app)
@csrf.exempt
@app.route('/some-view', methods=['POST'])
def some_view():
return
"""
if isinstance(view, string_types):
view_location = view
else:
view_location = '%s.%s' % (view.__module__, view.__name__)
self._exempt_views.add(view_location)
return view
def _error_response(self, reason):
return abort(400, reason)
def error_handler(self, view):
"""A decorator that set the error response handler.
It accepts one parameter `reason`::
@csrf.error_handler
def csrf_error(reason):
return render_template('error.html', reason=reason)
By default, it will return a 400 response.
"""
self._error_response = view
return view
def same_origin(current_uri, compare_uri):
parsed_uri = urlparse(current_uri)
parsed_compare = urlparse(compare_uri)
if parsed_uri.scheme != parsed_compare.scheme:
return False
if parsed_uri.hostname != parsed_compare.hostname:
return False
if parsed_uri.port != parsed_compare.port:
return False
return True