Browse Source

Replace itsdangerous' JWT with PyJWT

Peter Justin 3 years ago
parent
commit
1ad94a51ad
4 changed files with 39 additions and 25 deletions
  1. 24 18
      flaskbb/tokens/serializer.py
  2. 1 0
      requirements.txt
  3. 1 0
      setup.py
  4. 13 7
      tests/unit/tokens/test_serializer.py

+ 24 - 18
flaskbb/tokens/serializer.py

@@ -7,10 +7,9 @@
     :license: BSD, see LICENSE for more details
     :license: BSD, see LICENSE for more details
 """
 """
 
 
-from datetime import timedelta
+from datetime import datetime, timedelta
 
 
-from itsdangerous import (BadData, BadSignature, SignatureExpired,
-                          TimedJSONWebSignatureSerializer)
+import jwt
 
 
 from ..core import tokens
 from ..core import tokens
 
 
@@ -37,9 +36,15 @@ class FlaskBBTokenSerializer(tokens.TokenSerializer):
     """
     """
 
 
     def __init__(self, secret_key, expiry=_DEFAULT_EXPIRY):
     def __init__(self, secret_key, expiry=_DEFAULT_EXPIRY):
-        self._serializer = TimedJSONWebSignatureSerializer(
-            secret_key, int(expiry.total_seconds())
-        )
+        self.secret_key = secret_key
+        self.algorithm = "HS256"
+
+        if isinstance(expiry, timedelta):
+            self.expiry = datetime.utcnow() + expiry
+        elif isinstance(expiry, datetime):
+            self.expiry = expiry
+        else:
+            raise TypeError("'expiry' must be of type timedelta or datetime")
 
 
     def dumps(self, token):
     def dumps(self, token):
         """
         """
@@ -49,11 +54,10 @@ class FlaskBBTokenSerializer(tokens.TokenSerializer):
         :flaskbb.core.tokens.Token token: Token to transformed into a JWT
         :flaskbb.core.tokens.Token token: Token to transformed into a JWT
         :returns str: A fully serialized token
         :returns str: A fully serialized token
         """
         """
-        return self._serializer.dumps(
-            {
-                'id': token.user_id,
-                'op': token.operation,
-            }
+        return jwt.encode(
+            payload={"id": token.user_id, "op": token.operation, "exp": self.expiry},
+            key=self.secret_key,
+            algorithm=self.algorithm,
         )
         )
 
 
     def loads(self, raw_token):
     def loads(self, raw_token):
@@ -69,16 +73,18 @@ class FlaskBBTokenSerializer(tokens.TokenSerializer):
         :returns flaskbb.core.tokens.Token: Parsed token
         :returns flaskbb.core.tokens.Token: Parsed token
         """
         """
         try:
         try:
-            parsed = self._serializer.loads(raw_token)
-        except SignatureExpired:
+            parsed = jwt.decode(
+                raw_token, key=self.secret_key, algorithms=[self.algorithm]
+            )
+        except jwt.ExpiredSignatureError:
             raise tokens.TokenError.expired()
             raise tokens.TokenError.expired()
-        except BadSignature:  # pragma: no branch
+        except jwt.DecodeError:  # pragma: no branch
             raise tokens.TokenError.invalid()
             raise tokens.TokenError.invalid()
-        # ideally we never end up here as BadSignature should
+        # ideally we never end up here as DecodeError should
         # catch everything else, however since this is the root
         # catch everything else, however since this is the root
-        # exception for itsdangerous we'll catch it down and
+        # exception for PyJWT we'll catch it down and
         # and re-raise our own
         # and re-raise our own
-        except BadData:  # pragma: no cover
+        except jwt.InvalidTokenError:  # pragma: no cover
             raise tokens.TokenError.bad()
             raise tokens.TokenError.bad()
         else:
         else:
-            return tokens.Token(user_id=parsed['id'], operation=parsed['op'])
+            return tokens.Token(user_id=parsed["id"], operation=parsed["op"])

+ 1 - 0
requirements.txt

