import datetime
from warnings import warn
from jwt import (
ExpiredSignatureError, InvalidTokenError, InvalidAudienceError,
InvalidIssuerError, DecodeError
)
try:
from flask import _app_ctx_stack as ctx_stack
except ImportError: # pragma: no cover
from flask import _request_ctx_stack as ctx_stack
from flask_jwt_extended.config import config
from flask_jwt_extended.exceptions import (
JWTDecodeError, NoAuthorizationError, InvalidHeaderError, WrongTokenError,
RevokedTokenError, FreshTokenRequired, CSRFError, UserLoadError,
UserClaimsVerificationError
)
from flask_jwt_extended.default_callbacks import (
default_expired_token_callback, default_user_claims_callback,
default_user_identity_callback, default_invalid_token_callback,
default_unauthorized_callback, default_needs_fresh_token_callback,
default_revoked_token_callback, default_user_loader_error_callback,
default_claims_verification_callback, default_verify_claims_failed_callback,
default_decode_key_callback, default_encode_key_callback,
default_jwt_headers_callback)
from flask_jwt_extended.tokens import (
encode_refresh_token, encode_access_token
)
from flask_jwt_extended.utils import get_jwt_identity
[docs]class JWTManager(object):
"""
An object used to hold JWT settings and callback functions for the
Flask-JWT-Extended extension.
Instances of :class:`JWTManager` are *not* bound to specific apps, so
you can create one in the main body of your code and then bind it
to your app in a factory function.
"""
[docs] def __init__(self, app=None):
"""
Create the JWTManager instance. You can either pass a flask application
in directly here to register this extension with the flask app, or
call init_app after creating this object (in a factory pattern).
:param app: A flask application
"""
# Register the default error handler callback methods. These can be
# overridden with the appropriate loader decorators
self._user_claims_callback = default_user_claims_callback
self._user_identity_callback = default_user_identity_callback
self._expired_token_callback = default_expired_token_callback
self._invalid_token_callback = default_invalid_token_callback
self._unauthorized_callback = default_unauthorized_callback
self._needs_fresh_token_callback = default_needs_fresh_token_callback
self._revoked_token_callback = default_revoked_token_callback
self._user_loader_callback = None
self._user_loader_error_callback = default_user_loader_error_callback
self._token_in_blacklist_callback = None
self._claims_verification_callback = default_claims_verification_callback
self._verify_claims_failed_callback = default_verify_claims_failed_callback
self._decode_key_callback = default_decode_key_callback
self._encode_key_callback = default_encode_key_callback
self._jwt_additional_header_callback = default_jwt_headers_callback
# Register this extension with the flask app now (if it is provided)
if app is not None:
self.init_app(app)
[docs] def init_app(self, app):
"""
Register this extension with the flask app.
:param app: A flask application
"""
# Save this so we can use it later in the extension
if not hasattr(app, 'extensions'): # pragma: no cover
app.extensions = {}
app.extensions['flask-jwt-extended'] = self
# Set all the default configurations for this extension
self._set_default_configuration_options(app)
self._set_error_handler_callbacks(app)
def _set_error_handler_callbacks(self, app):
"""
Sets the error handler callbacks used by this extension
"""
@app.errorhandler(NoAuthorizationError)
def handle_auth_error(e):
return self._unauthorized_callback(str(e))
@app.errorhandler(CSRFError)
def handle_csrf_error(e):
return self._unauthorized_callback(str(e))
@app.errorhandler(ExpiredSignatureError)
def handle_expired_error(e):
try:
token = ctx_stack.top.expired_jwt
return self._expired_token_callback(token)
except TypeError:
msg = (
"jwt.expired_token_loader callback now takes the expired token "
"as an additional parameter. Example: expired_callback(token)"
)
warn(msg, DeprecationWarning)
return self._expired_token_callback()
@app.errorhandler(InvalidHeaderError)
def handle_invalid_header_error(e):
return self._invalid_token_callback(str(e))
@app.errorhandler(DecodeError)
def handle_invalid_header_error(e):
return self._invalid_token_callback(str(e))
@app.errorhandler(InvalidTokenError)
def handle_invalid_token_error(e):
return self._invalid_token_callback(str(e))
@app.errorhandler(JWTDecodeError)
def handle_jwt_decode_error(e):
return self._invalid_token_callback(str(e))
@app.errorhandler(WrongTokenError)
def handle_wrong_token_error(e):
return self._invalid_token_callback(str(e))
@app.errorhandler(InvalidAudienceError)
def handle_invalid_audience_error(e):
return self._invalid_token_callback(str(e))
@app.errorhandler(InvalidIssuerError)
def handle_invalid_issuer_error(e):
return self._invalid_token_callback(str(e))
@app.errorhandler(RevokedTokenError)
def handle_revoked_token_error(e):
return self._revoked_token_callback()
@app.errorhandler(FreshTokenRequired)
def handle_fresh_token_required(e):
return self._needs_fresh_token_callback()
@app.errorhandler(UserLoadError)
def handler_user_load_error(e):
# The identity is already saved before this exception was raised,
# otherwise a different exception would be raised, which is why we
# can safely call get_jwt_identity() here
identity = get_jwt_identity()
return self._user_loader_error_callback(identity)
@app.errorhandler(UserClaimsVerificationError)
def handle_failed_user_claims_verification(e):
return self._verify_claims_failed_callback()
@staticmethod
def _set_default_configuration_options(app):
"""
Sets the default configuration options used by this extension
"""
# Where to look for the JWT. Available options are cookies or headers
app.config.setdefault('JWT_TOKEN_LOCATION', ('headers',))
# Options for JWTs when the TOKEN_LOCATION is headers
app.config.setdefault('JWT_HEADER_NAME', 'Authorization')
app.config.setdefault('JWT_HEADER_TYPE', 'Bearer')
# Options for JWTs then the TOKEN_LOCATION is query_string
app.config.setdefault('JWT_QUERY_STRING_NAME', 'jwt')
# Option for JWTs when the TOKEN_LOCATION is cookies
app.config.setdefault('JWT_ACCESS_COOKIE_NAME', 'access_token_cookie')
app.config.setdefault('JWT_REFRESH_COOKIE_NAME', 'refresh_token_cookie')
app.config.setdefault('JWT_ACCESS_COOKIE_PATH', '/')
app.config.setdefault('JWT_REFRESH_COOKIE_PATH', '/')
app.config.setdefault('JWT_COOKIE_SECURE', False)
app.config.setdefault('JWT_COOKIE_DOMAIN', None)
app.config.setdefault('JWT_SESSION_COOKIE', True)
app.config.setdefault('JWT_COOKIE_SAMESITE', None)
# Option for JWTs when the TOKEN_LOCATION is json
app.config.setdefault('JWT_JSON_KEY', 'access_token')
app.config.setdefault('JWT_REFRESH_JSON_KEY', 'refresh_token')
# Options for using double submit csrf protection
app.config.setdefault('JWT_COOKIE_CSRF_PROTECT', True)
app.config.setdefault('JWT_CSRF_METHODS', ['POST', 'PUT', 'PATCH', 'DELETE'])
app.config.setdefault('JWT_ACCESS_CSRF_HEADER_NAME', 'X-CSRF-TOKEN')
app.config.setdefault('JWT_REFRESH_CSRF_HEADER_NAME', 'X-CSRF-TOKEN')
app.config.setdefault('JWT_CSRF_IN_COOKIES', True)
app.config.setdefault('JWT_ACCESS_CSRF_COOKIE_NAME', 'csrf_access_token')
app.config.setdefault('JWT_REFRESH_CSRF_COOKIE_NAME', 'csrf_refresh_token')
app.config.setdefault('JWT_ACCESS_CSRF_COOKIE_PATH', '/')
app.config.setdefault('JWT_REFRESH_CSRF_COOKIE_PATH', '/')
app.config.setdefault('JWT_CSRF_CHECK_FORM', False)
app.config.setdefault('JWT_ACCESS_CSRF_FIELD_NAME', 'csrf_token')
app.config.setdefault('JWT_REFRESH_CSRF_FIELD_NAME', 'csrf_token')
# How long an a token will live before they expire.
app.config.setdefault('JWT_ACCESS_TOKEN_EXPIRES', datetime.timedelta(minutes=15))
app.config.setdefault('JWT_REFRESH_TOKEN_EXPIRES', datetime.timedelta(days=30))
# What algorithm to use to sign the token. See here for a list of options:
# https://github.com/jpadilla/pyjwt/blob/master/jwt/api_jwt.py
app.config.setdefault('JWT_ALGORITHM', 'HS256')
# What algorithms are allowed to decode a token
app.config.setdefault('JWT_DECODE_ALGORITHMS', None)
# Secret key to sign JWTs with. Only used if a symmetric algorithm is
# used (such as the HS* algorithms). We will use the app secret key
# if this is not set.
app.config.setdefault('JWT_SECRET_KEY', None)
# Keys to sign JWTs with when use when using an asymmetric
# (public/private key) algorithm, such as RS* or EC*
app.config.setdefault('JWT_PRIVATE_KEY', None)
app.config.setdefault('JWT_PUBLIC_KEY', None)
# Options for blacklisting/revoking tokens
app.config.setdefault('JWT_BLACKLIST_ENABLED', False)
app.config.setdefault('JWT_BLACKLIST_TOKEN_CHECKS', ('access', 'refresh'))
app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity')
app.config.setdefault('JWT_USER_CLAIMS', 'user_claims')
app.config.setdefault('JWT_DECODE_AUDIENCE', None)
app.config.setdefault('JWT_ENCODE_ISSUER', None)
app.config.setdefault('JWT_DECODE_ISSUER', None)
app.config.setdefault('JWT_DECODE_LEEWAY', 0)
app.config.setdefault('JWT_CLAIMS_IN_REFRESH_TOKEN', False)
app.config.setdefault('JWT_ERROR_MESSAGE_KEY', 'msg')
[docs] def user_claims_loader(self, callback):
"""
This decorator sets the callback function for adding custom claims to an
access token when :func:`~flask_jwt_extended.create_access_token` is
called. By default, no extra user claims will be added to the JWT.
*HINT*: The callback function must be a function that takes only **one** argument,
which is the object passed into
:func:`~flask_jwt_extended.create_access_token`, and returns the custom
claims you want included in the access tokens. This returned claims
must be *JSON serializable*.
"""
self._user_claims_callback = callback
return callback
[docs] def user_identity_loader(self, callback):
"""
This decorator sets the callback function for getting the JSON
serializable identity out of whatever object is passed into
:func:`~flask_jwt_extended.create_access_token` and
:func:`~flask_jwt_extended.create_refresh_token`. By default, this will
return the unmodified object that is passed in as the `identity` kwarg
to the above functions.
*HINT*: The callback function must be a function that takes only **one** argument,
which is the object passed into
:func:`~flask_jwt_extended.create_access_token` or
:func:`~flask_jwt_extended.create_refresh_token`, and returns the
*JSON serializable* identity of this token.
"""
self._user_identity_callback = callback
return callback
[docs] def expired_token_loader(self, callback):
"""
This decorator sets the callback function that will be called if an
expired JWT attempts to access a protected endpoint. The default
implementation will return a 401 status code with the JSON:
{"msg": "Token has expired"}
*HINT*: The callback must be a function that takes **one** argument,
which is a dictionary containing the data for the expired token, and
and returns a *Flask response*.
"""
self._expired_token_callback = callback
return callback
[docs] def invalid_token_loader(self, callback):
"""
This decorator sets the callback function that will be called if an
invalid JWT attempts to access a protected endpoint. The default
implementation will return a 422 status code with the JSON:
{"msg": "<error description>"}
*HINT*: The callback must be a function that takes only **one** argument, which is
a string which contains the reason why a token is invalid, and returns
a *Flask response*.
"""
self._invalid_token_callback = callback
return callback
[docs] def unauthorized_loader(self, callback):
"""
This decorator sets the callback function that will be called if an
no JWT can be found when attempting to access a protected endpoint.
The default implementation will return a 401 status code with the JSON:
{"msg": "<error description>"}
*HINT*: The callback must be a function that takes only **one** argument, which is
a string which contains the reason why a JWT could not be found, and
returns a *Flask response*.
"""
self._unauthorized_callback = callback
return callback
[docs] def needs_fresh_token_loader(self, callback):
"""
This decorator sets the callback function that will be called if a
valid and non-fresh token attempts to access an endpoint protected with
the :func:`~flask_jwt_extended.fresh_jwt_required` decorator. The
default implementation will return a 401 status code with the JSON:
{"msg": "Fresh token required"}
*HINT*: The callback must be a function that takes **no** arguments, and returns
a *Flask response*.
"""
self._needs_fresh_token_callback = callback
return callback
[docs] def revoked_token_loader(self, callback):
"""
This decorator sets the callback function that will be called if a
revoked token attempts to access a protected endpoint. The default
implementation will return a 401 status code with the JSON:
{"msg": "Token has been revoked"}
*HINT*: The callback must be a function that takes **no** arguments, and returns
a *Flask response*.
"""
self._revoked_token_callback = callback
return callback
[docs] def user_loader_callback_loader(self, callback):
"""
This decorator sets the callback function that will be called to
automatically load an object when a protected endpoint is accessed.
By default this is not used.
*HINT*: The callback must take **one** argument which is the identity JWT
accessing the protected endpoint, and it must return any object (which can
then be accessed via the :attr:`~flask_jwt_extended.current_user` LocalProxy
in the protected endpoint), or `None` in the case of a user not being
able to be loaded for any reason. If this callback function returns
`None`, the :meth:`~flask_jwt_extended.JWTManager.user_loader_error_loader`
will be called.
"""
self._user_loader_callback = callback
return callback
[docs] def user_loader_error_loader(self, callback):
"""
This decorator sets the callback function that will be called if `None`
is returned from the
:meth:`~flask_jwt_extended.JWTManager.user_loader_callback_loader`
callback function. The default implementation will return
a 401 status code with the JSON:
{"msg": "Error loading the user <identity>"}
*HINT*: The callback must be a function that takes **one** argument, which is the
identity of the user who failed to load, and must return a *Flask response*.
"""
self._user_loader_error_callback = callback
return callback
[docs] def token_in_blacklist_loader(self, callback):
"""
This decorator sets the callback function that will be called when
a protected endpoint is accessed and will check if the JWT has been
been revoked. By default, this callback is not used.
*HINT*: The callback must be a function that takes **one** argument, which is the
decoded JWT (python dictionary), and returns *`True`* if the token
has been blacklisted (or is otherwise considered revoked), or *`False`*
otherwise.
"""
self._token_in_blacklist_callback = callback
return callback
[docs] def claims_verification_loader(self, callback):
"""
This decorator sets the callback function that will be called when
a protected endpoint is accessed, and will check if the custom claims
in the JWT are valid. By default, this callback is not used. The
error returned if the claims are invalid can be controlled via the
:meth:`~flask_jwt_extended.JWTManager.claims_verification_failed_loader`
decorator.
*HINT*: This callback must be a function that takes **one** argument, which is the
custom claims (python dict) present in the JWT, and returns *`True`* if the
claims are valid, or *`False`* otherwise.
"""
self._claims_verification_callback = callback
return callback
[docs] def claims_verification_failed_loader(self, callback):
"""
This decorator sets the callback function that will be called if
the :meth:`~flask_jwt_extended.JWTManager.claims_verification_loader`
callback returns False, indicating that the user claims are not valid.
The default implementation will return a 400 status code with the JSON:
{"msg": "User claims verification failed"}
*HINT*: This callback must be a function that takes **no** arguments, and returns
a *Flask response*.
"""
self._verify_claims_failed_callback = callback
return callback
[docs] def decode_key_loader(self, callback):
"""
This decorator sets the callback function for getting the JWT decode key and
can be used to dynamically choose the appropriate decode key based on token
contents.
The default implementation returns the decode key specified by
`JWT_SECRET_KEY` or `JWT_PUBLIC_KEY`, depending on the signing algorithm.
*HINT*: The callback function should be a function that takes
**two** arguments, which are the unverified claims and headers of the jwt
(dictionaries). The function must return a *string* which is the decode key
in PEM format to verify the token.
"""
self._decode_key_callback = callback
return callback
[docs] def encode_key_loader(self, callback):
"""
This decorator sets the callback function for getting the JWT encode key and
can be used to dynamically choose the appropriate encode key based on the
token identity.
The default implementation returns the encode key specified by
`JWT_SECRET_KEY` or `JWT_PRIVATE_KEY`, depending on the signing algorithm.
*HINT*: The callback function must be a function that takes only **one**
argument, which is the identity as passed into the create_access_token
or create_refresh_token functions, and must return a *string* which is
the decode key to verify the token.
"""
self._encode_key_callback = callback
return callback
def additional_headers_loader(self, callback):
"""
This decorator sets the callback function for adding custom headers to an
access token when :func:`~flask_jwt_extended.create_access_token` is
called. By default, two headers will be added the type of the token, which is JWT,
and the signing algorithm being used, such as HMAC SHA256 or RSA.
*HINT*: The callback function must be a function that takes **no** argument,
which is the object passed into
:func:`~flask_jwt_extended.create_access_token`, and returns the custom
claims you want included in the access tokens. This returned claims
must be *JSON serializable*.
"""
self._jwt_additional_header_callback = callback
return callback
def _create_refresh_token(self, identity, expires_delta=None, user_claims=None,
headers=None):
if expires_delta is None:
expires_delta = config.refresh_expires
if user_claims is None and config.user_claims_in_refresh_token:
user_claims = self._user_claims_callback(identity)
if headers is None:
headers = self._jwt_additional_header_callback(identity)
refresh_token = encode_refresh_token(
identity=self._user_identity_callback(identity),
secret=self._encode_key_callback(identity),
algorithm=config.algorithm,
expires_delta=expires_delta,
user_claims=user_claims,
csrf=config.csrf_protect,
identity_claim_key=config.identity_claim_key,
user_claims_key=config.user_claims_key,
json_encoder=config.json_encoder,
headers=headers
)
return refresh_token
def _create_access_token(self, identity, fresh=False, expires_delta=None,
user_claims=None, headers=None):
if expires_delta is None:
expires_delta = config.access_expires
if user_claims is None:
user_claims = self._user_claims_callback(identity)
if headers is None:
headers = self._jwt_additional_header_callback(identity)
access_token = encode_access_token(
identity=self._user_identity_callback(identity),
secret=self._encode_key_callback(identity),
algorithm=config.algorithm,
expires_delta=expires_delta,
fresh=fresh,
user_claims=user_claims,
csrf=config.csrf_protect,
identity_claim_key=config.identity_claim_key,
user_claims_key=config.user_claims_key,
json_encoder=config.json_encoder,
headers=headers,
issuer=config.encode_issuer,
)
return access_token