93 lines
3.0 KiB
Python
93 lines
3.0 KiB
Python
"""
|
|
A provided CSRF implementation which puts CSRF data in a session.
|
|
|
|
This can be used fairly comfortably with many `request.session` type
|
|
objects, including the Werkzeug/Flask session store, Django sessions, and
|
|
potentially other similar objects which use a dict-like API for storing
|
|
session keys.
|
|
|
|
The basic concept is a randomly generated value is stored in the user's
|
|
session, and an hmac-sha1 of it (along with an optional expiration time,
|
|
for extra security) is used as the value of the csrf_token. If this token
|
|
validates with the hmac of the random value + expiration time, and the
|
|
expiration time is not passed, the CSRF validation will pass.
|
|
"""
|
|
import hmac
|
|
import os
|
|
from datetime import datetime
|
|
from datetime import timedelta
|
|
from hashlib import sha1
|
|
|
|
from ..validators import ValidationError
|
|
from .core import CSRF
|
|
|
|
__all__ = ("SessionCSRF",)
|
|
|
|
|
|
class SessionCSRF(CSRF):
|
|
TIME_FORMAT = "%Y%m%d%H%M%S"
|
|
|
|
def setup_form(self, form):
|
|
self.form_meta = form.meta
|
|
return super().setup_form(form)
|
|
|
|
def generate_csrf_token(self, csrf_token_field):
|
|
meta = self.form_meta
|
|
if meta.csrf_secret is None:
|
|
raise Exception(
|
|
"must set `csrf_secret` on class Meta for SessionCSRF to work"
|
|
)
|
|
if meta.csrf_context is None:
|
|
raise TypeError("Must provide a session-like object as csrf context")
|
|
|
|
session = self.session
|
|
|
|
if "csrf" not in session:
|
|
session["csrf"] = sha1(os.urandom(64)).hexdigest()
|
|
|
|
if self.time_limit:
|
|
expires = (self.now() + self.time_limit).strftime(self.TIME_FORMAT)
|
|
csrf_build = "{}{}".format(session["csrf"], expires)
|
|
else:
|
|
expires = ""
|
|
csrf_build = session["csrf"]
|
|
|
|
hmac_csrf = hmac.new(
|
|
meta.csrf_secret, csrf_build.encode("utf8"), digestmod=sha1
|
|
)
|
|
return f"{expires}##{hmac_csrf.hexdigest()}"
|
|
|
|
def validate_csrf_token(self, form, field):
|
|
meta = self.form_meta
|
|
if not field.data or "##" not in field.data:
|
|
raise ValidationError(field.gettext("CSRF token missing."))
|
|
|
|
expires, hmac_csrf = field.data.split("##", 1)
|
|
|
|
check_val = (self.session["csrf"] + expires).encode("utf8")
|
|
|
|
hmac_compare = hmac.new(meta.csrf_secret, check_val, digestmod=sha1)
|
|
if hmac_compare.hexdigest() != hmac_csrf:
|
|
raise ValidationError(field.gettext("CSRF failed."))
|
|
|
|
if self.time_limit:
|
|
now_formatted = self.now().strftime(self.TIME_FORMAT)
|
|
if now_formatted > expires:
|
|
raise ValidationError(field.gettext("CSRF token expired."))
|
|
|
|
def now(self):
|
|
"""
|
|
Get the current time. Used for test mocking/overriding mainly.
|
|
"""
|
|
return datetime.now()
|
|
|
|
@property
|
|
def time_limit(self):
|
|
return getattr(self.form_meta, "csrf_time_limit", timedelta(minutes=30))
|
|
|
|
@property
|
|
def session(self):
|
|
return getattr(
|
|
self.form_meta.csrf_context, "session", self.form_meta.csrf_context
|
|
)
|