Browse Source

Add new `oauth2_validators` hook (#1473)

* Add new oauth2_validators hook

* Rename raw_data to response_json
Rafał Pitoń 2 years ago
parent
commit
575aac6029

+ 2 - 0
misago/hooks.py

@@ -3,7 +3,9 @@ urlpatterns = []
 context_processors = []
 
 new_registrations_validators = []
+
 oauth2_user_data_filters = []
+oauth2_validators = []
 
 post_search_filters = []
 post_validators = []

+ 3 - 1
misago/oauth2/client.py

@@ -129,11 +129,13 @@ def get_user_data(request, access_token):
     except (ValueError, TypeError):
         raise exceptions.OAuth2UserDataJSONError()
 
-    return {
+    clean_data = {
         key: get_value_from_json(getattr(request.settings, setting), response_json)
         for key, setting in JSON_MAPPING.items()
     }
 
+    return clean_data, response_json
+
 
 def get_redirect_uri(request):
     return request.build_absolute_uri(reverse("misago:oauth2-complete"))

+ 55 - 46
misago/oauth2/tests/test_get_user_data.py

@@ -45,7 +45,7 @@ def test_user_data_is_returned_using_get_request_with_token_in_query_string(
     )
 
     with patch("requests.get", get_mock):
-        assert get_user_data(mock_request, ACCESS_TOKEN) == user_data
+        assert get_user_data(mock_request, ACCESS_TOKEN) == (user_data, user_data)
 
         get_mock.assert_called_once_with(
             f"https://example.com/oauth2/user?atoken={ACCESS_TOKEN}",
@@ -82,7 +82,7 @@ def test_user_data_is_returned_using_get_request_with_token_in_header(mock_reque
     )
 
     with patch("requests.get", get_mock):
-        assert get_user_data(mock_request, ACCESS_TOKEN) == user_data
+        assert get_user_data(mock_request, ACCESS_TOKEN) == (user_data, user_data)
 
         get_mock.assert_called_once_with(
             f"https://example.com/oauth2/user",
@@ -121,7 +121,7 @@ def test_user_data_is_returned_using_get_request_with_bearer_token_in_header(
     )
 
     with patch("requests.get", get_mock):
-        assert get_user_data(mock_request, ACCESS_TOKEN) == user_data
+        assert get_user_data(mock_request, ACCESS_TOKEN) == (user_data, user_data)
 
         get_mock.assert_called_once_with(
             f"https://example.com/oauth2/user",
@@ -160,7 +160,7 @@ def test_user_data_is_returned_using_post_request_with_token_in_query_string(
     )
 
     with patch("requests.post", post_mock):
-        assert get_user_data(mock_request, ACCESS_TOKEN) == user_data
+        assert get_user_data(mock_request, ACCESS_TOKEN) == (user_data, user_data)
 
         post_mock.assert_called_once_with(
             f"https://example.com/oauth2/user?atoken={ACCESS_TOKEN}",
@@ -197,7 +197,7 @@ def test_user_data_is_returned_using_post_request_with_token_in_header(mock_requ
     )
 
     with patch("requests.post", post_mock):
-        assert get_user_data(mock_request, ACCESS_TOKEN) == user_data
+        assert get_user_data(mock_request, ACCESS_TOKEN) == (user_data, user_data)
 
         post_mock.assert_called_once_with(
             f"https://example.com/oauth2/user",
@@ -236,7 +236,7 @@ def test_user_data_is_returned_using_post_request_with_bearer_token_in_header(
     )
 
     with patch("requests.post", post_mock):
-        assert get_user_data(mock_request, ACCESS_TOKEN) == user_data
+        assert get_user_data(mock_request, ACCESS_TOKEN) == (user_data, user_data)
 
         post_mock.assert_called_once_with(
             f"https://example.com/oauth2/user",
@@ -276,7 +276,7 @@ def test_user_data_is_returned_using_post_request_with_extra_headers(
     )
 
     with patch("requests.post", post_mock):
-        assert get_user_data(mock_request, ACCESS_TOKEN) == user_data
+        assert get_user_data(mock_request, ACCESS_TOKEN) == (user_data, user_data)
 
         post_mock.assert_called_once_with(
             f"https://example.com/oauth2/user",
@@ -319,7 +319,7 @@ def test_user_data_request_with_token_in_url_respects_existing_querystring(
     )
 
     with patch("requests.get", get_mock):
-        assert get_user_data(mock_request, ACCESS_TOKEN) == user_data
+        assert get_user_data(mock_request, ACCESS_TOKEN) == (user_data, user_data)
 
         get_mock.assert_called_once_with(
             f"https://example.com/oauth2/data?type=user&atoken={ACCESS_TOKEN}",
@@ -346,26 +346,28 @@ def test_user_data_json_values_are_mapped_to_result(mock_request):
         "avatar": "https://example.com/avatar.png",
     }
 
+    json_response = {
+        "id": user_data["id"],
+        "user": {
+            "profile": {
+                "name": user_data["name"],
+                "email": user_data["email"],
+                "avatar": user_data["avatar"],
+            }
+        },
+    }
+
     get_mock = Mock(
         return_value=Mock(
             status_code=200,
             json=Mock(
-                return_value={
-                    "id": user_data["id"],
-                    "user": {
-                        "profile": {
-                            "name": user_data["name"],
-                            "email": user_data["email"],
-                            "avatar": user_data["avatar"],
-                        }
-                    },
-                },
+                return_value=json_response,
             ),
         ),
     )
 
     with patch("requests.get", get_mock):
-        assert get_user_data(mock_request, ACCESS_TOKEN) == user_data
+        assert get_user_data(mock_request, ACCESS_TOKEN) == (user_data, json_response)
 
         get_mock.assert_called_once_with(
             f"https://example.com/oauth2/data?type=user&atoken={ACCESS_TOKEN}",
@@ -392,31 +394,36 @@ def test_user_data_json_values_are_none_when_not_found(mock_request):
         "avatar": "https://example.com/avatar.png",
     }
 
+    json_response = {
+        "profile_id": user_data["id"],
+        "user": {
+            "data": {
+                "name": user_data["name"],
+                "email": user_data["email"],
+                "avatar": user_data["avatar"],
+            }
+        },
+    }
+
     get_mock = Mock(
         return_value=Mock(
             status_code=200,
             json=Mock(
-                return_value={
-                    "profile_id": user_data["id"],
-                    "user": {
-                        "data": {
-                            "name": user_data["name"],
-                            "email": user_data["email"],
-                            "avatar": user_data["avatar"],
-                        }
-                    },
-                },
+                return_value=json_response,
             ),
         ),
     )
 
     with patch("requests.get", get_mock):
-        assert get_user_data(mock_request, ACCESS_TOKEN) == {
-            "id": None,
-            "name": None,
-            "email": None,
-            "avatar": None,
-        }
+        assert get_user_data(mock_request, ACCESS_TOKEN) == (
+            {
+                "id": None,
+                "name": None,
+                "email": None,
+                "avatar": None,
+            },
+            json_response,
+        )
 
         get_mock.assert_called_once_with(
             f"https://example.com/oauth2/data?type=user&atoken={ACCESS_TOKEN}",
@@ -443,26 +450,28 @@ def test_user_data_skips_avatar_if_path_is_not_set(mock_request):
         "avatar": None,
     }
 
+    json_response = {
+        "id": user_data["id"],
+        "user": {
+            "profile": {
+                "name": user_data["name"],
+                "email": user_data["email"],
+                "avatar": "https://example.com/avatar.png",
+            }
+        },
+    }
+
     get_mock = Mock(
         return_value=Mock(
             status_code=200,
             json=Mock(
-                return_value={
-                    "id": user_data["id"],
-                    "user": {
-                        "profile": {
-                            "name": user_data["name"],
-                            "email": user_data["email"],
-                            "avatar": "https://example.com/avatar.png",
-                        }
-                    },
-                },
+                return_value=json_response,
             ),
         ),
     )
 
     with patch("requests.get", get_mock):
-        assert get_user_data(mock_request, ACCESS_TOKEN) == user_data
+        assert get_user_data(mock_request, ACCESS_TOKEN) == (user_data, json_response)
 
         get_mock.assert_called_once_with(
             f"https://example.com/oauth2/data?type=user&atoken={ACCESS_TOKEN}",

+ 6 - 0
misago/oauth2/tests/test_user_creation_from_data.py

@@ -20,6 +20,7 @@ def test_activated_user_is_created_from_valid_data(db, dynamic_settings):
             "email": "user@example.com",
             "avatar": None,
         },
+        {},
     )
 
     assert created
@@ -45,6 +46,7 @@ def test_user_subject_is_created_from_valid_data(db, dynamic_settings):
             "email": "user@example.com",
             "avatar": None,
         },
+        {},
     )
 
     assert created
@@ -63,6 +65,7 @@ def test_user_is_created_with_avatar_from_valid_data(db, dynamic_settings):
             "email": "user@example.com",
             "avatar": "https://placekitten.com/600/500",
         },
+        {},
     )
 
     assert created
@@ -81,6 +84,7 @@ def test_user_is_created_with_admin_activation_from_valid_data(db, dynamic_setti
             "email": "user@example.com",
             "avatar": None,
         },
+        {},
     )
 
     assert created
@@ -109,6 +113,7 @@ def test_user_name_conflict_during_creation_from_valid_data_is_handled(
                     "email": "test@example.com",
                     "avatar": None,
                 },
+                {},
             )
 
     assert excinfo.value.error_list == ["This username is not available."]
@@ -126,6 +131,7 @@ def test_user_email_conflict_during_creation_from_valid_data_is_handled(
                 "email": user.email,
                 "avatar": None,
             },
+            {},
         )
 
     assert excinfo.value.error_list == ["This e-mail address is not available."]

+ 57 - 0
misago/oauth2/tests/test_user_data_validation.py

@@ -1,6 +1,7 @@
 from unittest.mock import Mock, patch
 
 import pytest
+from django.core.exceptions import ValidationError
 
 from ..exceptions import OAuth2UserDataValidationError
 from ..validation import validate_user_data
@@ -16,6 +17,7 @@ def test_new_user_valid_data_is_validated(db, dynamic_settings):
             "email": "user@example.com",
             "avatar": None,
         },
