Browse Source

More registration hooks, interfaces, implementations

Alec Nikolas Reiter 7 years ago
parent
commit
ae8a907fa7

+ 54 - 30
flaskbb/auth/plugins.py

@@ -8,39 +8,33 @@
     :license: BSD, see LICENSE for more details
 """
 from flask import flash, redirect, url_for
-from flask_babelplus import gettext as _
-from flask_login import current_user, login_user, logout_user
-
+from flask_login import current_user, logout_user
+from ..extensions import db
 from . import impl
 from ..core.auth.authentication import ForceLogout
 from ..user.models import User
 from ..utils.settings import flaskbb_config
-from .services.authentication import (BlockUnactivatedUser, ClearFailedLogins,
-                                      DefaultFlaskBBAuthProvider,
-                                      MarkFailedLogin)
+from .services.authentication import (
+    BlockUnactivatedUser,
+    ClearFailedLogins,
+    DefaultFlaskBBAuthProvider,
+    MarkFailedLogin,
+)
 from .services.factories import account_activator_factory
-from .services.reauthentication import (ClearFailedLoginsOnReauth,
-                                        DefaultFlaskBBReauthProvider,
-                                        MarkFailedReauth)
-
-
-@impl
-def flaskbb_event_user_registered(username):
-    user = User.query.filter_by(username=username).first()
-
-    if flaskbb_config["ACTIVATE_ACCOUNT"]:
-        service = account_activator_factory()
-        service.initiate_account_activation(user.email)
-        flash(
-            _(
-                "An account activation email has been sent to "
-                "%(email)s",
-                email=user.email
-            ), "success"
-        )
-    else:
-        login_user(user)
-        flash(_("Thanks for registering."), "success")
+from .services.reauthentication import (
+    ClearFailedLoginsOnReauth,
+    DefaultFlaskBBReauthProvider,
+    MarkFailedReauth,
+)
+from .services.registration import (
+    AutoActivateUserPostProcessor,
+    AutologinPostProcessor,
+    EmailUniquenessValidator,
+    SendActivationPostProcessor,
+    UsernameRequirements,
+    UsernameUniquenessValidator,
+    UsernameValidator,
+)
 
 
 @impl(trylast=True)
@@ -83,5 +77,35 @@ def flaskbb_errorhandlers(app):
         if current_user:
             logout_user()
             if error.reason:
-                flash(error.reason, 'danger')
-        return redirect(url_for('forum.index'))
+                flash(error.reason, "danger")
+        return redirect(url_for("forum.index"))
+
+
+@impl
+def flaskbb_gather_registration_validators():
+    blacklist = [
+        w.strip() for w in flaskbb_config["AUTH_USERNAME_BLACKLIST"].split(",")
+    ]
+
+    requirements = UsernameRequirements(
+        min=flaskbb_config["AUTH_USERNAME_MIN_LENGTH"],
+        max=flaskbb_config["AUTH_USERNAME_MAX_LENGTH"],
+        blacklist=blacklist,
+    )
+
+    return [
+        EmailUniquenessValidator(User),
+        UsernameUniquenessValidator(User),
+        UsernameValidator(requirements),
+    ]
+
+
+@impl
+def flaskbb_registration_post_processor(user):
+    if flaskbb_config["ACTIVATE_ACCOUNT"]:
+        service = SendActivationPostProcessor(account_activator_factory())
+    else:
+        service = AutologinPostProcessor()
+
+    service.post_process(user)
+    AutoActivateUserPostProcessor(db, flaskbb_config).post_process(user)

+ 2 - 23
flaskbb/auth/services/factories.py

@@ -17,36 +17,15 @@ from ...extensions import db
 from ...tokens import FlaskBBTokenSerializer
 from ...tokens.verifiers import EmailMatchesUserToken
 from ...user.models import User
-from ...user.repo import UserRepository
-from ...utils.settings import flaskbb_config
 from .activation import AccountActivator
 from .authentication import PluginAuthenticationManager
 from .password import ResetPasswordService
 from .reauthentication import PluginReauthenticationManager
-from .registration import (EmailUniquenessValidator, RegistrationService,
-                           UsernameRequirements, UsernameUniquenessValidator,
-                           UsernameValidator)
+from .registration import RegistrationService
 
 
 def registration_service_factory():
-    blacklist = [
-        w.strip()
-        for w in flaskbb_config["AUTH_USERNAME_BLACKLIST"].split(",")
-    ]
-
-    requirements = UsernameRequirements(
-        min=flaskbb_config["AUTH_USERNAME_MIN_LENGTH"],
-        max=flaskbb_config["AUTH_USERNAME_MAX_LENGTH"],
-        blacklist=blacklist
-    )
-
-    validators = [
-        EmailUniquenessValidator(User),
-        UsernameUniquenessValidator(User),
-        UsernameValidator(requirements)
-    ]
-
-    return RegistrationService(validators, UserRepository(db))
+    return RegistrationService(current_app.pluggy, User, db)
 
 
 def reset_service_factory():

+ 125 - 25
flaskbb/auth/services/registration.py

@@ -9,16 +9,37 @@
     :license: BSD, see LICENSE for more details
 """
 
