Browse Source

Merge pull request #472 from justanr/polish-registration-hooks

Polish registration hooks
Alec Nikolas Reiter 7 years ago
parent
commit
bf6be2f587

+ 9 - 0
docs/development/api/registration.rst

@@ -19,6 +19,12 @@ Registration Interfaces
 .. autoclass:: UserRegistrationService
     :members:
 
+.. autoclass:: RegistrationFailureHandler
+    :members:
+
+.. autoclass:: RegistrationPostProcessor
+    :members:
+
 
 Registration Provided Implementations
 -------------------------------------
@@ -29,4 +35,7 @@ Registration Provided Implementations
 .. autoclass:: UsernameValidator
 .. autoclass:: UsernameUniquenessValidator
 .. autoclass:: EmailUniquenessValidator
+.. autoclass:: SendActivationPostProcessor
+.. autoclass:: AutologinPostProcessor
+.. autoclass:: AutoActivateUserPostProcessor
 .. autoclass:: RegistrationService

+ 15 - 0
docs/development/hooks/event.rst

@@ -5,11 +5,26 @@
 FlaskBB Event Hooks
 ===================
 
+Post and Topic Events
+---------------------
+
 .. autofunction:: flaskbb_event_post_save_before
 .. autofunction:: flaskbb_event_post_save_after
 .. autofunction:: flaskbb_event_topic_save_before
 .. autofunction:: flaskbb_event_topic_save_after
+
+Registration Events
+-------------------
+
 .. autofunction:: flaskbb_event_user_registered
+.. autofunction:: flaskbb_gather_registration_validators
+.. autofunction:: flaskbb_registration_post_processor
+.. autofunction:: flaskbb_registration_failure_handler
+
+
+Authentication Events
+---------------------
+
 .. autofunction:: flaskbb_authenticate
 .. autofunction:: flaskbb_post_authenticate
 .. autofunction:: flaskbb_authentication_failed

+ 65 - 31
flaskbb/auth/plugins.py

