Browse Source

Extend redirect_or_next to also take into consideration the referrer when building the redirect url

Peter Justin 4 years ago
parent
commit
e184522b3a
2 changed files with 71 additions and 43 deletions
  1. 9 8
      flaskbb/management/views.py
  2. 62 35
      flaskbb/utils/helpers.py

+ 9 - 8
flaskbb/management/views.py

@@ -34,8 +34,8 @@ from flaskbb.plugins.utils import validate_plugin
 from flaskbb.user.models import Group, Guest, User
 from flaskbb.utils.forms import populate_settings_dict, populate_settings_form
 from flaskbb.utils.helpers import (get_online_users, register_view,
-                                   render_template, time_diff, time_utcnow,
-                                   FlashAndRedirect)
+                                   render_template, redirect_or_next,
+                                   time_diff, time_utcnow, FlashAndRedirect)
 from flaskbb.utils.requirements import (CanBanUser, CanEditUser, IsAdmin,
                                         IsAtleastModerator,
                                         IsAtleastSuperModerator)
@@ -478,7 +478,8 @@ class BanUser(MethodView):
             flash(_("User is now banned."), "success")
         else:
             flash(_("Could not ban user."), "danger")
-        return redirect(url_for("management.banned_users"))
+
+        return redirect_or_next(url_for("management.banned_users"))
 
 
 class UnbanUser(MethodView):
@@ -541,7 +542,7 @@ class UnbanUser(MethodView):
         else:
             flash(_("Could not unban user."), "danger")
 
-        return redirect(url_for("management.banned_users"))
+        return redirect_or_next(url_for("management.users"))
 
 
 class Groups(MethodView):
@@ -1040,13 +1041,13 @@ class MarkReportRead(MethodView):
                     _("Report %(id)s is already marked as read.", id=report.id),
                     "success"
                 )
-                return redirect(url_for("management.reports"))
+                return redirect_or_next(url_for("management.reports"))
 
             report.zapped_by = current_user.id
             report.zapped = time_utcnow()
             report.save()
             flash(_("Report %(id)s marked as read.", id=report.id), "success")
-            return redirect(url_for("management.reports"))
+            return redirect_or_next(url_for("management.reports"))
 
         # mark all as read
         reports = Report.query.filter(Report.zapped == None).all()
@@ -1060,7 +1061,7 @@ class MarkReportRead(MethodView):
         db.session.commit()
 
         flash(_("All reports were marked as read."), "success")
-        return redirect(url_for("management.reports"))
+        return redirect_or_next(url_for("management.reports"))
 
 
 class DeleteReport(MethodView):
@@ -1108,7 +1109,7 @@ class DeleteReport(MethodView):
         report = Report.query.filter_by(id=report_id).first_or_404()
         report.delete()
         flash(_("Report deleted."), "success")
-        return redirect(url_for("management.reports"))
+        return redirect_or_next(url_for("management.reports"))
 
 
 class CeleryStatus(MethodView):

+ 62 - 35
flaskbb/utils/helpers.py

@@ -19,6 +19,7 @@ import warnings
 from datetime import datetime, timedelta
 from email import message_from_string
 from functools import wraps
+from urllib.parse import urlparse, urljoin
 
 import pkg_resources
 import requests
@@ -63,10 +64,10 @@ def to_unicode(input_bytes, encoding="utf-8"):
 
 def slugify(text, delim=u"-"):
     """Generates an slightly worse ASCII-only slug.
-    Taken from the Flask Snippets page.
+     Taken from the Flask Snippets page.
 
-   :param text: The text which should be slugified
-   :param delim: Default "-". The delimeter for whitespace
+    :param text: The text which should be slugified
+    :param delim: Default "-". The delimeter for whitespace
     """
     text = unidecode.unidecode(text)
     result = []
@@ -76,13 +77,40 @@ def slugify(text, delim=u"-"):
     return str(delim.join(result))
 
 
