Просмотр исходного кода

Unwrap current_user when interacting with the database

Alec Nikolas Reiter 7 лет назад
Родитель
Сommit
8eb6b7f8fb
2 измененных файлов с 47 добавлено и 36 удалено
  1. 36 36
      flaskbb/forum/views.py
  2. 11 0
      flaskbb/utils/helpers.py

+ 36 - 36
flaskbb/forum/views.py

@@ -21,7 +21,7 @@ from flaskbb.extensions import db, allows
 from flaskbb.utils.settings import flaskbb_config
 from flaskbb.utils.helpers import (get_online_users, time_diff, time_utcnow,
                                    format_quote, render_template,
-                                   do_topic_action)
+                                   do_topic_action, real)
 from flaskbb.utils.requirements import (CanAccessForum, CanAccessTopic,
                                         CanDeletePost, CanDeleteTopic,
                                         CanEditPost, CanPostReply,
@@ -38,7 +38,7 @@ forum = Blueprint("forum", __name__)
 
 @forum.route("/")
 def index():
-    categories = Category.get_all(user=current_user)
+    categories = Category.get_all(user=real(current_user))
 
     # Fetch a few stats about the forum
     user_count = User.query.count()
@@ -71,7 +71,7 @@ def index():
 @forum.route("/category/<int:category_id>-<slug>")
 def view_category(category_id, slug=None):
     category, forums = Category.\
-        get_forums(category_id=category_id, user=current_user)
+        get_forums(category_id=category_id, user=real(current_user))
 
     return render_template("forum/category.html", forums=forums,
                            category=category)
@@ -84,14 +84,14 @@ def view_forum(forum_id, slug=None):
     page = request.args.get('page', 1, type=int)
 
     forum_instance, forumsread = Forum.get_forum(
-        forum_id=forum_id, user=current_user
+        forum_id=forum_id, user=real(current_user)
     )
 
     if forum_instance.external:
         return redirect(forum_instance.external)
 
     topics = Forum.get_topics(
-        forum_id=forum_instance.id, user=current_user, page=page,
+        forum_id=forum_instance.id, user=real(current_user), page=page,
         per_page=flaskbb_config["TOPICS_PER_PAGE"]
     )
 
@@ -108,7 +108,7 @@ def view_topic(topic_id, slug=None):
     page = request.args.get('page', 1, type=int)
 
     # Fetch some information about the topic
-    topic = Topic.get_topic(topic_id=topic_id, user=current_user)
+    topic = Topic.get_topic(topic_id=topic_id, user=real(current_user))
 
     # Count the topic views
     topic.views += 1
@@ -130,16 +130,16 @@ def view_topic(topic_id, slug=None):
     forumsread = None
     if current_user.is_authenticated:
         forumsread = ForumsRead.query.\
-            filter_by(user_id=current_user.id,
+            filter_by(user_id=real(current_user).id,
                       forum_id=topic.forum.id).first()
 
-    topic.update_read(current_user, topic.forum, forumsread)
+    topic.update_read(real(current_user), topic.forum, forumsread)
 
     form = None
     if Permission(CanPostReply):
         form = QuickreplyForm()
         if form.validate_on_submit():
-            post = form.save(current_user, topic)
+            post = form.save(real(current_user), topic)
             return view_post(post.id)
 
     return render_template("forum/topic.html", topic=topic, posts=posts,
@@ -179,7 +179,7 @@ def new_topic(forum_id, slug=None):
                 form=form, preview=form.content.data
             )
         if "submit" in request.form and form.validate():
-            topic = form.save(current_user, forum_instance)
+            topic = form.save(real(current_user), forum_instance)
             # redirect to the new topic
             return redirect(url_for('forum.view_topic', topic_id=topic.id))
 
@@ -277,7 +277,7 @@ def manage_forum(forum_id, slug=None):
     page = request.args.get('page', 1, type=int)
 
     forum_instance, forumsread = Forum.get_forum(forum_id=forum_id,
-                                                 user=current_user)
+                                                 user=real(current_user))
 
     # remove the current forum from the select field (move).
     available_forums = Forum.query.order_by(Forum.position).all()
@@ -292,7 +292,7 @@ def manage_forum(forum_id, slug=None):
         return redirect(forum_instance.external)
 
     topics = Forum.get_topics(
-        forum_id=forum_instance.id, user=current_user, page=page,
+        forum_id=forum_instance.id, user=real(current_user), page=page,
         per_page=flaskbb_config["TOPICS_PER_PAGE"]
     )
 
@@ -312,34 +312,34 @@ def manage_forum(forum_id, slug=None):
 
         # locking/unlocking
         if "lock" in request.form:
-            changed = do_topic_action(topics=tmp_topics, user=current_user,
+            changed = do_topic_action(topics=tmp_topics, user=real(current_user),
                                       action="locked", reverse=False)
 
             flash(_("%(count)s topics locked.", count=changed), "success")
             return redirect(mod_forum_url)
 
         elif "unlock" in request.form:
-            changed = do_topic_action(topics=tmp_topics, user=current_user,
+            changed = do_topic_action(topics=tmp_topics, user=real(current_user),
                                       action="locked", reverse=True)
             flash(_("%(count)s topics unlocked.", count=changed), "success")
             return redirect(mod_forum_url)
 
         # highlighting/trivializing
         elif "highlight" in request.form:
-            changed = do_topic_action(topics=tmp_topics, user=current_user,
+            changed = do_topic_action(topics=tmp_topics, user=real(current_user),
                                       action="important", reverse=False)
             flash(_("%(count)s topics highlighted.", count=changed), "success")
             return redirect(mod_forum_url)
 
         elif "trivialize" in request.form:
-            changed = do_topic_action(topics=tmp_topics, user=current_user,
+            changed = do_topic_action(topics=tmp_topics, user=real(current_user),
                                       action="important", reverse=True)
             flash(_("%(count)s topics trivialized.", count=changed), "success")
             return redirect(mod_forum_url)
 
         # deleting
         elif "delete" in request.form:
-            changed = do_topic_action(topics=tmp_topics, user=current_user,
+            changed = do_topic_action(topics=tmp_topics, user=real(current_user),
                                       action="delete", reverse=False)
             flash(_("%(count)s topics deleted.", count=changed), "success")
             return redirect(mod_forum_url)
@@ -393,7 +393,7 @@ def new_post(topic_id, slug=None):
                 form=form, preview=form.content.data
             )
         else:
-            post = form.save(current_user, topic)
+            post = form.save(real(current_user), topic)
             return view_post(post.id)
 
     return render_template("forum/new_post.html", topic=topic, form=form)
@@ -420,7 +420,7 @@ def reply_post(topic_id, post_id):
                 form=form, preview=form.content.data
             )
         else:
-            post = form.save(current_user, topic)
+            post = form.save(real(current_user), topic)
             return view_post(post.id)
     else:
         form.content.data = format_quote(post.username, post.content)
@@ -452,7 +452,7 @@ def edit_post(post_id):
         else:
             form.populate_obj(post)
             post.date_modified = time_utcnow()
-            post.modified_by = current_user.username
+            post.modified_by = real(current_user).username
             post.save()
 
             if post.first_post:
@@ -500,7 +500,7 @@ def report_post(post_id):
 
     form = ReportForm()
     if form.validate_on_submit():
-        form.save(current_user, post)
+        form.save(real(current_user), post)
         flash(_("Thanks for reporting."), "success")
 
     return render_template("forum/report_post.html", form=form)
@@ -521,15 +521,15 @@ def markread(forum_id=None, slug=None):
     if forum_id:
         forum_instance = Forum.query.filter_by(id=forum_id).first_or_404()
         forumsread = ForumsRead.query.filter_by(
-            user_id=current_user.id, forum_id=forum_instance.id
+            user_id=real(current_user).id, forum_id=forum_instance.id
         ).first()
-        TopicsRead.query.filter_by(user_id=current_user.id,
+        TopicsRead.query.filter_by(user_id=real(current_user).id,
                                    forum_id=forum_instance.id).delete()
 
         if not forumsread:
             forumsread = ForumsRead()
-            forumsread.user_id = current_user.id
-            forumsread.forum_id = forum_instance.id
+            forumsread.user = real(current_user)
+            forumsread.forum = forum_instance
 
         forumsread.last_read = time_utcnow()
         forumsread.cleared = time_utcnow()
@@ -543,14 +543,14 @@ def markread(forum_id=None, slug=None):
         return redirect(forum_instance.url)
 
     # Mark all forums as read
-    ForumsRead.query.filter_by(user_id=current_user.id).delete()
-    TopicsRead.query.filter_by(user_id=current_user.id).delete()
+    ForumsRead.query.filter_by(user_id=real(current_user).id).delete()
+    TopicsRead.query.filter_by(user_id=real(current_user).id).delete()
 
     forums = Forum.query.all()
     forumsread_list = []
     for forum_instance in forums:
         forumsread = ForumsRead()
-        forumsread.user_id = current_user.id
+        forumsread.user_id = real(current_user).id
         forumsread.forum_id = forum_instance.id
         forumsread.last_read = time_utcnow()
         forumsread.cleared = time_utcnow()
@@ -611,10 +611,10 @@ def memberlist():
 @login_required
 def topictracker():
     page = request.args.get("page", 1, type=int)
-    topics = current_user.tracked_topics.\
+    topics = real(current_user).tracked_topics.\
         outerjoin(TopicsRead,
                   db.and_(TopicsRead.topic_id == Topic.id,
-                          TopicsRead.user_id == current_user.id)).\
+                          TopicsRead.user_id == real(current_user).id)).\
         add_entity(TopicsRead).\
         order_by(Topic.last_updated.desc()).\
         paginate(page, flaskbb_config['TOPICS_PER_PAGE'], True)
@@ -625,8 +625,8 @@ def topictracker():
         tmp_topics = Topic.query.filter(Topic.id.in_(topic_ids)).all()
 
         for topic in tmp_topics:
-            current_user.untrack_topic(topic)
-        current_user.save()
+            real(current_user).untrack_topic(topic)
+        real(current_user).save()
 
         flash(_("%(topic_count)s topics untracked.",
                 topic_count=len(tmp_topics)), "success")
@@ -640,8 +640,8 @@ def topictracker():
 @login_required
 def track_topic(topic_id, slug=None):
     topic = Topic.query.filter_by(id=topic_id).first_or_404()
-    current_user.track_topic(topic)
-    current_user.save()
+    real(current_user).track_topic(topic)
+    real(current_user).save()
     return redirect(topic.url)
 
 
@@ -650,8 +650,8 @@ def track_topic(topic_id, slug=None):
 @login_required
 def untrack_topic(topic_id, slug=None):
     topic = Topic.query.filter_by(id=topic_id).first_or_404()
-    current_user.untrack_topic(topic)
-    current_user.save()
+    real(current_user).untrack_topic(topic)
+    real(current_user).save()
     return redirect(topic.url)
 
 

+ 11 - 0
flaskbb/utils/helpers.py

@@ -29,6 +29,8 @@ from flask_babelplus import lazy_gettext as _
 from flask_themes2 import render_theme_template, get_themes_list
 from flask_login import current_user
 
+from werkzeug.local import LocalProxy
+
 from flaskbb._compat import range_method, text_type, iteritems
 from flaskbb.extensions import redis_store, babel
 from flaskbb.utils.settings import flaskbb_config
@@ -614,3 +616,12 @@ class ReverseProxyPathFix(object):
             environ['wsgi.url_scheme'] = 'https'
 
         return self.app(environ, start_response)
+
+
+def real(obj):
+    """
+    Unwraps a werkzeug.local.LocalProxy object if given one, else returns the object
+    """
+    if isinstance(obj, LocalProxy):
+        return obj._get_current_object()
+    return obj