@@ -8,39 +8,34 @@
     :license: BSD, see LICENSE for more details
 """
 from flask import flash, redirect, url_for
-from flask_babelplus import gettext as _
-from flask_login import current_user, login_user, logout_user
+from flask_login import current_user, logout_user
 
 from . import impl
 from ..core.auth.authentication import ForceLogout
+from ..extensions import db
 from ..user.models import User
 from ..utils.settings import flaskbb_config
-from .services.authentication import (BlockUnactivatedUser, ClearFailedLogins,
-                                      DefaultFlaskBBAuthProvider,
-                                      MarkFailedLogin)
+from .services.authentication import (
+    BlockUnactivatedUser,
+    ClearFailedLogins,
+    DefaultFlaskBBAuthProvider,
+    MarkFailedLogin,
+)
 from .services.factories import account_activator_factory
-from .services.reauthentication import (ClearFailedLoginsOnReauth,
-                                        DefaultFlaskBBReauthProvider,
-                                        MarkFailedReauth)
-
-
-@impl
-def flaskbb_event_user_registered(username):
-    user = User.query.filter_by(username=username).first()
-
-    if flaskbb_config["ACTIVATE_ACCOUNT"]:
-        service = account_activator_factory()
-        service.initiate_account_activation(user.email)
-        flash(
-            _(
-                "An account activation email has been sent to "
-                "%(email)s",
-                email=user.email
-            ), "success"
-        )
-    else:
-        login_user(user)
-        flash(_("Thanks for registering."), "success")
+from .services.reauthentication import (
+    ClearFailedLoginsOnReauth,
+    DefaultFlaskBBReauthProvider,
+    MarkFailedReauth,
+)
+from .services.registration import (
+    AutoActivateUserPostProcessor,
+    AutologinPostProcessor,
+    EmailUniquenessValidator,
+    SendActivationPostProcessor,
+    UsernameRequirements,
+    UsernameUniquenessValidator,
+    UsernameValidator,
+)
 
 
 @impl(trylast=True)
@@ -50,9 +45,13 @@ def flaskbb_authenticate(identifier, secret):
 
 @impl(tryfirst=True)
 def flaskbb_post_authenticate(user):
-    ClearFailedLogins().handle_post_auth(user)
+    handlers = [ClearFailedLogins()]
+
     if flaskbb_config["ACTIVATE_ACCOUNT"]:
-        BlockUnactivatedUser().handle_post_auth(user)
+        handlers.append(BlockUnactivatedUser())
+
+    for handler in handlers:
+        handler.handle_post_auth(user)
 
 
 @impl
@@ -83,5 +82,40 @@ def flaskbb_errorhandlers(app):
         if current_user:
             logout_user()
             if error.reason:
-                flash(error.reason, 'danger')
-        return redirect(url_for('forum.index'))
+                flash(error.reason, "danger")
+        return redirect(url_for("forum.index"))
+
+
+@impl
+def flaskbb_gather_registration_validators():
+    blacklist = [
+        w.strip() for w in flaskbb_config["AUTH_USERNAME_BLACKLIST"].split(",")
+    ]
+
+    requirements = UsernameRequirements(
+        min=flaskbb_config["AUTH_USERNAME_MIN_LENGTH"],
+        max=flaskbb_config["AUTH_USERNAME_MAX_LENGTH"],
+        blacklist=blacklist,
+    )
+
+    return [
+        EmailUniquenessValidator(User),
+        UsernameUniquenessValidator(User),
+        UsernameValidator(requirements),
+    ]
+
+
+@impl
+def flaskbb_registration_post_processor(user):
+    handlers = []
+
+    if flaskbb_config["ACTIVATE_ACCOUNT"]:
+        handlers.append(
+            SendActivationPostProcessor(account_activator_factory())
+        )
+    else:
+        handlers.append(AutologinPostProcessor())
+        handlers.append(AutoActivateUserPostProcessor(db, flaskbb_config))
+
+    for handler in handlers:
+        handler.post_process(user)

+ 2 - 23
flaskbb/auth/services/factories.py

@@ -17,36 +17,15 @@ from ...extensions import db
 from ...tokens import FlaskBBTokenSerializer
 from ...tokens.verifiers import EmailMatchesUserToken
 from ...user.models import User
-from ...user.repo import UserRepository
-from ...utils.settings import flaskbb_config
 from .activation import AccountActivator
 from .authentication import PluginAuthenticationManager
 from .password import ResetPasswordService
 from .reauthentication import PluginReauthenticationManager
-from .registration import (EmailUniquenessValidator, RegistrationService,
-                           UsernameRequirements, UsernameUniquenessValidator,
-                           UsernameValidator)
+from .registration import RegistrationService
 
 
 def registration_service_factory():
-    blacklist = [
-        w.strip()
-        for w in flaskbb_config["AUTH_USERNAME_BLACKLIST"].split(",")
-    ]
-
-    requirements = UsernameRequirements(
-        min=flaskbb_config["AUTH_USERNAME_MIN_LENGTH"],
-        max=flaskbb_config["AUTH_USERNAME_MAX_LENGTH"],
-        blacklist=blacklist
-    )
-
-    validators = [
-        EmailUniquenessValidator(User),
-        UsernameUniquenessValidator(User),
-        UsernameValidator(requirements)
-    ]
-
-    return RegistrationService(validators, UserRepository(db))
+    return RegistrationService(current_app.pluggy, User, db)
 
 
 def reset_service_factory():

+ 137 - 25
flaskbb/auth/services/registration.py

@@ -9,16 +9,37 @@
     :license: BSD, see LICENSE for more details
 """
 
+from datetime import datetime
+from itertools import chain
+
 import attr
+from flask import flash
 from flask_babelplus import gettext as _
+from flask_login import login_user
+from pytz import UTC
 from sqlalchemy import func
 
-from ...core.auth.registration import UserRegistrationService, UserValidator
-from ...core.exceptions import StopValidation, ValidationError
+from ...core.auth.registration import (
+    RegistrationPostProcessor,
+    UserRegistrationService,
+    UserValidator,
+)
+from ...core.exceptions import (
+    PersistenceError,
+    StopValidation,
+    ValidationError,
+)
+from ...user.models import User
 
 __all__ = (
-    "UsernameRequirements", "UsernameValidator", "EmailUniquenessValidator",
-    "UsernameUniquenessValidator"
+    "AutoActivateUserPostProcessor",
+    "AutologinPostProcessor",
+    "EmailUniquenessValidator",
+    "RegistrationService",
+    "SendActivationPostProcessor",
+    "UsernameRequirements",
+    "UsernameUniquenessValidator",
+    "UsernameValidator",
 )
 
 