-def redirect_or_next(endpoint, **kwargs):
+def is_safe_url(target):
+    """Check if target will lead to the same server
+    Ref: https://web.archive.org/web/20190128010142/http://flask.pocoo.org/snippets/62/
+
+    :param target: The redirect target
+    """
+    ref_url = urlparse(request.host_url)
+    test_url = urlparse(urljoin(request.host_url, target))
+    return (
+        test_url.scheme in ("http", "https")
+        and ref_url.netloc == test_url.netloc
+    )
+
+
+def redirect_url(endpoint, use_referrer=True):
+    """Generates a redirect url based on the referrer or endpoint."""
+    targets = [endpoint]
+    if use_referrer:
+        targets.insert(0, request.referrer)
+    for target in targets:
+        if target and is_safe_url(target):
+            return target
+
+
+def redirect_or_next(endpoint, use_referrer=True):
     """Redirects the user back to the page they were viewing or to a specified
     endpoint. Wraps Flasks :func:`Flask.redirect` function.
 
     :param endpoint: The fallback endpoint.
     """
-    return redirect(request.args.get("next") or endpoint, **kwargs)
+    return redirect(
+        request.args.get("next")
+        or redirect_url(endpoint, use_referrer)
+    )
 
 
 def render_template(template, **context):  # pragma: no cover
@@ -179,34 +207,34 @@ def do_topic_action(topics, user, action, reverse):  # noqa: C901
 
 def get_categories_and_forums(query_result, user):
     """Returns a list with categories. Every category has a list for all
-    their associated forums.
-
-    The structure looks like this::
-        [(<Category 1>,
-          [(<Forum 1>, None),
-           (<Forum 2>, <flaskbb.forum.models.ForumsRead at 0x38fdb50>)]),
-         (<Category 2>,
-          [(<Forum 3>, None),
-          (<Forum 4>, None)])]
-
-    and to unpack the values you can do this::
-        In [110]: for category, forums in x:
-           .....:     print category
-           .....:     for forum, forumsread in forums:
-           .....:         print "\t", forum, forumsread
-
-   This will print something like this:
-        <Category 1>
-            <Forum 1> None
-            <Forum 2> <flaskbb.forum.models.ForumsRead object at 0x38fdb50>
-        <Category 2>
-            <Forum 3> None
-            <Forum 4> None
-
-    :param query_result: A tuple (KeyedTuple) with all categories and forums
-
-    :param user: The user object is needed because a signed out user does not
-                 have the ForumsRead relation joined.
+     their associated forums.
+
+     The structure looks like this::
+         [(<Category 1>,
+           [(<Forum 1>, None),
+            (<Forum 2>, <flaskbb.forum.models.ForumsRead at 0x38fdb50>)]),
+          (<Category 2>,
+           [(<Forum 3>, None),
+           (<Forum 4>, None)])]
+
+     and to unpack the values you can do this::
+         In [110]: for category, forums in x:
+            .....:     print category
+            .....:     for forum, forumsread in forums:
+            .....:         print "\t", forum, forumsread
+
+    This will print something like this:
+         <Category 1>
+             <Forum 1> None
+             <Forum 2> <flaskbb.forum.models.ForumsRead object at 0x38fdb50>
+         <Category 2>
+             <Forum 3> None
+             <Forum 4> None
+
+     :param query_result: A tuple (KeyedTuple) with all categories and forums
+
+     :param user: The user object is needed because a signed out user does not
+                  have the ForumsRead relation joined.
     """
     it = itertools.groupby(query_result, operator.itemgetter(0))
 
@@ -611,8 +639,7 @@ def get_alembic_locations(plugin_dirs):
     the unique identifier of the plugin.
     """
     branches_dirs = [
-        tuple([os.path.basename(os.path.dirname(p)), p])
-        for p in plugin_dirs
+        tuple([os.path.basename(os.path.dirname(p)), p]) for p in plugin_dirs
     ]
 
     return branches_dirs