Browse Source

Merge remote-tracking branch 'origin' into 36-Shadow-Delete

Alec Nikolas Reiter 7 years ago
parent
commit
b95d261daa
4 changed files with 109 additions and 89 deletions
  1. 1 1
      flaskbb/cli/main.py
  2. 55 0
      flaskbb/forum/locals.py
  3. 22 70
      flaskbb/utils/requirements.py
  4. 31 18
      tests/unit/test_requirements.py

+ 1 - 1
flaskbb/cli/main.py

@@ -509,7 +509,7 @@ def generate_config(development, output, force):
                 "For more options see the SQLAlchemy docs:\n"
                 "    http://docs.sqlalchemy.org/en/latest/core/engines.html",
                 fg="cyan")
-    default_conf["database_url"] = click.prompt(
+    default_conf["database_uri"] = click.prompt(
         click.style("Database URI", fg="magenta"),
         default=default_conf.get("database_uri"))
 

+ 55 - 0
flaskbb/forum/locals.py

@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+"""
+    flaskbb.forum.locals
+    ~~~~~~~~~~~~~~~~~~~~
+    Thread local helpers for FlaskBB
+
+    :copyright: 2017, the FlaskBB Team
+    :license: BSD, see license for more details
+"""
+
+from flask import _request_ctx_stack, has_request_context, request
+from werkzeug.local import LocalProxy
+
+from .models import Category, Forum, Post, Topic
+
+
+@LocalProxy
+def current_post():
+    return _get_item(Post, 'post_id', 'post')
+
+
+@LocalProxy
+def current_topic():
+    if current_post:
+        return current_post.topic
+    return _get_item(Topic, 'topic_id', 'topic')
+
+
+@LocalProxy
+def current_forum():
+    if current_topic:
+        return current_topic.forum
+    return _get_item(Forum, 'forum_id', 'forum')
+
+
+@LocalProxy
+def current_category():
+    if current_forum:
+        return current_forum.category
+    return _get_item(Category, 'category_id', 'category')
+
+
+def _get_item(model, view_arg, name):
+    if (
+        has_request_context() and
+        not getattr(_request_ctx_stack.top, name, None) and
+        view_arg in request.view_args
+    ):
+        setattr(
+            _request_ctx_stack.top,
+            name,
+            model.query.filter_by(id=request.view_args[view_arg]).first()
+        )
+
+    return getattr(_request_ctx_stack.top, name, None)

+ 22 - 70
flaskbb/utils/requirements.py

@@ -7,11 +7,11 @@
     :copyright: (c) 2015 by the FlaskBB Team.
     :license: BSD, see LICENSE for more details
 """
-from flask_allows import Requirement, Or, And
+from flask_allows import And, Or, Requirement
 
 from flaskbb.exceptions import FlaskBBError
-from flaskbb.forum.models import Post, Topic, Forum
-from flaskbb.user.models import Group
+from flaskbb.forum.locals import current_forum, current_post, current_topic
+from flaskbb.forum.models import Forum, Post, Topic
 
 
 class Has(Requirement):
@@ -57,15 +57,9 @@ class IsModeratorInForum(IsAuthed):
         return Forum.query.get(self.forum_id)
 
     def _get_forum_from_request(self, request):
-        view_args = request.view_args
-        if 'post_id' in view_args:
-            return Post.query.get(view_args['post_id']).topic.forum
-        elif 'topic_id' in view_args:
-            return Topic.query.get(view_args['topic_id']).forum
-        elif 'forum_id' in view_args:
-            return Forum.query.get(view_args['forum_id'])
-        else:
-            raise FlaskBBError
+        if not current_forum:
+            raise FlaskBBError('Could not load forum data')
+        return current_forum
 
 
 class IsSameUser(IsAuthed):
@@ -82,11 +76,10 @@ class IsSameUser(IsAuthed):
         return self._get_user_id_from_post(request)
 
     def _get_user_id_from_post(self, request):
-        view_args = request.view_args
-        if 'post_id' in view_args:
-            return Post.query.get(view_args['post_id']).user_id
-        elif 'topic_id' in view_args:
-            return Topic.query.get(view_args['topic_id']).user_id
+        if current_post:
+            return current_post.user_id
+        elif current_topic:
+            return current_topic.user_id
         else:
             raise FlaskBBError
 
@@ -125,22 +118,8 @@ class TopicNotLocked(Requirement):
             return self._get_topic_from_request(request)
 
     def _get_topic_from_request(self, request):
-        view_args = request.view_args
-        if 'post_id' in view_args:
-            return (
-                Topic.query.join(Post, Post.topic_id == Topic.id)
-                .join(Forum, Forum.id == Topic.forum_id)
-                .filter(Post.id == view_args['post_id'])
-                .with_entities(Topic.locked, Forum.locked)
-                .first()
-            )
-        elif 'topic_id' in view_args:
-            return (
-                Topic.query.join(Forum, Forum.id == Topic.forum_id)
-                .filter(Topic.id == view_args['topic_id'])
-                .with_entities(Topic.locked, Forum.locked)
-                .first()
-            )
+        if current_topic:
+            return current_topic.locked, current_forum.locked
         else:
             raise FlaskBBError("How did you get this to happen?")
 
@@ -166,59 +145,32 @@ class ForumNotLocked(Requirement):
             return self._get_forum_from_request(request)
 
     def _get_forum_from_request(self, request):
-        view_args = request.view_args
-
-        # These queries look big and nasty, but they really aren't that bad
-        # Basically, find the forum this post or topic belongs to
-        # with_entities returns a KeyedTuple with only the locked status
-
-        if 'post_id' in view_args:
-            return (
-                Forum.query.join(Topic, Topic.forum_id == Forum.id)
-                .join(Post, Post.topic_id == Topic.id)
-                .filter(Post.id == view_args['post_id'])
-                .with_entities(Forum.locked)
-                .first()
-            )
-
-        elif 'topic_id' in view_args:
-            return (
-                Forum.query.join(Topic, Topic.forum_id == Forum.id)
-                .filter(Topic.id == view_args['topic_id'])
-                .with_entities(Forum.locked)
-                .first()
-            )
-
-        elif 'forum_id' in view_args:
-            return Forum.query.get(view_args['forum_id'])
+        if current_forum:
+            return current_forum.locked
+        raise FlaskBBError
 
 
 class CanAccessForum(Requirement):
     def fulfill(self, user, request):
-        forum_id = request.view_args['forum_id']
-        group_ids = [g.id for g in user.groups]
+        if not current_forum:
+            raise FlaskBBError('Could not load forum data')
 
-        return Forum.query.filter(
-            Forum.id == forum_id,
-            Forum.groups.any(Group.id.in_(group_ids))
-        ).count()
+        return set([g.id for g in current_forum.groups]) & set([g.id for g in user.groups])
 
 
 class CanAccessTopic(Requirement):
     def fulfill(self, user, request):
-        topic_id = request.view_args['topic_id']
-        group_ids = [g.id for g in user.groups]
+        if not current_forum:
+            raise FlaskBBError('Could not load topic data')
 
-        return Forum.query.join(Topic, Topic.forum_id == Forum.id).filter(
-            Topic.id == topic_id,
-            Forum.groups.any(Group.id.in_(group_ids))
-        ).count()
+        return set([g.id for g in current_forum.groups]) & set([g.id for g in user.groups])
 
 
 def IsAtleastModeratorInForum(forum_id=None, forum=None):
     return Or(IsAtleastSuperModerator, IsModeratorInForum(forum_id=forum_id,
                                                           forum=forum))
 
+
 IsMod = And(IsAuthed(), Has('mod'))
 IsSuperMod = And(IsAuthed(), Has('super_mod'))
 IsAdmin = And(IsAuthed(), Has('admin'))

+ 31 - 18
tests/unit/test_requirements.py

@@ -1,5 +1,18 @@
+import pytest
+from flask import _request_ctx_stack, request
+
 from flaskbb.utils import requirements as r
-from flaskbb.utils.datastructures import SimpleNamespace
+
+
+def push_onto_request_context(**kw):
+    for name, value in kw.items():
+        setattr(_request_ctx_stack.top, name, value)
+
+
+@pytest.yield_fixture
+def request_context(application):
+    with application.test_request_context():
+        yield
 
 
 def test_Fred_IsNotAdmin(Fred):
@@ -42,44 +55,44 @@ def test_Fred_CannotBanUser(Fred):
     assert not r.CanBanUser(Fred, None)
 
 
-def test_CanEditTopic_with_member(user, topic):
-    request = SimpleNamespace(view_args={'topic_id': topic.id})
+def test_CanEditTopic_with_member(user, topic, request_context):
+    push_onto_request_context(topic=topic)
     assert r.CanEditPost(user, request)
 
 
-def test_Fred_cannot_edit_other_members_post(user, Fred, topic):
-    request = SimpleNamespace(view_args={'topic_id': topic.id})
+def test_Fred_cannot_edit_other_members_post(user, Fred, topic, request_context):
+    push_onto_request_context(topic=topic)
     assert not r.CanEditPost(Fred, request)
 
 
-def test_Fred_CannotEditLockedTopic(Fred, topic_locked):
-    request = SimpleNamespace(view_args={'topic_id': topic_locked.id})
+def test_Fred_CannotEditLockedTopic(Fred, topic_locked, request_context):
+    push_onto_request_context(topic=topic_locked)
     assert not r.CanEditPost(Fred, request)
 
 
-def test_Moderator_in_Forum_CanEditLockedTopic(moderator_user, topic_locked):
-    request = SimpleNamespace(view_args={'topic_id': topic_locked.id})
+def test_Moderator_in_Forum_CanEditLockedTopic(moderator_user, topic_locked, request_context):
+    push_onto_request_context(topic=topic_locked)
     assert r.CanEditPost(moderator_user, request)
 
 
-def test_FredIsAMod_but_still_cant_edit_topic_in_locked_forum(
-        Fred, topic_locked, default_groups):
+def test_FredIsAMod_but_still_cant_edit_topic_in_locked_forum(Fred, topic_locked, default_groups, request_context):
 
-    request = SimpleNamespace(view_args={'topic_id': topic_locked.id})
     Fred.primary_group = default_groups[2]
+
+    push_onto_request_context(topic=topic_locked)
     assert not r.CanEditPost(Fred, request)
 
 
-def test_Fred_cannot_reply_to_locked_topic(Fred, topic_locked):
-    request = SimpleNamespace(view_args={'topic_id': topic_locked.id})
+def test_Fred_cannot_reply_to_locked_topic(Fred, topic_locked, request_context):
+    push_onto_request_context(topic=topic_locked)
     assert not r.CanPostReply(Fred, request)
 
 
-def test_Fred_cannot_delete_others_post(Fred, topic):
-    request = SimpleNamespace(view_args={'post_id': topic.first_post.id})
+def test_Fred_cannot_delete_others_post(Fred, topic, request_context):
+    push_onto_request_context(post=topic.first_post)
     assert not r.CanDeletePost(Fred, request)
 
 
-def test_Mod_can_delete_others_post(moderator_user, topic):
-    request = SimpleNamespace(view_args={'post_id': topic.first_post.id})
+def test_Mod_can_delete_others_post(moderator_user, topic, request_context):
+    push_onto_request_context(post=topic.first_post)
     assert r.CanDeletePost(moderator_user, request)