@@ -43,25 +64,28 @@ class UsernameValidator(UserValidator):
         self._requirements = requirements
 
     def validate(self, user_info):
-        if not (self._requirements.min <= len(user_info.username) <=
-                self._requirements.max):
+        if not (
+            self._requirements.min
+            <= len(user_info.username)
+            <= self._requirements.max
+        ):
             raise ValidationError(
-                'username',
+                "username",
                 _(
-                    'Username must be between %(min)s and %(max)s characters long',  # noqa
+                    "Username must be between %(min)s and %(max)s characters long",  # noqa
                     min=self._requirements.min,
-                    max=self._requirements.max
-                )
+                    max=self._requirements.max,
+                ),
             )
 
         is_blacklisted = user_info.username in self._requirements.blacklist
         if is_blacklisted:  # pragma: no branch
             raise ValidationError(
-                'username',
+                "username",
                 _(
-                    '%(username)s is a forbidden username',
-                    username=user_info.username
-                )
+                    "%(username)s is a forbidden username",
+                    username=user_info.username,
+                ),
             )
 
 
@@ -79,11 +103,11 @@ class UsernameUniquenessValidator(UserValidator):
         ).count()
         if count != 0:  # pragma: no branch
             raise ValidationError(
-                'username',
+                "username",
                 _(
-                    '%(username)s is already registered',
-                    username=user_info.username
-                )
+                    "%(username)s is already registered",
+                    username=user_info.username,
+                ),
             )
 
 
@@ -101,11 +125,62 @@ class EmailUniquenessValidator(UserValidator):
         ).count()
         if count != 0:  # pragma: no branch
             raise ValidationError(
-                'email',
-                _('%(email)s is already registered', email=user_info.email)
+                "email",
+                _("%(email)s is already registered", email=user_info.email),
             )
 
 
+class SendActivationPostProcessor(RegistrationPostProcessor):
+    """
+    Sends an activation request after registration
+
+    :param account_activator:
+    :type account_activator: :class:`~flaskbb.core.auth.activation.AccountActivator`
+    """  # noqa
+
+    def __init__(self, account_activator):
+        self.account_activator = account_activator
+
+    def post_process(self, user):
+        self.account_activator.initiate_account_activation(user.email)
+        flash(
+            _(
+                "An account activation email has been sent to %(email)s",
+                email=user.email,
+            ),
+            "success",
+        )
+
+
+class AutologinPostProcessor(RegistrationPostProcessor):
+    """
+    Automatically logs a user in after registration
+    """
+
+    def post_process(self, user):
+        login_user(user)
+        flash(_("Thanks for registering."), "success")
+
+
+class AutoActivateUserPostProcessor(RegistrationPostProcessor):
+    """
+    Automatically marks the user as activated if activation isn't required
+    for the forum.
+
+    :param db: Configured Flask-SQLAlchemy extension object
+    :param config: Current flaskbb configuration object
+    """
+
+    def __init__(self, db, config):
+        self.db = db
+        self.config = config
+
+    def post_process(self, user):
+        if not self.config['ACTIVATE_ACCOUNT']:
+            user.activated = True
+            self.db.session.commit()
+
+
 class RegistrationService(UserRegistrationService):
     """
     Default registration service for FlaskBB, runs the registration information
@@ -119,18 +194,55 @@ class RegistrationService(UserRegistrationService):
     reasons why the registration was prevented.
     """
 
-    def __init__(self, validators, user_repo):
-        self.validators = validators
-        self.user_repo = user_repo
+    def __init__(self, plugins, users, db):
+        self.plugins = plugins
+        self.users = users
+        self.db = db
 
     def register(self, user_info):
+        try:
+            self._validate_registration(user_info)
+        except StopValidation as e:
+            self._handle_failure(user_info, e.reasons)
+            raise
+
+        user = self._store_user(user_info)
+        self._post_process(user)
+        return user
+
+    def _validate_registration(self, user_info):
         failures = []
+        validators = self.plugins.hook.flaskbb_gather_registration_validators()
 