+from datetime import datetime
+from itertools import chain
+
 import attr
+from flask import flash
 from flask_babelplus import gettext as _
+from flask_login import login_user
+from pytz import UTC
 from sqlalchemy import func
 
-from ...core.auth.registration import UserRegistrationService, UserValidator
-from ...core.exceptions import StopValidation, ValidationError
+from ...core.auth.registration import (
+    RegistrationPostProcessor,
+    UserRegistrationService,
+    UserValidator,
+)
+from ...core.exceptions import (
+    PersistenceError,
+    StopValidation,
+    ValidationError,
+)
+from ...user.models import User
 
 __all__ = (
-    "UsernameRequirements", "UsernameValidator", "EmailUniquenessValidator",
-    "UsernameUniquenessValidator"
+    "AutoActivateUserPostProcessor",
+    "AutologinPostProcessor",
+    "EmailUniquenessValidator",
+    "RegistrationService",
+    "SendActivationPostProcessor",
+    "UsernameRequirements",
+    "UsernameUniquenessValidator",
+    "UsernameValidator",
 )
 
 
@@ -43,25 +64,28 @@ class UsernameValidator(UserValidator):
         self._requirements = requirements
 
     def validate(self, user_info):
-        if not (self._requirements.min <= len(user_info.username) <=
-                self._requirements.max):
+        if not (
+            self._requirements.min
+            <= len(user_info.username)
+            <= self._requirements.max
+        ):
             raise ValidationError(
-                'username',
+                "username",
                 _(
-                    'Username must be between %(min)s and %(max)s characters long',  # noqa
+                    "Username must be between %(min)s and %(max)s characters long",  # noqa
                     min=self._requirements.min,
-                    max=self._requirements.max
-                )
+                    max=self._requirements.max,
+                ),
             )
 
         is_blacklisted = user_info.username in self._requirements.blacklist
         if is_blacklisted:  # pragma: no branch
             raise ValidationError(
-                'username',
+                "username",
                 _(
-                    '%(username)s is a forbidden username',
-                    username=user_info.username
-                )
+                    "%(username)s is a forbidden username",
+                    username=user_info.username,
+                ),
             )
 
 
@@ -79,11 +103,11 @@ class UsernameUniquenessValidator(UserValidator):
         ).count()
         if count != 0:  # pragma: no branch
             raise ValidationError(
-                'username',
+                "username",
                 _(
-                    '%(username)s is already registered',
-                    username=user_info.username
-                )
+                    "%(username)s is already registered",
+                    username=user_info.username,
+                ),
             )
 
 
@@ -101,11 +125,50 @@ class EmailUniquenessValidator(UserValidator):
         ).count()
         if count != 0:  # pragma: no branch
             raise ValidationError(
-                'email',
-                _('%(email)s is already registered', email=user_info.email)
+                "email",
+                _("%(email)s is already registered", email=user_info.email),
             )
 
 
+class SendActivationPostProcessor(RegistrationPostProcessor):
+
+    def __init__(self, account_activator):
+        self.account_activator = account_activator
+
+    def post_process(self, user):
+        self.account_activator.initiate_account_activation(user.email)
+        flash(
+            _(
+                "An account activation email has been sent to " "%(email)s",
+                email=user.email,
+            ),
+            "success",
+        )
+
+
+class AutologinPostProcessor(RegistrationPostProcessor):
+
+    def post_process(self, user):
+        login_user(user)
+        flash(_("Thanks for registering."), "success")
+
+
+class AutoActivateUserPostProcessor(RegistrationPostProcessor):
+    """
+    Automatically marks the user as activated if activation isn't required
+    for the forum.
+    """
+
+    def __init__(self, db, config):
+        self.db = db
+        self.config = config
+
+    def post_process(self, user):
+        if not self.config['ACTIVATE_ACCOUNT']:
+            user.activated = True
+            self.db.session.commit()
+
+
 class RegistrationService(UserRegistrationService):
     """
     Default registration service for FlaskBB, runs the registration information
@@ -119,18 +182,55 @@ class RegistrationService(UserRegistrationService):
     reasons why the registration was prevented.
     """
 