+        {},
     )
 
     assert valid_data == {
@@ -36,6 +38,7 @@ def test_existing_user_valid_data_is_validated(user, dynamic_settings):
             "email": user.email,
             "avatar": None,
         },
+        {},
     )
 
     assert valid_data == {
@@ -66,6 +69,7 @@ def test_error_was_raised_for_user_data_with_without_name(db, dynamic_settings):
                     "email": "user@example.com",
                     "avatar": None,
                 },
+                {},
             )
 
     assert excinfo.value.error_list == [
@@ -89,6 +93,7 @@ def test_error_was_raised_for_user_data_with_invalid_name(db, dynamic_settings):
                     "email": "user@example.com",
                     "avatar": None,
                 },
+                {},
             )
 
     assert excinfo.value.error_list == [
@@ -112,6 +117,7 @@ def test_error_was_raised_for_user_data_with_too_long_name(db, dynamic_settings)
                     "email": "user@example.com",
                     "avatar": None,
                 },
+                {},
             )
 
     assert excinfo.value.error_list == [
@@ -130,6 +136,7 @@ def test_error_was_raised_for_user_data_without_email(db, dynamic_settings):
                 "email": "",
                 "avatar": None,
             },
+            {},
         )
 
     assert excinfo.value.error_list == ["Enter a valid email address."]
