flask_wtf CSRFProtect机制

一,使用

实例化

from flask_wtf import CSRFProtect
csrf = CSRFProtect()

初始化

from flask import Flask
app = Flask(__name__)
...
WTF_CSRF_SECRET_KEY=xxx #设置token 生成salt
...
csrf.init_app(app)

    csrf默认对['POST', 'PUT', 'PATCH', 'DELETE']方法进行设置、验证token机制。方法修改可以通过config中设置 WTF_CSRF_METHODS值进行更改。

    如果想排除某个api不进验证,可通过csrf.exempt进行装饰。

二,实现机制

1,生成token

csrf机制通过generate_csrf()函数,生成随机数存入session同时通过随机数dump生成token,并以 csrf-token为键或config中配置的名称通过键值对的方式存在g中,在response时放在header或body中

def generate_csrf(secret_key=None, token_key=None):
    """Generate a CSRF token. The token is cached for a request, so multiple
    calls to this function will generate the same token.

    During testing, it might be useful to access the signed token in
    ``g.csrf_token`` and the raw token in ``session['csrf_token']``.

    :param secret_key: Used to securely sign the token. Default is
        ``WTF_CSRF_SECRET_KEY`` or ``SECRET_KEY``.
    :param token_key: Key where token is stored in session for comparision.
        Default is ``WTF_CSRF_FIELD_NAME`` or ``'csrf_token'``.
    """

    secret_key = _get_config(
        secret_key, 'WTF_CSRF_SECRET_KEY', current_app.secret_key,
        message='A secret key is required to use CSRF.'
    )
    field_name = _get_config(
        token_key, 'WTF_CSRF_FIELD_NAME', 'csrf_token',
        message='A field name is required to use CSRF.'
    )

    if field_name not in g:
        s = URLSafeTimedSerializer(secret_key, salt='wtf-csrf-token')

        if field_name not in session:
            session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()

        try:
            token = s.dumps(session[field_name])
        except TypeError:
            session[field_name] = hashlib.sha1(os.urandom(64)).hexdigest()
            token = s.dumps(session[field_name])

        setattr(g, field_name, token)

    return g.get(field_name)

2,验证token

当app接收到request, csrf会从header、form等中找csrf-token,并load后与存在session中的值进行对比验证。主要通过protect()函数实现。

def protect(self):
        if request.method not in current_app.config['WTF_CSRF_METHODS']:
            return

        try:
            validate_csrf(self._get_csrf_token())
        except ValidationError as e:
            logger.info(e.args[0])
            self._error_response(e.args[0])

        if request.is_secure and current_app.config['WTF_CSRF_SSL_STRICT']:
            if not request.referrer:
                self._error_response('The referrer header is missing.')

            good_referrer = 'https://{0}/'.format(request.host)

            if not same_origin(request.referrer, good_referrer):
                self._error_response('The referrer does not match the host.')

        g.csrf_valid = True  # mark this request as CSRF valid

3,csrf.exempt实现机制

注册,当app启动时,对于有exempt装饰的路由(endpoint),会通过一个列表进行记录。

def exempt(self, view):
        """Mark a view or blueprint to be excluded from CSRF protection.

        ::

            @app.route('/some-view', methods=['POST'])
            @csrf.exempt
            def some_view():
                ...

        ::

            bp = Blueprint(...)
            csrf.exempt(bp)

        """

        if isinstance(view, Blueprint):
            self._exempt_blueprints.add(view.name)
            return view

        if isinstance(view, string_types):
            view_location = view
        else:
            if isinstance(view, views.MethodViewType):
                view_location = '.'.join((view.__module__, view.__name__.lower()))
            else:
                view_location = '.'.join((view.__module__, view.__name__))

        self._exempt_views.add(view_location)
        return view

当路由被访问时,csrf会先检查该函数是否是exempt装饰了的路由,如果是就跳过protect()验证检查。

@app.before_request
         def csrf_protect():
            if not app.config['WTF_CSRF_ENABLED']:
                return

            if not app.config['WTF_CSRF_CHECK_DEFAULT']:
                return

            if request.method not in app.config['WTF_CSRF_METHODS']:
                return

            if not request.endpoint:
                return

            if request.blueprint in self._exempt_blueprints:
                return

            view = app.view_functions.get(request.endpoint)
            dest = '{0}.{1}'.format(view.__module__, view.__name__)

            if dest in self._exempt_views:
                return

            self.protect()

猜你喜欢

转载自blog.csdn.net/ypgsh/article/details/84402363