-        for v in self.validators:
+        for v in chain.from_iterable(validators):
             try:
                 v(user_info)
             except ValidationError as e:
                 failures.append((e.attribute, e.reason))
         if failures:
             raise StopValidation(failures)
-        self.user_repo.add(user_info)
+
+    def _handle_failure(self, user_info, failures):
+        self.plugins.hook.flaskbb_registration_failure_handler(
+            user_info=user_info, failures=failures
+        )
+
+    def _store_user(self, user_info):
+        try:
+            user = User(
+                username=user_info.username,
+                email=user_info.email,
+                password=user_info.password,
+                language=user_info.language,
+                primary_group_id=user_info.group,
+                date_joined=datetime.now(UTC),
+            )
+            self.db.session.add(user)
+            self.db.session.commit()
+            return user
+        except Exception:
+            self.db.session.rollback()
+            raise PersistenceError("Could not persist user")
+
+    def _post_process(self, user):
+        self.plugins.hook.flaskbb_registration_post_processor(user=user)

+ 6 - 11
flaskbb/auth/views.py

@@ -30,7 +30,7 @@ from flaskbb.utils.settings import flaskbb_config
 
 from ..core.auth.authentication import StopAuthentication
 from ..core.auth.registration import UserRegistrationInfo
-from ..core.exceptions import StopValidation, ValidationError
+from ..core.exceptions import StopValidation, ValidationError, PersistenceError
 from ..core.tokens import TokenError
 from .plugins import impl
 from .services import (account_activator_factory,
@@ -150,13 +150,8 @@ class Register(MethodView):
             except StopValidation as e:
                 form.populate_errors(e.reasons)
                 return render_template("auth/register.html", form=form)
-
-            else:
-                try:
-                    db.session.commit()
-                except Exception:  # noqa
-                    logger.exception("Database error while resetting password")
-                    db.session.rollback()
+            except PersistenceError:
+                    logger.exception("Database error while persisting user")
                     flash(
                         _(
                             "Could not process registration due"
@@ -189,8 +184,8 @@ class ForgotPassword(MethodView):
         if form.validate_on_submit():
 
             try:
-                self.password_reset_service_factory(
-                ).initiate_password_reset(form.email.data)
+                service = self.password_reset_service_factory()
+                service.initiate_password_reset(form.email.data)
             except ValidationError:
                 flash(
                     _(
@@ -279,7 +274,7 @@ class RequestActivationToken(MethodView):
                         "your email address."
                     ), "success"
                 )
-                return redirect(url_for("auth.activate_account"))
+                return redirect(url_for('forum.index'))
 
         return render_template(
             "auth/request_account_activation.html", form=form

+ 41 - 0
flaskbb/core/auth/registration.py

@@ -49,6 +49,47 @@ class UserValidator(ABC):
         return self.validate(user_info)
 
 
+class RegistrationFailureHandler(ABC):
+    """
+    Used to handle failures in the registration process.
+    """
+
+    @abstractmethod
+    def handle_failure(self, user_info, failures):
+        """
+        This method is abstract.
+
+        :param user_info: The provided registration information.
+        :param failures: Tuples of (attribute, message) from the failure
+        :type user_info: :class:`~flaskbb.core.auth.registration.UserRegistrationInfo`
+        """  # noqa
+        pass
+
+    def __call__(self, user_info, failures):
+        self.handle_failure(user_info, failures)
+
+
+class RegistrationPostProcessor(ABC):
+    """
+    Used to post proccess successful registrations by the time this
+    interface is called, the user has already been persisted into the
+    database.
+    """
+
+    @abstractmethod
+    def post_process(self, user):
+        """
+        This method is abstract.
+
+        :param user: The registered, persisted user.
+        :type user: :class:`~flaskbb.user.models.User`
+        """
+        pass
+
+    def __call__(self, user):
+        self.post_process(user)
+
+
 class UserRegistrationService(ABC):
     """
     Used to manage the registration process. A default implementation is

+ 14 - 0
flaskbb/core/exceptions.py

@@ -52,3 +52,17 @@ class StopValidation(BaseFlaskBBError):
     def __init__(self, reasons):
         self.reasons = reasons
         super(StopValidation, self).__init__(reasons)
+
+
+class PersistenceError(BaseFlaskBBError):
+    """
+    Used to catch down errors when persisting models to the database instead
+    of letting all issues percolate up, this should be raised from those
+    exceptions without smashing their tracebacks. Example::
+
+        try:
+            db.session.add(new_user)
+            db.session.commit()
+        except Exception:
+            raise PersistenceError("Couldn't save user account")
+    """

+ 0 - 0
flaskbb/core/user/__init__.py


+ 0 - 31
flaskbb/core/user/repo.py

@@ -1,31 +0,0 @@
-# -*- coding: utf-8 -*-
-"""
-    flaskbb.core.user.repo
-    ~~~~~~~~~~~~~~~~~~~~~~
-
-    This module provides an abstracted access to users stored in the database.
-
-    :copyright: (c) 2014-2018 the FlaskbBB Team.
-    :license: BSD, see LICENSE for more details
-"""
-
-from ..._compat import ABC
-from abc import abstractmethod
-
-
-class UserRepository(ABC):
-    @abstractmethod
-    def add(self, user_info):
-        pass
-
-    @abstractmethod
-    def find_by(self, **kwargs):
-        pass
-
-    @abstractmethod
-    def get(self, user_id):
-        pass
-
-    @abstractmethod
-    def find_one_by(self, **kwargs):
-        pass

+ 88 - 0
flaskbb/plugins/spec.py

@@ -197,14 +197,102 @@ def flaskbb_event_topic_save_after(topic, is_new):
     """
 
 
+# TODO(anr): When pluggy 1.0 is released, mark this spec deprecated
 @spec
 def flaskbb_event_user_registered(username):
     """Hook for handling events after a user is registered
 
+    .. warning::
+
+        This hook is deprecated in favor of
+        :func:`~flaskbb.plugins.spec.flaskbb_registration_post_processor`
+
     :param username: The username of the newly registered user.
     """
 
 
+@spec
+def flaskbb_gather_registration_validators():
+    """
+    Hook for gathering user registration validators, implementers must return
+    a callable that accepts a
+    :class:`~flaskbb.core.auth.registration.UserRegistrationInfo` and raises
+    a :class:`~flaskbb.core.exceptions.ValidationError` if the registration
+    is invalid or :class:`~flaskbb.core.exceptions.StopValidation` if
+    validation of the registration should end immediatey.
+
+    Example::
+
+        def cannot_be_named_fred(user_info):
+            if user_info.username.lower() == 'fred':
+                raise ValidationError(('username', 'Cannot name user fred'))
+
+        @impl
+        def flaskbb_gather_registration_validators():
+            return [cannot_be_named_fred]
+
+    .. note::
+
+        This is implemented as a hook that returns callables since the
+        callables are designed to raise exceptions that are aggregated to
+        form the failure message for the registration response.
+
+    See Also: :class:`~flaskbb.core.auth.registration.UserValidator`
+    """
+
+
+@spec
+def flaskbb_registration_failure_handler(user_info, failures):
+    """
+    Hook for dealing with user registration failures, receives the info
+    that user attempted to register with as well as the errors that failed
+    the registration.
+
+    Example::
+
+        from .utils import fuzz_username
+
+        def has_already_registered(failures):
+            return any(
+                attr = "username" and "already registered" in msg
+                for (attr, msg) in failures
+            )
+
+
+        def suggest_alternate_usernames(user_info, failures):
+            if has_already_registered(failures):
+                suggestions = fuzz_username(user_info.username)
+                failures.append(("username", "Try: {}".format(suggestions)))
+
+
+        @impl
+        def flaskbb_registration_failure_handler(user_info, failures):
+            suggest_alternate_usernames(user_info, failures)
+
+    See Also: :class:`~flaskbb.core.auth.registration.RegistrationFailureHandler`
+    """  # noqa
+
+
+@spec
+def flaskbb_registration_post_processor(user):
+    """
+    Hook for handling actions after a user has successfully registered. This
+    spec receives the user object after it has been successfully persisted
+    to the database.
+
+    Example::
+
+        def greet_user(user):
+            flash(_("Thanks for registering {}".format(user.username)))
+
+        @impl
+        def flaskbb_registration_post_processor(user):
+            greet_user(user)
+
+    See Also: :class:`~flaskbb.core.auth.registration.RegistrationPostProcessor`
+    """  # noqa
+
+
 @spec(firstresult=True)
 def flaskbb_authenticate(identifier, secret):
     """Hook for authenticating users in FlaskBB.

+ 4 - 0
flaskbb/templates/layout.html

@@ -119,10 +119,14 @@
                                     </a>
                                     <button type="button" class="btn btn-primary dropdown-toggle" data-toggle="dropdown"><span class="caret"></span></button>
                                     <ul class="dropdown-menu" role="menu">
+                                        {# MAYBE(anr): Move this into a hook?? #}
                                         {% if flaskbb_config["REGISTRATION_ENABLED"] %}
                                         <li><a href="{{ url_for('auth.register') }}"><span class="fa fa-user-plus fa-fw"></span> {% trans %}Register{% endtrans %}</a></li>
                                         {% endif %}
                                         <li><a href="{{ url_for('auth.forgot_password') }}"><span class="fa fa-undo fa-fw"></span> {% trans %}Reset Password{% endtrans %}</a></li>
+                                        {% if flaskbb_config["ACTIVATE_ACCOUNT"] %}
+                                        <li><a href="{{ url_for('auth.request_activation_token') }}"><span class="fa fa-fw fa-ticket"></span> {% trans %}Activate Account{% endtrans %}</a></li>
+                                        {% endif %}
                                     </ul>
                                 </div>
                             </li>

+ 0 - 32
flaskbb/user/repo.py

@@ -1,32 +0,0 @@
-from datetime import datetime
-
-from pytz import UTC
-
-from ..core.user.repo import UserRepository as BaseUserRepository
-from .models import User
-
-
-class UserRepository(BaseUserRepository):
-
-    def __init__(self, db):
-        self.db = db
-
-    def add(self, user_info):
-        user = User(
-            username=user_info.username,
-            email=user_info.email,
-            password=user_info.password,
-            language=user_info.language,
-            primary_group_id=user_info.group,
-            date_joined=datetime.now(UTC)
-        )
-        self.db.session.add(user)
-
-    def get(self, user_id):
-        return User.query.get(user_id)
-
-    def find_by(self, **kwargs):
-        return User.query.filter_by(**kwargs).all()
-
-    def find_one_by(self, **kwargs):
-        return User.query.filter_by(**kwargs).first()

+ 4 - 1
tests/fixtures/plugin.py

@@ -1,7 +1,10 @@
 import pytest
 from flaskbb.plugins.manager import FlaskBBPluginManager
+from flaskbb.plugins import spec
 
 
 @pytest.fixture
 def plugin_manager():
-    return FlaskBBPluginManager("flaskbb")
+    pluggy = FlaskBBPluginManager("flaskbb")
+    pluggy.add_hookspecs(spec)
+    return pluggy

+ 1 - 1
tests/fixtures/user.py

@@ -72,7 +72,7 @@ def unactivated_user(default_groups):
     """
     Creates an unactivated user in the default user group
     """
-    user = User(username='notactive', email='not@active.com',
+    user = User(username='notactive', email='notactive@example.com',
                 password='password', primary_group=default_groups[3],
                 activated=False)
     user.save()

+ 0 - 2
tests/unit/auth/test_authentication.py

@@ -6,7 +6,6 @@ from flaskbb.core.auth.authentication import (AuthenticationFailureHandler,
                                               AuthenticationProvider,
                                               PostAuthenticationHandler,
                                               StopAuthentication)
-from flaskbb.plugins import spec
 from freezegun import freeze_time
 from pluggy import HookimplMarker
 from pytz import UTC
@@ -168,7 +167,6 @@ class TestPluginAuthenticationManager(object):
         db.session.rollback.assert_called_once_with()
 
     def _get_auth_manager(self, plugin_manager, db):
-        plugin_manager.add_hookspecs(spec)
         return auth.PluginAuthenticationManager(
             plugin_manager, session=db.session
         )

+ 0 - 2
tests/unit/auth/test_reauthentication.py

@@ -6,7 +6,6 @@ from flaskbb.core.auth.authentication import (PostReauthenticateHandler,
                                               ReauthenticateFailureHandler,
                                               ReauthenticateProvider,
                                               StopAuthentication)
-from flaskbb.plugins import spec
 from freezegun import freeze_time
 from pluggy import HookimplMarker
 from pytz import UTC
@@ -100,7 +99,6 @@ class TestPluginAuthenticationManager(object):
         db.session.rollback.assert_called_once_with()
 
     def _get_auth_manager(self, plugin_manager, db):
-        plugin_manager.add_hookspecs(spec)
         return reauth.PluginReauthenticationManager(
             plugin_manager, session=db.session
         )

+ 107 - 56
tests/unit/auth/test_registration.py

@@ -1,69 +1,120 @@
 import pytest
+from pluggy import HookimplMarker
 
-from flaskbb.auth.services import registration
-from flaskbb.core.auth.registration import UserRegistrationInfo
-from flaskbb.core.exceptions import StopValidation, ValidationError
-from flaskbb.core.user.repo import UserRepository
+from flaskbb.auth.services.registration import RegistrationService
+from flaskbb.core.auth.registration import (
+    RegistrationFailureHandler,
+    RegistrationPostProcessor,
+    UserRegistrationInfo,
+    UserValidator,
+)
+from flaskbb.core.exceptions import (
+    PersistenceError,
+    StopValidation,
+    ValidationError,
+)
+from flaskbb.user.models import User
 
-pytestmark = pytest.mark.usefixtures('default_settings')
+pytestmark = pytest.mark.usefixtures("default_settings")
 
 
-class RaisingValidator(registration.UserValidator):
+class RaisingValidator(UserValidator):
 
     def validate(self, user_info):
-        raise ValidationError('test', 'just a little whoopsie-diddle')
-
-
-def test_doesnt_register_user_if_validator_fails_with_ValidationError(mocker):
-    repo = mocker.Mock(UserRepository)
-    service = registration.RegistrationService([RaisingValidator()], repo)
-
-    with pytest.raises(StopValidation):
-        service.register(
-            UserRegistrationInfo(
-                username='fred',
-                password='lol',
-                email='fred@fred.fred',
-                language='fredspeak',
-                group=4
-            )
-        )
+        raise ValidationError("username", "nope")
+
+
+class TestRegistrationService(object):
+    fred = UserRegistrationInfo(
+        username="Fred",
+        password="Fred",
+        email="fred@fred.com",
+        language="fred",
+        group=4,
+    )
 
-    repo.add.assert_not_called()
+    def test_raises_stop_validation_if_validators_fail(
+        self, plugin_manager, database
+    ):
+        service = self._get_service(plugin_manager, database)
+        plugin_manager.register(self.impls(validator=RaisingValidator()))
 
+        with pytest.raises(StopValidation) as excinfo:
+            service.register(self.fred)
 
-def test_gathers_up_all_errors_during_registration(mocker):
-    repo = mocker.Mock(UserRepository)
-    service = registration.RegistrationService([
-        RaisingValidator(), RaisingValidator()
-    ], repo)
+        assert ("username", "nope") in excinfo.value.reasons
 
-    with pytest.raises(StopValidation) as excinfo:
-        service.register(
-            UserRegistrationInfo(
-                username='fred',
-                password='lol',
-                email='fred@fred.fred',
-                language='fredspeak',
-                group=4
-            )
+    def test_calls_failure_handlers_if_validation_fails(
+        self, plugin_manager, database, mocker
+    ):
+        service = self._get_service(plugin_manager, database)
+        failure = mocker.MagicMock(spec=RegistrationFailureHandler)
+        plugin_manager.register(
+            self.impls(validator=RaisingValidator(), failure=failure)
         )
 
-    repo.add.assert_not_called()
-    assert len(excinfo.value.reasons) == 2
-    assert all(('test', 'just a little whoopsie-diddle') == r
-               for r in excinfo.value.reasons)
-
-
-def test_registers_user_if_no_errors_occurs(mocker):
-    repo = mocker.Mock(UserRepository)
-    service = registration.RegistrationService([], repo)
-    user_info = UserRegistrationInfo(
-        username='fred',
-        password='lol',
-        email='fred@fred.fred',
-        language='fredspeak',
-        group=4
-    )
-    service.register(user_info)
-    repo.add.assert_called_with(user_info)
+        with pytest.raises(StopValidation) as excinfo:
+            service.register(self.fred)
+
+        failure.assert_called_once_with(self.fred, excinfo.value.reasons)
+
+    def test_registers_user_if_everything_is_good(
+        self, database, plugin_manager
+    ):
+        service = self._get_service(plugin_manager, database)
+
+        service.register(self.fred)
+
+        actual_fred = User.query.filter(User.username == "Fred").one()
+
+        assert actual_fred.id is not None
+
+    def test_calls_post_processors_if_user_registration_works(
+        self, database, plugin_manager, mocker
+    ):
+        service = self._get_service(plugin_manager, database)
+        post_process = mocker.MagicMock(spec=RegistrationPostProcessor)
+        plugin_manager.register(self.impls(post_process=post_process))
+
+        fred = service.register(self.fred)
+
+        post_process.assert_called_once_with(fred)
+
+    def test_raises_persistenceerror_if_saving_user_goes_wrong(
+        self, database, plugin_manager, Fred
+    ):
+        service = self._get_service(plugin_manager, database)
+
+        with pytest.raises(PersistenceError):
+            service.register(self.fred)
+
+    @staticmethod
+    def _get_service(plugin_manager, db):
+        return RegistrationService(plugins=plugin_manager, users=User, db=db)
+
+    @staticmethod
+    def impls(validator=None, failure=None, post_process=None):
+        impl = HookimplMarker("flaskbb")
+
+        class Impls:
+            if validator is not None:
+
+                @impl
+                def flaskbb_gather_registration_validators(self):
+                    return [validator]
+
+            if failure is not None:
+
+                @impl
+                def flaskbb_registration_failure_handler(
+                    self, user_info, failures
+                ):
+                    failure(user_info, failures)
+
+            if post_process is not None:
+
+                @impl
+                def flaskbb_registration_post_processor(self, user):
+                    post_process(user)
+
+        return Impls()

+ 83 - 0
tests/unit/auth/test_registration_processors.py

@@ -0,0 +1,83 @@
+from flask import get_flashed_messages
+from flask_login import current_user
+
+from flaskbb.auth.services.registration import (
+    AutoActivateUserPostProcessor,
+    AutologinPostProcessor,
+    SendActivationPostProcessor,
+)
+from flaskbb.core.auth.activation import AccountActivator
+from flaskbb.utils.settings import flaskbb_config
+
+
+class TestAutoActivateUserPostProcessor(object):
+
+    def test_activates_when_user_activation_isnt_required(
+        self, unactivated_user, database
+    ):
+        config = {"ACTIVATE_ACCOUNT": False}
+        processor = AutoActivateUserPostProcessor(database, config)
+        processor.post_process(unactivated_user)
+
+        assert unactivated_user.activated
+
+    def test_doesnt_activate_when_user_activation_is_required(
+        self, database, unactivated_user
+    ):
+        config = {"ACTIVATE_ACCOUNT": True}
+        processor = AutoActivateUserPostProcessor(database, config)
+        processor.post_process(unactivated_user)
+
+        assert not unactivated_user.activated
+
+
+class TestAutologinPostProcessor(object):
+
+    def test_sets_user_as_current_user(
+        self, Fred, request_context, default_settings
+    ):
+        flaskbb_config["ACTIVATE_ACCOUNT"] = False
+        processor = AutologinPostProcessor()
+
+        processor.post_process(Fred)
+
+        expected_message = ("success", "Thanks for registering.")
+
+        assert current_user.username == Fred.username
+        assert (
+            get_flashed_messages(with_categories=True)[0] == expected_message
+        )
+
+
+class TestSendActivationPostProcessor(object):
+
+    class SpyingActivator(AccountActivator):
+
+        def __init__(self):
+            self.called = False
+            self.user = None
+
+        def initiate_account_activation(self, user):
+            self.called = True
+            self.user = user
+
+        def activate_account(self, token):
+            pass
+
+    def test_sends_activation_notice(
+        self, request_context, unactivated_user, default_settings
+    ):
+        activator = self.SpyingActivator()
+        processor = SendActivationPostProcessor(activator)
+
+        processor.post_process(unactivated_user)
+
+        expected_message = (
+            "success",
+            "An account activation email has been sent to notactive@example.com",  # noqa
+        )
+        assert activator.called
+        assert activator.user == unactivated_user.email
+        assert (
+            get_flashed_messages(with_categories=True)[0] == expected_message
+        )