@@ -146,6 +153,56 @@ def test_error_was_raised_for_user_data_with_invalid_email(db, dynamic_settings)
                 "email": "userexample.com",
                 "avatar": None,
             },
+            {},
         )
 
     assert excinfo.value.error_list == ["Enter a valid email address."]
+
+
+def custom_oauth2_validator(request, user, user_data, raw_data):
+    if "bad" in user_data["name"].lower():
+        raise ValidationError("Custom validation error!")
+
+
+def test_custom_oauth2_validator_passes_valid_data(db, dynamic_settings):
+    user_data = {
+        "id": "1234",
+        "name": "UserName",
+        "email": "user@example.com",
+        "avatar": None,
+    }
+
+    with patch(
+        "misago.oauth2.validation.oauth2_validators",
+        [custom_oauth2_validator],
+    ):
+        assert (
+            validate_user_data(
+                Mock(settings=dynamic_settings),
+                None,
+                user_data,
+                {},
+            )
+            == user_data
+        )
+
+
+def test_custom_oauth2_validator_raises_error_for_invalid_data(db, dynamic_settings):
+    with pytest.raises(OAuth2UserDataValidationError) as excinfo:
+        with patch(
+            "misago.oauth2.validation.oauth2_validators",
+            [custom_oauth2_validator],
+        ):
+            validate_user_data(
+                Mock(settings=dynamic_settings),
+                None,
+                {
+                    "id": "1234",
+                    "name": "UserNameBad",
+                    "email": "user@example.com",
+                    "avatar": None,
+                },
+                {},
+            )
+
+    assert excinfo.value.error_list == ["Custom validation error!"]