-    def __init__(self, validators, user_repo):
-        self.validators = validators
-        self.user_repo = user_repo
+    def __init__(self, plugins, users, db):
+        self.plugins = plugins
+        self.users = users
+        self.db = db
 
     def register(self, user_info):
+        try:
+            self._validate_registration(user_info)
+        except StopValidation as e:
+            self._handle_failure(user_info, e.reasons)
+            raise
+
+        user = self._store_user(user_info)
+        self._post_process(user)
+        return user
+
+    def _validate_registration(self, user_info):
         failures = []
+        validators = self.plugins.hook.flaskbb_gather_registration_validators()
 
-        for v in self.validators:
+        for v in chain.from_iterable(validators):
             try:
                 v(user_info)
             except ValidationError as e:
                 failures.append((e.attribute, e.reason))
         if failures:
             raise StopValidation(failures)
-        self.user_repo.add(user_info)
+
+    def _handle_failure(self, user_info, failures):
+        self.plugins.hook.flaskbb_registration_failure_handler(
+            user_info=user_info, failures=failures
+        )
+
+    def _store_user(self, user_info):
+        try:
+            user = User(
+                username=user_info.username,
+                email=user_info.email,
+                password=user_info.password,
+                language=user_info.language,
+                primary_group_id=user_info.group,
+                date_joined=datetime.now(UTC),
+            )
+            self.db.session.add(user)
+            self.db.session.commit()
+            return user
+        except Exception:
+            self.db.session.rollback()
+            raise PersistenceError("Could not persist user")
+
+    def _post_process(self, user):
+        self.plugins.hook.flaskbb_registration_post_processor(user=user)

+ 5 - 10
flaskbb/auth/views.py

@@ -30,7 +30,7 @@ from flaskbb.utils.settings import flaskbb_config
 
 from ..core.auth.authentication import StopAuthentication
 from ..core.auth.registration import UserRegistrationInfo