@@ -44,6 +44,7 @@ Pillow==8.3.2
 pluggy==1.0.0
 pluggy==1.0.0
 prompt-toolkit==3.0.20
 prompt-toolkit==3.0.20
 Pygments==2.10.0
 Pygments==2.10.0
+PyJWT==2.1.0
 python-dateutil==2.8.2
 python-dateutil==2.8.2
 python-editor==1.0.4
 python-editor==1.0.4
 pytz==2021.1
 pytz==2021.1

+ 1 - 0
setup.py

@@ -51,6 +51,7 @@ install_requires = [
     "pluggy>=0.13.1",
     "pluggy>=0.13.1",
     "prompt-toolkit>=3.0.19",
     "prompt-toolkit>=3.0.19",
     "Pygments>=2.9.0",
     "Pygments>=2.9.0",
+    "PyJWT>=2.1.0",
     "python-dateutil>=2.8.2",
     "python-dateutil>=2.8.2",
     "python-editor>=1.0.4",
     "python-editor>=1.0.4",
     "pytz>=2021.1",
     "pytz>=2021.1",

+ 13 - 7
tests/unit/tokens/test_serializer.py

@@ -6,12 +6,12 @@ from freezegun import freeze_time
 from flaskbb import tokens
 from flaskbb import tokens
 from flaskbb.core.tokens import Token, TokenActions, TokenError
 from flaskbb.core.tokens import Token, TokenActions, TokenError
 
 
-pytestmark = pytest.mark.usefixtures('default_settings')
+pytestmark = pytest.mark.usefixtures("default_settings")
 
 
 
 
 def test_can_round_trip_token():
 def test_can_round_trip_token():
     serializer = tokens.FlaskBBTokenSerializer(
     serializer = tokens.FlaskBBTokenSerializer(
-        'hello i am secret', timedelta(seconds=100)
+        "hello i am secret", timedelta(seconds=100)
     )
     )
     token = Token(user_id=1, operation=TokenActions.RESET_PASSWORD)
     token = Token(user_id=1, operation=TokenActions.RESET_PASSWORD)
     roundtrip = serializer.loads(serializer.dumps(token))
     roundtrip = serializer.loads(serializer.dumps(token))
@@ -21,17 +21,17 @@ def test_can_round_trip_token():
 
 
 def test_raises_token_error_with_bad_data():
 def test_raises_token_error_with_bad_data():
     serializer = tokens.FlaskBBTokenSerializer(
     serializer = tokens.FlaskBBTokenSerializer(
-        'hello i am also secret', timedelta(seconds=100)
+        "hello i am also secret", timedelta(seconds=100)
     )
     )
 
 
     with pytest.raises(TokenError) as excinfo:
     with pytest.raises(TokenError) as excinfo:
-        serializer.loads('not actually a token')
-    assert 'invalid' in str(excinfo.value)
+        serializer.loads("not actually a token")
+    assert "invalid" in str(excinfo.value)
 
 
 
 
 def test_expired_token_raises():
 def test_expired_token_raises():
     serializer = tokens.FlaskBBTokenSerializer(
     serializer = tokens.FlaskBBTokenSerializer(
-        'i am a secret not', expiry=timedelta(seconds=1)
+        "i am a secret not", expiry=datetime.utcnow() + timedelta(seconds=1)
     )
     )
     dumped_token = serializer.dumps(
     dumped_token = serializer.dumps(
         Token(user_id=1, operation=TokenActions.RESET_PASSWORD)
         Token(user_id=1, operation=TokenActions.RESET_PASSWORD)
@@ -41,4 +41,10 @@ def test_expired_token_raises():
         with pytest.raises(TokenError) as excinfo:
         with pytest.raises(TokenError) as excinfo:
             serializer.loads(dumped_token)
             serializer.loads(dumped_token)
 
 
-    assert 'expired' in str(excinfo.value)
+    assert "expired" in str(excinfo.value)
+
+
+def test_raises_typeerror_expiry_args():
+    with pytest.raises(TypeError) as excinfo:
+        tokens.FlaskBBTokenSerializer("hello i am also secret", 100)
+        assert "timedelta or datetime" in excinfo.value