+ 4 - 0
misago/oauth2/tests/test_user_update_with_data.py

@@ -21,6 +21,7 @@ def test_user_is_updated_with_valid_data(user, dynamic_settings):
             "email": "updated@example.com",
             "avatar": None,
         },
+        {},
     )
 
     assert created is False
@@ -51,6 +52,7 @@ def test_user_is_not_updated_with_unchanged_valid_data(user, dynamic_settings):
             "email": user.email,
             "avatar": None,
         },
+        {},
     )
 
     assert created is False
@@ -93,6 +95,7 @@ def test_user_name_conflict_during_update_with_valid_data_is_handled(
                     "email": "test@example.com",
                     "avatar": None,
                 },
+                {},
             )
 
     assert excinfo.value.error_list == ["This username is not available."]
@@ -112,6 +115,7 @@ def test_user_email_conflict_during_update_with_valid_data_is_handled(
                 "email": other_user.email,
                 "avatar": None,
             },
+            {},
         )
 
     assert excinfo.value.error_list == ["This e-mail address is not available."]

+ 2 - 2
misago/oauth2/user.py

@@ -11,7 +11,7 @@ from .validation import validate_user_data
 User = get_user_model()
 
 
-def get_user_from_data(request, user_data):
+def get_user_from_data(request, user_data, raw_data):
     if not user_data["id"]:
         raise OAuth2UserIdNotProvidedError()
 
@@ -22,7 +22,7 @@ def get_user_from_data(request, user_data):
 
     created = not bool(user)
 
-    cleaned_data = validate_user_data(request, user, user_data)
+    cleaned_data = validate_user_data(request, user, user_data, raw_data)
 
     try:
         with transaction.atomic():

+ 4 - 15
misago/oauth2/validation.py

@@ -6,9 +6,8 @@ from django.forms import ValidationError
 from django.utils.crypto import get_random_string
 from unidecode import unidecode
 
-from ..hooks import oauth2_user_data_filters
+from ..hooks import oauth2_validators, oauth2_user_data_filters
 from ..users.validators import (
-    validate_new_registration,
     validate_username_content,
     validate_username_length,
 )
@@ -22,7 +21,7 @@ class UsernameSettings:
     username_length_min: int = 1
 
 
-def validate_user_data(request, user, user_data):
+def validate_user_data(request, user, user_data, response_json):
     filtered_data = filter_user_data(request, user, user_data)
 
     try:
@@ -30,21 +29,11 @@ def validate_user_data(request, user, user_data):
         validate_username_length(UsernameSettings, filtered_data["name"])
         validate_email(filtered_data["email"])
 
-        errors_list = []
-
-        def add_error(_field_unused: str | None, error: str | ValidationError):
-            if isinstance(error, ValidationError):
-                error = error.message
-
-            errors_list.append(str(error))
-
-        validate_new_registration(request, filtered_data, add_error)
+        for plugin_oauth2_validator in oauth2_validators:
+            plugin_oauth2_validator(request, user, user_data, response_json)
     except ValidationError as exc:
         raise OAuth2UserDataValidationError(error_list=[str(exc.message)])
 
-    if errors_list:
-        raise OAuth2UserDataValidationError(error_list=errors_list)
-
     return filtered_data
 
 

+ 2 - 2
misago/oauth2/views.py

@@ -54,8 +54,8 @@ def oauth2_complete(request):
     try:
         code_grant = get_code_grant(request)
         token = get_access_token(request, code_grant)
-        user_data = get_user_data(request, token)
-        user, created = get_user_from_data(request, user_data)
+        user_data, raw_data = get_user_data(request, token)
+        user, created = get_user_from_data(request, user_data, raw_data)
 
         if not user.is_active:
             raise OAuth2UserAccountDeactivatedError()