-from ..core.exceptions import StopValidation, ValidationError
+from ..core.exceptions import StopValidation, ValidationError, PersistenceError
 from ..core.tokens import TokenError
 from .plugins import impl
 from .services import (account_activator_factory,
@@ -150,13 +150,8 @@ class Register(MethodView):
             except StopValidation as e:
                 form.populate_errors(e.reasons)
                 return render_template("auth/register.html", form=form)
-
-            else:
-                try:
-                    db.session.commit()
-                except Exception:  # noqa
-                    logger.exception("Database error while resetting password")
-                    db.session.rollback()
+            except PersistenceError:
+                    logger.exception("Database error while persisting user")
                     flash(
                         _(
                             "Could not process registration due"
@@ -189,8 +184,8 @@ class ForgotPassword(MethodView):
         if form.validate_on_submit():
 
             try:
-                self.password_reset_service_factory(
-                ).initiate_password_reset(form.email.data)
+                service = self.password_reset_service_factory()
+                service.initiate_password_reset(form.email.data)
             except ValidationError:
                 flash(
                     _(

+ 28 - 0
flaskbb/core/auth/registration.py

@@ -49,6 +49,34 @@ class UserValidator(ABC):
         return self.validate(user_info)
 
 
+class RegistrationFailureHandler(ABC):
+    """
+    Used to handle failures in the registration process.
+    """
+
+    @abstractmethod
+    def handle_failure(self, user_info, failures):
+        pass
+
+    def __call__(self, user_info, failures):
+        self.handle_failure(user_info, failures)
+
+
+class RegistrationPostProcessor(ABC):
+    """
+    Used to post proccess successful registrations by the time this
+    interface is called, the user has already been persisted into the
+    database.
+    """
+
+    @abstractmethod
+    def post_process(self, user):
+        pass
+
+    def __call__(self, user):
+        self.post_process(user)
+
+
 class UserRegistrationService(ABC):
     """
     Used to manage the registration process. A default implementation is

+ 14 - 0
flaskbb/core/exceptions.py

@@ -52,3 +52,17 @@ class StopValidation(BaseFlaskBBError):
     def __init__(self, reasons):
         self.reasons = reasons
         super(StopValidation, self).__init__(reasons)
+
+
+class PersistenceError(BaseFlaskBBError):
+    """
+    Used to catch down errors when persisting models to the database instead
+    of letting all issues percolate up, this should be raised from those
+    exceptions without smashing their tracebacks. Example::
+
+        try:
+            db.session.add(new_user)
+            db.session.commit()
+        except Exception:
+            raise PersistenceError("Couldn't save user account")
+    """

+ 52 - 0
flaskbb/plugins/spec.py

@@ -205,6 +205,58 @@ def flaskbb_event_user_registered(username):
     """
 
 
+@spec
+def flaskbb_gather_registration_validators():
+    """
+    Hook for gathering user registration validators, implementers must return
+    a callable that accepts a
+    :class:`~flaskbb.core.auth.registration.UserRegistrationInfo` and raises
+    a :class:`~flaskbb.core.exceptions.ValidationError` if the registration
+    is invalid or :class:`~flaskbb.core.exceptions.StopValidation` if
+    validation of the registration should end immediatey.
+
+    Example::
+
+        def cannot_be_named_fred(user_info):
+            if user_info.username.lower() == 'fred':
+                raise ValidationError(('username', 'Cannot name user fred'))
+
+        @impl
+        def flaskbb_gather_validate_user_registration():
+            return cannot_be_named_fred
+
+    .. note::
+
+        This is implemented as a hook that returns callables since the
+        callables are designed to raise exceptions.
+    """
+
+
+@spec
+def flaskbb_registration_failure_handler(user_info, failures):
+    """
+    Hook for dealing with user registration failures, receives the info
+    that user attempted to register with as well as the errors that failed
+    the registration.
+    """
+
+
+@spec
+def flaskbb_registration_post_processor(user):
+    """
+    Hook for handling actions after a user has successfully registered.
+
+    Example::
+
+        def greet_user(user):
+            flash(_("Thanks for registering {}".format(user.username)))
+
+        @impl
+        def flaskbb_registration_post_processor(user):
+            greet_user(user)
+    """
+
+
 @spec(firstresult=True)
 def flaskbb_authenticate(identifier, secret):
     """Hook for authenticating users in FlaskBB.

+ 4 - 1
tests/fixtures/plugin.py

@@ -1,7 +1,10 @@
 import pytest
 from flaskbb.plugins.manager import FlaskBBPluginManager
+from flaskbb.plugins import spec
 
 
 @pytest.fixture
 def plugin_manager():
-    return FlaskBBPluginManager("flaskbb")
+    pluggy = FlaskBBPluginManager("flaskbb")
+    pluggy.add_hookspecs(spec)
+    return pluggy

+ 0 - 2
tests/unit/auth/test_authentication.py

@@ -6,7 +6,6 @@ from flaskbb.core.auth.authentication import (AuthenticationFailureHandler,
                                               AuthenticationProvider,
                                               PostAuthenticationHandler,
                                               StopAuthentication)
-from flaskbb.plugins import spec
 from freezegun import freeze_time
 from pluggy import HookimplMarker
 from pytz import UTC
@@ -168,7 +167,6 @@ class TestPluginAuthenticationManager(object):
         db.session.rollback.assert_called_once_with()
 
     def _get_auth_manager(self, plugin_manager, db):
-        plugin_manager.add_hookspecs(spec)
         return auth.PluginAuthenticationManager(
             plugin_manager, session=db.session
         )

+ 0 - 2
tests/unit/auth/test_reauthentication.py

@@ -6,7 +6,6 @@ from flaskbb.core.auth.authentication import (PostReauthenticateHandler,
                                               ReauthenticateFailureHandler,
                                               ReauthenticateProvider,
                                               StopAuthentication)
-from flaskbb.plugins import spec
 from freezegun import freeze_time
 from pluggy import HookimplMarker
 from pytz import UTC
@@ -100,7 +99,6 @@ class TestPluginAuthenticationManager(object):
         db.session.rollback.assert_called_once_with()
 
     def _get_auth_manager(self, plugin_manager, db):
-        plugin_manager.add_hookspecs(spec)
         return reauth.PluginReauthenticationManager(
             plugin_manager, session=db.session
         )

+ 98 - 53
tests/unit/auth/test_registration.py

@@ -1,69 +1,114 @@
 import pytest
+from pluggy import HookimplMarker
 
-from flaskbb.auth.services import registration
-from flaskbb.core.auth.registration import UserRegistrationInfo
-from flaskbb.core.exceptions import StopValidation, ValidationError
-from flaskbb.core.user.repo import UserRepository
+from flaskbb.auth.services.registration import RegistrationService
+from flaskbb.core.auth.registration import (RegistrationFailureHandler,
+                                            RegistrationPostProcessor,
+                                            UserRegistrationInfo,
+                                            UserValidator)
+from flaskbb.core.exceptions import (PersistenceError, StopValidation,
+                                     ValidationError)
+from flaskbb.user.models import User
 
 pytestmark = pytest.mark.usefixtures('default_settings')
 
 
-class RaisingValidator(registration.UserValidator):
+class RaisingValidator(UserValidator):
 
     def validate(self, user_info):
-        raise ValidationError('test', 'just a little whoopsie-diddle')
-
-
-def test_doesnt_register_user_if_validator_fails_with_ValidationError(mocker):
-    repo = mocker.Mock(UserRepository)
-    service = registration.RegistrationService([RaisingValidator()], repo)
-
-    with pytest.raises(StopValidation):
-        service.register(
-            UserRegistrationInfo(
-                username='fred',
-                password='lol',
-                email='fred@fred.fred',
-                language='fredspeak',
-                group=4
-            )
-        )
+        raise ValidationError('username', 'nope')
+
+
+class TestRegistrationService(object):
+    fred = UserRegistrationInfo(
+        username='Fred',
+        password='Fred',
+        email='fred@fred.com',
+        language='fred',
+        group=4
+    )
 
-    repo.add.assert_not_called()
+    def test_raises_stop_validation_if_validators_fail(
+            self, plugin_manager, database
+    ):
+        service = self._get_service(plugin_manager, database)
+        plugin_manager.register(self.impls(validator=RaisingValidator()))
 
+        with pytest.raises(StopValidation) as excinfo:
+            service.register(self.fred)
 
-def test_gathers_up_all_errors_during_registration(mocker):
-    repo = mocker.Mock(UserRepository)
-    service = registration.RegistrationService([
-        RaisingValidator(), RaisingValidator()
-    ], repo)
+        assert ('username', 'nope') in excinfo.value.reasons
 
-    with pytest.raises(StopValidation) as excinfo:
-        service.register(
-            UserRegistrationInfo(
-                username='fred',
-                password='lol',
-                email='fred@fred.fred',
-                language='fredspeak',
-                group=4
-            )
+    def test_calls_failure_handlers_if_validation_fails(
+            self, plugin_manager, database, mocker
+    ):
+        service = self._get_service(plugin_manager, database)
+        failure = mocker.MagicMock(spec=RegistrationFailureHandler)
+        plugin_manager.register(
+            self.impls(validator=RaisingValidator(), failure=failure)
         )
 
-    repo.add.assert_not_called()
-    assert len(excinfo.value.reasons) == 2
-    assert all(('test', 'just a little whoopsie-diddle') == r
-               for r in excinfo.value.reasons)
+        with pytest.raises(StopValidation) as excinfo:
+            service.register(self.fred)
 
+        failure.assert_called_once_with(self.fred, excinfo.value.reasons)
 
-def test_registers_user_if_no_errors_occurs(mocker):
-    repo = mocker.Mock(UserRepository)
-    service = registration.RegistrationService([], repo)
-    user_info = UserRegistrationInfo(
-        username='fred',
-        password='lol',
-        email='fred@fred.fred',
-        language='fredspeak',
-        group=4
-    )
-    service.register(user_info)
-    repo.add.assert_called_with(user_info)
+    def test_registers_user_if_everything_is_good(
+            self, database, plugin_manager
+    ):
+        service = self._get_service(plugin_manager, database)
+
+        service.register(self.fred)
+
+        actual_fred = User.query.filter(User.username == 'Fred').one()
+
+        assert actual_fred.id is not None
+
+    def test_calls_post_processors_if_user_registration_works(
+            self, database, plugin_manager, mocker
+    ):
+        service = self._get_service(plugin_manager, database)
+        post_process = mocker.MagicMock(spec=RegistrationPostProcessor)
+        plugin_manager.register(self.impls(post_process=post_process))
+
+        fred = service.register(self.fred)
+
+        post_process.assert_called_once_with(fred)
+
+    def test_raises_persistenceerror_if_saving_user_goes_wrong(
+            self, database, plugin_manager, Fred
+    ):
+        service = self._get_service(plugin_manager, database)
+
+        with pytest.raises(PersistenceError):
+            service.register(self.fred)
+
+    @staticmethod
+    def _get_service(plugin_manager, db):
+        return RegistrationService(plugins=plugin_manager, users=User, db=db)
+
+    @staticmethod
+    def impls(validator=None, failure=None, post_process=None):
+        impl = HookimplMarker('flaskbb')
+
+        class Impls:
+            if validator is not None:
+
+                @impl
+                def flaskbb_gather_registration_validators(self):
+                    return [validator]
+
+            if failure is not None:
+
+                @impl
+                def flaskbb_registration_failure_handler(
+                        self, user_info, failures
+                ):
+                    failure(user_info, failures)
+
+            if post_process is not None:
+
+                @impl
+                def flaskbb_registration_post_processor(self, user):
+                    post_process(user)
+        return Impls()