Browse Source

Move more tests to patch_user_acl

rafalp 6 years ago
parent
commit
0115195d17

+ 1 - 1
misago/categories/api.py

@@ -7,5 +7,5 @@ from .utils import get_categories_tree
 
 class CategoryViewSet(viewsets.ViewSet):
     def list(self, request):
-        categories_tree = get_categories_tree(request.user, join_posters=True)
+        categories_tree = get_categories_tree(request.user, request.user_acl, join_posters=True)
         return Response(CategorySerializer(categories_tree, many=True).data)

+ 4 - 4
misago/categories/utils.py

@@ -4,7 +4,7 @@ from misago.readtracker import categoriestracker
 from .models import Category
 
 
-def get_categories_tree(user, parent=None, join_posters=False):
+def get_categories_tree(user, user_acl, parent=None, join_posters=False):
     if not user.acl_cache['visible_categories']:
         return []
 
@@ -13,7 +13,7 @@ def get_categories_tree(user, parent=None, join_posters=False):
     else:
         queryset = Category.objects.all_categories()
 
-    queryset_with_acl = queryset.filter(id__in=user.acl_cache['visible_categories'])
+    queryset_with_acl = queryset.filter(id__in=user_acl['visible_categories'])
     if join_posters:
         queryset_with_acl = queryset_with_acl.select_related('last_poster')
 
@@ -32,8 +32,8 @@ def get_categories_tree(user, parent=None, join_posters=False):
         if category.parent_id and category.level > parent_level:
             categories_dict[category.parent_id].subcategories.append(category)
 
-    add_acl(user, categories_list)
-    categoriestracker.make_read_aware(user, categories_list)
+    add_acl(user_acl, categories_list)
+    categoriestracker.make_read_aware(user, user_acl, categories_list)
 
     for category in reversed(visible_categories):
         if category.acl['can_browse']:

+ 1 - 1
misago/categories/views/categorieslist.py

@@ -6,7 +6,7 @@ from misago.categories.utils import get_categories_tree
 
 
 def categories(request):
-    categories_tree = get_categories_tree(request.user, join_posters=True)
+    categories_tree = get_categories_tree(request.user, request.user_acl, join_posters=True)
 
     request.frontend_context.update({
         'CATEGORIES': CategorySerializer(categories_tree, many=True).data,

+ 124 - 288
misago/threads/tests/test_threadslists.py

@@ -4,7 +4,7 @@ from django.urls import reverse
 from django.utils import timezone
 from django.utils.encoding import smart_str
 
-from misago.acl.testutils import override_acl
+from misago.acl.test import patch_user_acl
 from misago.categories.models import Category
 from misago.conf import settings
 from misago.readtracker import poststracker
@@ -12,10 +12,48 @@ from misago.threads import testutils
 from misago.users.models import AnonymousUser
 from misago.users.testutils import AuthenticatedUserTestCase
 
-
 LISTS_URLS = ('', 'my/', 'new/', 'unread/', 'subscribed/', )
 
 
+def patch_categories_acl(category_acl=None, base_acl=None):
+    def patch_acl(_, user_acl):
+        first_category = Category.objects.get(slug='first-category')
+        first_category_acl = user_acl['categories'][first_category.id].copy()
+
+        user_acl.update({
+            'categories': {},
+            'visible_categories': [],
+            'browseable_categories': [],
+            'can_approve_content': [],
+        })
+
+        # copy first category's acl to other categories to make base for overrides
+        for category in Category.objects.all_categories():
+            user_acl['categories'][category.id] = first_category_acl
+
+        if base_acl:
+            user_acl.update(base_acl)
+
+        for category in Category.objects.all_categories():
+            user_acl['visible_categories'].append(category.id)
+            user_acl['browseable_categories'].append(category.id)
+            user_acl['categories'][category.id].update({
+                'can_see': 1,
+                'can_browse': 1,
+                'can_see_all_threads': 1,
+                'can_see_own_threads': 0,
+                'can_hide_threads': 0,
+                'can_approve_content': 0,
+            })
+
+            if category_acl:
+                user_acl['categories'][category.id].update(category_acl)
+                if category_acl.get('can_approve_content'):
+                    user_acl['can_approve_content'].append(category.id)
+
+    return patch_user_acl(patch_acl)
+
+
 class ThreadsListTestCase(AuthenticatedUserTestCase):
     def setUp(self):
         """
@@ -120,46 +158,6 @@ class ThreadsListTestCase(AuthenticatedUserTestCase):
         self.category_e = Category.objects.get(slug='category-e')
         self.category_f = Category.objects.get(slug='category-f')
 
-        self.access_all_categories()
-
-    def access_all_categories(self, category_acl=None, base_acl=None):
-        self.clear_state()
-
-        categories_acl = {
-            'categories': {},
-            'visible_categories': [],
-            'browseable_categories': [],
-            'can_approve_content': [],
-        }
-
-        # copy first category's acl to other categories to make base for overrides
-        first_category_acl = self.user.acl_cache['categories'][self.first_category.pk].copy()
-        for category in Category.objects.all_categories():
-            categories_acl['categories'][category.pk] = first_category_acl
-
-        if base_acl:
-            categories_acl.update(base_acl)
-
-        for category in Category.objects.all_categories():
-            categories_acl['visible_categories'].append(category.pk)
-            categories_acl['browseable_categories'].append(category.pk)
-            categories_acl['categories'][category.pk].update({
-                'can_see': 1,
-                'can_browse': 1,
-                'can_see_all_threads': 1,
-                'can_see_own_threads': 0,
-                'can_hide_threads': 0,
-                'can_approve_content': 0,
-            })
-
-            if category_acl:
-                categories_acl['categories'][category.pk].update(category_acl)
-                if category_acl.get('can_approve_content'):
-                    categories_acl['can_approve_content'].append(category.pk)
-
-        override_acl(self.user, categories_acl)
-        return categories_acl
-
     def assertContainsThread(self, response, thread):
         self.assertContains(response, ' href="%s"' % thread.get_absolute_url())
 
@@ -185,11 +183,10 @@ class ApiTests(ThreadsListTestCase):
 
 
 class AllThreadsListTests(ThreadsListTestCase):
+    @patch_categories_acl()
     def test_list_renders_empty(self):
         """empty threads list renders"""
         for url in LISTS_URLS:
-            self.access_all_categories()
-
             response = self.client.get('/' + url)
             self.assertEqual(response.status_code, 200)
             self.assertContains(response, "empty-message")
@@ -198,8 +195,6 @@ class AllThreadsListTests(ThreadsListTestCase):
             else:
                 self.assertContains(response, "There are no threads on this forum")
 
-            self.access_all_categories()
-
             response = self.client.get(self.category_b.get_absolute_url() + url)
             self.assertEqual(response.status_code, 200)
             self.assertContains(response, self.category_b.name)
@@ -209,8 +204,6 @@ class AllThreadsListTests(ThreadsListTestCase):
             else:
                 self.assertContains(response, "There are no threads in this category")
 
-            self.access_all_categories()
-
             response = self.client.get('%s?list=%s' % (self.api_link, url.strip('/') or 'all'))
             self.assertEqual(response.status_code, 200)
 
@@ -221,46 +214,34 @@ class AllThreadsListTests(ThreadsListTestCase):
         self.logout_user()
         self.user = self.get_anonymous_user()
 
-        self.access_all_categories()
-
         response = self.client.get('/')
         self.assertEqual(response.status_code, 200)
         self.assertContains(response, "empty-message")
         self.assertContains(response, "There are no threads on this forum")
 
-        self.access_all_categories()
-
         response = self.client.get(self.category_b.get_absolute_url())
         self.assertEqual(response.status_code, 200)
         self.assertContains(response, self.category_b.name)
         self.assertContains(response, "empty-message")
         self.assertContains(response, "There are no threads in this category")
 
-        self.access_all_categories()
-
         response = self.client.get('%s?list=all' % self.api_link)
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
+    @patch_categories_acl()
     def test_list_authenticated_only_views(self):
         """authenticated only views return 403 for guests"""
         for url in LISTS_URLS:
-            self.access_all_categories()
-
             response = self.client.get('/' + url)
             self.assertEqual(response.status_code, 200)
 
-            self.access_all_categories()
-
             response = self.client.get(self.category_b.get_absolute_url() + url)
             self.assertEqual(response.status_code, 200)
             self.assertContains(response, self.category_b.name)
 
-            self.access_all_categories()
-
-            self.access_all_categories()
             response = self.client.get(
                 '%s?category=%s&list=%s' %
                 (self.api_link, self.category_b.pk, url.strip('/') or 'all', )
@@ -270,22 +251,19 @@ class AllThreadsListTests(ThreadsListTestCase):
         self.logout_user()
         self.user = self.get_anonymous_user()
         for url in LISTS_URLS[1:]:
-            self.access_all_categories()
-
             response = self.client.get('/' + url)
             self.assertEqual(response.status_code, 403)
 
-            self.access_all_categories()
             response = self.client.get(self.category_b.get_absolute_url() + url)
             self.assertEqual(response.status_code, 403)
 
-            self.access_all_categories()
             response = self.client.get(
                 '%s?category=%s&list=%s' %
                 (self.api_link, self.category_b.pk, url.strip('/') or 'all', )
             )
             self.assertEqual(response.status_code, 403)
 
+    @patch_categories_acl()
     def test_list_renders_categories_picker(self):
         """categories picker renders valid categories"""
         Category(
@@ -316,7 +294,6 @@ class AllThreadsListTests(ThreadsListTestCase):
         # hidden category
         self.assertNotContains(response, 'subcategory-%s' % test_category.css_class)
 
-        self.access_all_categories()
         response = self.client.get(self.api_link)
         self.assertEqual(response.status_code, 200)
 
@@ -325,11 +302,8 @@ class AllThreadsListTests(ThreadsListTestCase):
         self.assertNotIn(self.category_b.pk, response_json['subcategories'])
 
         # test category view
-        self.access_all_categories()
-
         response = self.client.get(self.category_a.get_absolute_url())
         self.assertEqual(response.status_code, 200)
-
         self.assertContains(response, 'subcategory-%s' % self.category_b.css_class)
 
         # readable categories, but non-accessible directly
@@ -337,7 +311,6 @@ class AllThreadsListTests(ThreadsListTestCase):
         self.assertNotContains(response, 'subcategory-%s' % self.category_d.css_class)
         self.assertNotContains(response, 'subcategory-%s' % self.category_f.css_class)
 
-        self.access_all_categories()
         response = self.client.get('%s?category=%s' % (self.api_link, self.category_a.pk))
         self.assertEqual(response.status_code, 200)
 
@@ -462,7 +435,7 @@ class CategoryThreadsListTests(ThreadsListTestCase):
             response = self.client.get(test_category.get_absolute_url() + url)
             self.assertEqual(response.status_code, 404)
 
-            response = self.client.get('%s?category=%s' % (self.api_link, test_category.pk))
+            response = self.client.get('%s?category=%s' % (self.api_link, test_category.id))
             self.assertEqual(response.status_code, 404)
 
     def test_access_protected_category(self):
@@ -478,37 +451,23 @@ class CategoryThreadsListTests(ThreadsListTestCase):
         test_category = Category.objects.get(slug='hidden-category')
 
         for url in LISTS_URLS:
-            override_acl(
-                self.user, {
-                    'visible_categories': [test_category.pk],
-                    'browseable_categories': [],
-                    'categories': {
-                        test_category.pk: {
-                            'can_see': 1,
-                            'can_browse': 0,
-                        },
+            with patch_user_acl({
+                'visible_categories': [test_category.id],
+                'browseable_categories': [],
+                'categories': {
+                    test_category.id: {
+                        'can_see': 1,
+                        'can_browse': 0,
                     },
-                }
-            )
-            response = self.client.get(test_category.get_absolute_url() + url)
-            self.assertEqual(response.status_code, 403)
+                },
+            }):
+                response = self.client.get(test_category.get_absolute_url() + url)
+                self.assertEqual(response.status_code, 403)
 
-            override_acl(
-                self.user, {
-                    'visible_categories': [test_category.pk],
-                    'browseable_categories': [],
-                    'categories': {
-                        test_category.pk: {
-                            'can_see': 1,
-                            'can_browse': 0,
-                        },
-                    },
-                }
-            )
-            response = self.client.get(
-                '%s?category=%s&list=%s' % (self.api_link, test_category.pk, url.strip('/'), )
-            )
-            self.assertEqual(response.status_code, 403)
+                response = self.client.get(
+                    '%s?category=%s&list=%s' % (self.api_link, test_category.id, url.strip('/'))
+                )
+                self.assertEqual(response.status_code, 403)
 
     def test_display_pinned_threads(self):
         """
@@ -550,7 +509,7 @@ class CategoryThreadsListTests(ThreadsListTestCase):
         self.assertTrue(positions['s'] > positions['g'])
 
         # API behaviour is identic
-        response = self.client.get('/api/threads/?category=%s' % self.first_category.pk)
+        response = self.client.get('/api/threads/?category=%s' % self.first_category.id)
         self.assertEqual(response.status_code, 200)
 
         content = smart_str(response.content)
@@ -574,6 +533,7 @@ class CategoryThreadsListTests(ThreadsListTestCase):
 
 
 class ThreadsVisibilityTests(ThreadsListTestCase):
+    @patch_categories_acl()
     def test_list_renders_test_thread(self):
         """list renders test thread with valid top category"""
         test_thread = testutils.post_thread(
@@ -592,7 +552,6 @@ class ThreadsVisibilityTests(ThreadsListTestCase):
         self.assertContains(response, 'thread-detail-category-%s' % self.category_c.css_class)
 
         # api displays same data
-        self.access_all_categories()
         response = self.client.get(self.api_link)
         self.assertEqual(response.status_code, 200)
 
@@ -602,7 +561,6 @@ class ThreadsVisibilityTests(ThreadsListTestCase):
         self.assertIn(self.category_a.pk, response_json['subcategories'])
 
         # test category view
-        self.access_all_categories()
         response = self.client.get(self.category_b.get_absolute_url())
         self.assertEqual(response.status_code, 200)
 
@@ -613,7 +571,6 @@ class ThreadsVisibilityTests(ThreadsListTestCase):
         self.assertContains(response, 'thread-detail-category-%s' % self.category_c.css_class)
 
         # api displays same data
-        self.access_all_categories()
         response = self.client.get('%s?category=%s' % (self.api_link, self.category_b.pk))
         self.assertEqual(response.status_code, 200)
 
@@ -664,6 +621,7 @@ class ThreadsVisibilityTests(ThreadsListTestCase):
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
+    @patch_categories_acl()
     def test_list_user_see_own_unapproved_thread(self):
         """list renders unapproved thread that belongs to viewer"""
         test_thread = testutils.post_thread(
@@ -677,13 +635,13 @@ class ThreadsVisibilityTests(ThreadsListTestCase):
         self.assertContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories()
         response = self.client.get(self.api_link)
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(response_json['results'][0]['id'], test_thread.pk)
 
+    @patch_categories_acl()
     def test_list_user_cant_see_unapproved_thread(self):
         """list hides unapproved thread that belongs to other user"""
         test_thread = testutils.post_thread(
@@ -696,13 +654,13 @@ class ThreadsVisibilityTests(ThreadsListTestCase):
         self.assertNotContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories()
         response = self.client.get(self.api_link)
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
+    @patch_categories_acl()
     def test_list_user_cant_see_hidden_thread(self):
         """list hides hidden thread that belongs to other user"""
         test_thread = testutils.post_thread(
@@ -715,13 +673,13 @@ class ThreadsVisibilityTests(ThreadsListTestCase):
         self.assertNotContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories()
         response = self.client.get(self.api_link)
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
+    @patch_categories_acl()
     def test_list_user_cant_see_own_hidden_thread(self):
         """list hides hidden thread that belongs to viewer"""
         test_thread = testutils.post_thread(
@@ -735,13 +693,13 @@ class ThreadsVisibilityTests(ThreadsListTestCase):
         self.assertNotContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories()
         response = self.client.get(self.api_link)
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
+    @patch_categories_acl({'can_hide_threads': 1})
     def test_list_user_can_see_own_hidden_thread(self):
         """list shows hidden thread that belongs to viewer due to permission"""
         test_thread = testutils.post_thread(
@@ -750,21 +708,18 @@ class ThreadsVisibilityTests(ThreadsListTestCase):
             is_hidden=True,
         )
 
-        self.access_all_categories({'can_hide_threads': 1})
-
         response = self.client.get('/')
         self.assertEqual(response.status_code, 200)
         self.assertContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories({'can_hide_threads': 1})
-
         response = self.client.get(self.api_link)
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(response_json['results'][0]['id'], test_thread.pk)
 
+    @patch_categories_acl({'can_hide_threads': 1})
     def test_list_user_can_see_hidden_thread(self):
         """list shows hidden thread that belongs to other user due to permission"""
         test_thread = testutils.post_thread(
@@ -772,21 +727,18 @@ class ThreadsVisibilityTests(ThreadsListTestCase):
             is_hidden=True,
         )
 
-        self.access_all_categories({'can_hide_threads': 1})
-
         response = self.client.get('/')
         self.assertEqual(response.status_code, 200)
         self.assertContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories({'can_hide_threads': 1})
-
         response = self.client.get(self.api_link)
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(response_json['results'][0]['id'], test_thread.pk)
 
+    @patch_categories_acl({'can_approve_content': 1})
     def test_list_user_can_see_unapproved_thread(self):
         """list shows hidden thread that belongs to other user due to permission"""
         test_thread = testutils.post_thread(
@@ -794,15 +746,11 @@ class ThreadsVisibilityTests(ThreadsListTestCase):
             is_unapproved=True,
         )
 
-        self.access_all_categories({'can_approve_content': 1})
-
         response = self.client.get('/')
         self.assertEqual(response.status_code, 200)
         self.assertContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories({'can_approve_content': 1})
-
         response = self.client.get(self.api_link)
         self.assertEqual(response.status_code, 200)
 
@@ -811,34 +759,30 @@ class ThreadsVisibilityTests(ThreadsListTestCase):
 
 
 class MyThreadsListTests(ThreadsListTestCase):
+    @patch_categories_acl()
     def test_list_renders_empty(self):
         """list renders empty"""
-        self.access_all_categories()
-
         response = self.client.get('/my/')
         self.assertEqual(response.status_code, 200)
         self.assertContains(response, "empty-message")
 
-        self.access_all_categories()
-
         response = self.client.get(self.category_a.get_absolute_url() + 'my/')
         self.assertEqual(response.status_code, 200)
         self.assertContains(response, "empty-message")
 
         # test api
-        self.access_all_categories()
         response = self.client.get('%s?list=my' % self.api_link)
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
-        self.access_all_categories()
         response = self.client.get('%s?list=my&category=%s' % (self.api_link, self.category_a.pk))
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
+    @patch_categories_acl()
     def test_list_renders_test_thread(self):
         """list renders only threads posted by user"""
         test_thread = testutils.post_thread(
@@ -848,22 +792,17 @@ class MyThreadsListTests(ThreadsListTestCase):
 
         other_thread = testutils.post_thread(category=self.category_a)
 
-        self.access_all_categories()
-
         response = self.client.get('/my/')
         self.assertEqual(response.status_code, 200)
         self.assertContainsThread(response, test_thread)
         self.assertNotContainsThread(response, other_thread)
 
-        self.access_all_categories()
-
         response = self.client.get(self.category_a.get_absolute_url() + 'my/')
         self.assertEqual(response.status_code, 200)
         self.assertContainsThread(response, test_thread)
         self.assertNotContainsThread(response, other_thread)
 
         # test api
-        self.access_all_categories()
         response = self.client.get('%s?list=my' % self.api_link)
         self.assertEqual(response.status_code, 200)
 
@@ -871,7 +810,6 @@ class MyThreadsListTests(ThreadsListTestCase):
         self.assertEqual(len(response_json['results']), 1)
         self.assertEqual(response_json['results'][0]['id'], test_thread.pk)
 
-        self.access_all_categories()
         response = self.client.get('%s?list=my&category=%s' % (self.api_link, self.category_a.pk))
         self.assertEqual(response.status_code, 200)
 
@@ -881,52 +819,43 @@ class MyThreadsListTests(ThreadsListTestCase):
 
 
 class NewThreadsListTests(ThreadsListTestCase):
+    @patch_categories_acl()
     def test_list_renders_empty(self):
         """list renders empty"""
-        self.access_all_categories()
-
         response = self.client.get('/new/')
         self.assertEqual(response.status_code, 200)
         self.assertContains(response, "empty-message")
 
-        self.access_all_categories()
-
         response = self.client.get(self.category_a.get_absolute_url() + 'new/')
         self.assertEqual(response.status_code, 200)
         self.assertContains(response, "empty-message")
 
         # test api
-        self.access_all_categories()
         response = self.client.get('%s?list=new' % self.api_link)
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
-        self.access_all_categories()
         response = self.client.get('%s?list=new&category=%s' % (self.api_link, self.category_a.pk))
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
+    @patch_categories_acl()
     def test_list_renders_new_thread(self):
         """list renders new thread"""
         test_thread = testutils.post_thread(category=self.category_a)
 
-        self.access_all_categories()
-
         response = self.client.get('/new/')
         self.assertEqual(response.status_code, 200)
         self.assertContainsThread(response, test_thread)
 
-        self.access_all_categories()
-
         response = self.client.get(self.category_a.get_absolute_url() + 'new/')
         self.assertEqual(response.status_code, 200)
         self.assertContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories()
         response = self.client.get('%s?list=new' % self.api_link)
         self.assertEqual(response.status_code, 200)
 
@@ -934,7 +863,6 @@ class NewThreadsListTests(ThreadsListTestCase):
         self.assertEqual(len(response_json['results']), 1)
         self.assertEqual(response_json['results'][0]['id'], test_thread.pk)
 
-        self.access_all_categories()
         response = self.client.get('%s?list=new&category=%s' % (self.api_link, self.category_a.pk))
         self.assertEqual(response.status_code, 200)
 
@@ -942,6 +870,7 @@ class NewThreadsListTests(ThreadsListTestCase):
         self.assertEqual(len(response_json['results']), 1)
         self.assertEqual(response_json['results'][0]['id'], test_thread.pk)
 
+    @patch_categories_acl()
     def test_list_renders_thread_bumped_after_user_cutoff(self):
         """list renders new thread bumped after user cutoff"""
         self.user.joined_on = timezone.now() - timedelta(days=10)
@@ -957,20 +886,15 @@ class NewThreadsListTests(ThreadsListTestCase):
             posted_on=self.user.joined_on + timedelta(days=4),
         )
 
-        self.access_all_categories()
-
         response = self.client.get('/new/')
         self.assertEqual(response.status_code, 200)
         self.assertContainsThread(response, test_thread)
 
-        self.access_all_categories()
-
         response = self.client.get(self.category_a.get_absolute_url() + 'new/')
         self.assertEqual(response.status_code, 200)
         self.assertContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories()
         response = self.client.get('%s?list=new' % self.api_link)
         self.assertEqual(response.status_code, 200)
 
@@ -978,7 +902,6 @@ class NewThreadsListTests(ThreadsListTestCase):
         self.assertEqual(len(response_json['results']), 1)
         self.assertEqual(response_json['results'][0]['id'], test_thread.pk)
 
-        self.access_all_categories()
         response = self.client.get('%s?list=new&category=%s' % (self.api_link, self.category_a.pk))
         self.assertEqual(response.status_code, 200)
 
@@ -986,6 +909,7 @@ class NewThreadsListTests(ThreadsListTestCase):
         self.assertEqual(len(response_json['results']), 1)
         self.assertEqual(response_json['results'][0]['id'], test_thread.pk)
 
+    @patch_categories_acl()
     def test_list_hides_global_cutoff_thread(self):
         """list hides thread started before global cutoff"""
         self.user.joined_on = timezone.now() - timedelta(days=10)
@@ -996,33 +920,28 @@ class NewThreadsListTests(ThreadsListTestCase):
             started_on=timezone.now() - timedelta(days=settings.MISAGO_READTRACKER_CUTOFF + 1),
         )
 
-        self.access_all_categories()
-
         response = self.client.get('/new/')
         self.assertEqual(response.status_code, 200)
         self.assertNotContainsThread(response, test_thread)
 
-        self.access_all_categories()
-
         response = self.client.get(self.category_a.get_absolute_url() + 'new/')
         self.assertEqual(response.status_code, 200)
         self.assertNotContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories()
         response = self.client.get('%s?list=new' % self.api_link)
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
-        self.access_all_categories()
         response = self.client.get('%s?list=new&category=%s' % (self.api_link, self.category_a.pk))
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
+    @patch_categories_acl()
     def test_list_hides_user_cutoff_thread(self):
         """list hides thread started before users cutoff"""
         self.user.joined_on = timezone.now() - timedelta(days=5)
@@ -1033,63 +952,51 @@ class NewThreadsListTests(ThreadsListTestCase):
             started_on=self.user.joined_on - timedelta(minutes=1),
         )
 
-        self.access_all_categories()
-
         response = self.client.get('/new/')
         self.assertEqual(response.status_code, 200)
         self.assertNotContainsThread(response, test_thread)
 
-        self.access_all_categories()
-
         response = self.client.get(self.category_a.get_absolute_url() + 'new/')
         self.assertEqual(response.status_code, 200)
         self.assertNotContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories()
         response = self.client.get('%s?list=new' % self.api_link)
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
-        self.access_all_categories()
         response = self.client.get('%s?list=new&category=%s' % (self.api_link, self.category_a.pk))
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
+    @patch_categories_acl()
     def test_list_hides_user_read_thread(self):
         """list hides thread already read by user"""
         self.user.joined_on = timezone.now() - timedelta(days=5)
         self.user.save()
 
         test_thread = testutils.post_thread(category=self.category_a)
-
         poststracker.save_read(self.user, test_thread.first_post)
 
-        self.access_all_categories()
-
         response = self.client.get('/new/')
         self.assertEqual(response.status_code, 200)
         self.assertNotContainsThread(response, test_thread)
 
-        self.access_all_categories()
-
         response = self.client.get(self.category_a.get_absolute_url() + 'new/')
         self.assertEqual(response.status_code, 200)
         self.assertNotContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories()
         response = self.client.get('%s?list=new' % self.api_link)
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
-        self.access_all_categories()
         response = self.client.get('%s?list=new&category=%s' % (self.api_link, self.category_a.pk))
         self.assertEqual(response.status_code, 200)
 
@@ -1098,29 +1005,24 @@ class NewThreadsListTests(ThreadsListTestCase):
 
 
 class UnreadThreadsListTests(ThreadsListTestCase):
+    @patch_categories_acl()
     def test_list_renders_empty(self):
         """list renders empty"""
-        self.access_all_categories()
-
         response = self.client.get('/unread/')
         self.assertEqual(response.status_code, 200)
         self.assertContains(response, "empty-message")
 
-        self.access_all_categories()
-
         response = self.client.get(self.category_a.get_absolute_url() + 'unread/')
         self.assertEqual(response.status_code, 200)
         self.assertContains(response, "empty-message")
 
         # test api
-        self.access_all_categories()
         response = self.client.get('%s?list=unread' % self.api_link)
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
-        self.access_all_categories()
         response = self.client.get(
             '%s?list=unread&category=%s' % (self.api_link, self.category_a.pk)
         )
@@ -1129,31 +1031,25 @@ class UnreadThreadsListTests(ThreadsListTestCase):
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
+    @patch_categories_acl()
     def test_list_renders_unread_thread(self):
         """list renders thread with unread posts"""
         self.user.joined_on = timezone.now() - timedelta(days=5)
         self.user.save()
 
         test_thread = testutils.post_thread(category=self.category_a)
-
         poststracker.save_read(self.user, test_thread.first_post)
-
         testutils.reply_thread(test_thread)
 
-        self.access_all_categories()
-
         response = self.client.get('/unread/')
         self.assertEqual(response.status_code, 200)
         self.assertContainsThread(response, test_thread)
 
-        self.access_all_categories()
-
         response = self.client.get(self.category_a.get_absolute_url() + 'unread/')
         self.assertEqual(response.status_code, 200)
         self.assertContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories()
         response = self.client.get('%s?list=unread' % self.api_link)
         self.assertEqual(response.status_code, 200)
 
@@ -1161,7 +1057,6 @@ class UnreadThreadsListTests(ThreadsListTestCase):
         self.assertEqual(len(response_json['results']), 1)
         self.assertEqual(response_json['results'][0]['id'], test_thread.pk)
 
-        self.access_all_categories()
         response = self.client.get(
             '%s?list=unread&category=%s' % (self.api_link, self.category_a.pk)
         )
@@ -1171,6 +1066,7 @@ class UnreadThreadsListTests(ThreadsListTestCase):
         self.assertEqual(len(response_json['results']), 1)
         self.assertEqual(response_json['results'][0]['id'], test_thread.pk)
 
+    @patch_categories_acl()
     def test_list_hides_never_read_thread(self):
         """list hides never read thread"""
         self.user.joined_on = timezone.now() - timedelta(days=5)
@@ -1178,27 +1074,21 @@ class UnreadThreadsListTests(ThreadsListTestCase):
 
         test_thread = testutils.post_thread(category=self.category_a)
 
-        self.access_all_categories()
-
         response = self.client.get('/unread/')
         self.assertEqual(response.status_code, 200)
         self.assertNotContainsThread(response, test_thread)
 
-        self.access_all_categories()
-
         response = self.client.get(self.category_a.get_absolute_url() + 'unread/')
         self.assertEqual(response.status_code, 200)
         self.assertNotContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories()
         response = self.client.get('%s?list=unread' % self.api_link)
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
-        self.access_all_categories()
         response = self.client.get(
             '%s?list=unread&category=%s' % (self.api_link, self.category_a.pk)
         )
@@ -1207,36 +1097,30 @@ class UnreadThreadsListTests(ThreadsListTestCase):
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
+    @patch_categories_acl()
     def test_list_hides_read_thread(self):
         """list hides read thread"""
         self.user.joined_on = timezone.now() - timedelta(days=5)
         self.user.save()
 
         test_thread = testutils.post_thread(category=self.category_a)
-
         poststracker.save_read(self.user, test_thread.first_post)
 
-        self.access_all_categories()
-
         response = self.client.get('/unread/')
         self.assertEqual(response.status_code, 200)
         self.assertNotContainsThread(response, test_thread)
 
-        self.access_all_categories()
-
         response = self.client.get(self.category_a.get_absolute_url() + 'unread/')
         self.assertEqual(response.status_code, 200)
         self.assertNotContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories()
         response = self.client.get('%s?list=unread' % self.api_link)
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
-        self.access_all_categories()
         response = self.client.get(
             '%s?list=unread&category=%s' % (self.api_link, self.category_a.pk)
         )
@@ -1245,6 +1129,7 @@ class UnreadThreadsListTests(ThreadsListTestCase):
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
+    @patch_categories_acl()
     def test_list_hides_global_cutoff_thread(self):
         """list hides thread replied before global cutoff"""
         self.user.joined_on = timezone.now() - timedelta(days=10)
@@ -1256,30 +1141,23 @@ class UnreadThreadsListTests(ThreadsListTestCase):
         )
 
         poststracker.save_read(self.user, test_thread.first_post)
-
         testutils.reply_thread(test_thread, posted_on=test_thread.started_on + timedelta(days=1))
 
-        self.access_all_categories()
-
         response = self.client.get('/unread/')
         self.assertEqual(response.status_code, 200)
         self.assertNotContainsThread(response, test_thread)
 
-        self.access_all_categories()
-
         response = self.client.get(self.category_a.get_absolute_url() + 'unread/')
         self.assertEqual(response.status_code, 200)
         self.assertNotContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories()
         response = self.client.get('%s?list=unread' % self.api_link)
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
-        self.access_all_categories()
         response = self.client.get(
             '%s?list=unread&category=%s' % (self.api_link, self.category_a.pk)
         )
@@ -1288,6 +1166,7 @@ class UnreadThreadsListTests(ThreadsListTestCase):
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
+    @patch_categories_acl()
     def test_list_hides_user_cutoff_thread(self):
         """list hides thread replied before user cutoff"""
         self.user.joined_on = timezone.now() - timedelta(days=10)
@@ -1305,27 +1184,21 @@ class UnreadThreadsListTests(ThreadsListTestCase):
             posted_on=test_thread.started_on + timedelta(days=1),
         )
 
-        self.access_all_categories()
-
         response = self.client.get('/unread/')
         self.assertEqual(response.status_code, 200)
         self.assertNotContainsThread(response, test_thread)
 
-        self.access_all_categories()
-
         response = self.client.get(self.category_a.get_absolute_url() + 'unread/')
         self.assertEqual(response.status_code, 200)
         self.assertNotContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories()
         response = self.client.get('%s?list=unread' % self.api_link)
         self.assertEqual(response.status_code, 200)
 
         response_json = response.json()
         self.assertEqual(len(response_json['results']), 0)
 
-        self.access_all_categories()
         response = self.client.get(
             '%s?list=unread&category=%s' % (self.api_link, self.category_a.pk)
         )
@@ -1336,6 +1209,7 @@ class UnreadThreadsListTests(ThreadsListTestCase):
 
 
 class SubscribedThreadsListTests(ThreadsListTestCase):
+    @patch_categories_acl()
     def test_list_shows_subscribed_thread(self):
         """list shows subscribed thread"""
         test_thread = testutils.post_thread(category=self.category_a)
@@ -1345,20 +1219,15 @@ class SubscribedThreadsListTests(ThreadsListTestCase):
             last_read_on=test_thread.last_post_on,
         )
 
-        self.access_all_categories()
-
         response = self.client.get('/subscribed/')
         self.assertEqual(response.status_code, 200)
         self.assertContainsThread(response, test_thread)
 
-        self.access_all_categories()
-
         response = self.client.get(self.category_a.get_absolute_url() + 'subscribed/')
         self.assertEqual(response.status_code, 200)
         self.assertContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories()
         response = self.client.get('%s?list=subscribed' % self.api_link)
         self.assertEqual(response.status_code, 200)
 
@@ -1366,7 +1235,6 @@ class SubscribedThreadsListTests(ThreadsListTestCase):
         self.assertEqual(len(response_json['results']), 1)
         self.assertContains(response, test_thread.get_absolute_url())
 
-        self.access_all_categories()
         response = self.client.get(
             '%s?list=subscribed&category=%s' % (self.api_link, self.category_a.pk)
         )
@@ -1376,24 +1244,20 @@ class SubscribedThreadsListTests(ThreadsListTestCase):
         self.assertEqual(len(response_json['results']), 1)
         self.assertContains(response, test_thread.get_absolute_url())
 
+    @patch_categories_acl()
     def test_list_hides_unsubscribed_thread(self):
         """list shows subscribed thread"""
         test_thread = testutils.post_thread(category=self.category_a)
 
-        self.access_all_categories()
-
         response = self.client.get('/subscribed/')
         self.assertEqual(response.status_code, 200)
         self.assertNotContainsThread(response, test_thread)
 
-        self.access_all_categories()
-
         response = self.client.get(self.category_a.get_absolute_url() + 'subscribed/')
         self.assertEqual(response.status_code, 200)
         self.assertNotContainsThread(response, test_thread)
 
         # test api
-        self.access_all_categories()
         response = self.client.get('%s?list=subscribed' % self.api_link)
         self.assertEqual(response.status_code, 200)
 
@@ -1401,7 +1265,6 @@ class SubscribedThreadsListTests(ThreadsListTestCase):
         self.assertEqual(len(response_json['results']), 0)
         self.assertNotContainsThread(response, test_thread)
 
-        self.access_all_categories()
         response = self.client.get(
             '%s?list=subscribed&category=%s' % (self.api_link, self.category_a.pk)
         )
@@ -1420,29 +1283,29 @@ class UnapprovedListTests(ThreadsListTestCase):
             '%s?list=unapproved' % self.api_link,
         )
 
-        for test_url in TEST_URLS:
-            self.access_all_categories()
-            response = self.client.get(test_url)
-            self.assertEqual(response.status_code, 403)
+        with patch_categories_acl():
+            for test_url in TEST_URLS:
+                response = self.client.get(test_url)
+                self.assertEqual(response.status_code, 403)
 
         # approval perm has no influence on visibility
-        for test_url in TEST_URLS:
-            self.access_all_categories({'can_approve_content': True})
-
-            self.access_all_categories()
-            response = self.client.get(test_url)
-            self.assertEqual(response.status_code, 403)
+        with patch_categories_acl({'can_approve_content': True}):
+            for test_url in TEST_URLS:
+                response = self.client.get(test_url)
+                self.assertEqual(response.status_code, 403)
 
         # approval perm has no influence on visibility
-        for test_url in TEST_URLS:
-            self.access_all_categories(base_acl={
-                'can_see_unapproved_content_lists': True,
-            })
-
-            self.access_all_categories()
-            response = self.client.get(test_url)
-            self.assertEqual(response.status_code, 200)
-
+        with patch_categories_acl(base_acl={
+            'can_see_unapproved_content_lists': True,
+        }):
+            for test_url in TEST_URLS:
+                response = self.client.get(test_url)
+                self.assertEqual(response.status_code, 200)
+
+    @patch_categories_acl(
+        {'can_approve_content': True},
+        {'can_see_unapproved_content_lists': True},
+    )
     def test_list_shows_all_threads_for_approving_user(self):
         """list shows all threads with unapproved posts when user has perm"""
         visible_thread = testutils.post_thread(
@@ -1455,40 +1318,23 @@ class UnapprovedListTests(ThreadsListTestCase):
             is_unapproved=False,
         )
 
-        self.access_all_categories({
-            'can_approve_content': True,
-        }, {
-            'can_see_unapproved_content_lists': True,
-        })
-
         response = self.client.get('/unapproved/')
         self.assertEqual(response.status_code, 200)
         self.assertContainsThread(response, visible_thread)
         self.assertNotContainsThread(response, hidden_thread)
 
-        self.access_all_categories({
-            'can_approve_content': True
-        }, {
-            'can_see_unapproved_content_lists': True,
-        })
-
         response = self.client.get(self.category_a.get_absolute_url() + 'unapproved/')
         self.assertEqual(response.status_code, 200)
         self.assertContainsThread(response, visible_thread)
         self.assertNotContainsThread(response, hidden_thread)
 
         # test api
-        self.access_all_categories({
-            'can_approve_content': True
-        }, {
-            'can_see_unapproved_content_lists': True,
-        })
-
         response = self.client.get('%s?list=unapproved' % self.api_link)
         self.assertEqual(response.status_code, 200)
         self.assertContains(response, visible_thread.get_absolute_url())
         self.assertNotContains(response, hidden_thread.get_absolute_url())
 
+    @patch_categories_acl(base_acl={'can_see_unapproved_content_lists': True})
     def test_list_shows_owned_threads_for_unapproving_user(self):
         """list shows owned threads with unapproved posts for user without perm"""
         visible_thread = testutils.post_thread(
@@ -1502,49 +1348,41 @@ class UnapprovedListTests(ThreadsListTestCase):
             is_unapproved=True,
         )
 
-        self.access_all_categories(base_acl={
-            'can_see_unapproved_content_lists': True,
-        })
         response = self.client.get('/unapproved/')
         self.assertEqual(response.status_code, 200)
         self.assertContainsThread(response, visible_thread)
         self.assertNotContainsThread(response, hidden_thread)
 
-        self.access_all_categories(base_acl={
-            'can_see_unapproved_content_lists': True,
-        })
         response = self.client.get(self.category_a.get_absolute_url() + 'unapproved/')
         self.assertEqual(response.status_code, 200)
         self.assertContainsThread(response, visible_thread)
         self.assertNotContainsThread(response, hidden_thread)
 
         # test api
-        self.access_all_categories(base_acl={
-            'can_see_unapproved_content_lists': True,
-        })
         response = self.client.get('%s?list=unapproved' % self.api_link)
         self.assertEqual(response.status_code, 200)
         self.assertContains(response, visible_thread.get_absolute_url())
         self.assertNotContains(response, hidden_thread.get_absolute_url())
 
 
+def patch_category_see_all_threads_acl():
+    def patch_acl(_, user_acl):
+        category = Category.objects.get(slug='first-category')
+        category_acl = user_acl['categories'][category.id].copy()
+        category_acl.update({'can_see_all_threads': 0})
+        user_acl['categories'][category.id] = category_acl
+
+    return patch_user_acl(patch_acl)
+
+
 class OwnerOnlyThreadsVisibilityTests(AuthenticatedUserTestCase):
     def setUp(self):
         super().setUp()
 
         self.category = Category.objects.get(slug='first-category')
 
-    def override_acl(self, user):
-        category_acl = user.acl_cache['categories'][self.category.pk].copy()
-        category_acl.update({'can_see_all_threads': 0})
-        user.acl_cache['categories'][self.category.pk] = category_acl
-
-        override_acl(user, user.acl_cache)
-
     def test_owned_threads_visibility(self):
         """only user-posted threads are visible in category"""
-        self.override_acl(self.user)
-
         visible_thread = testutils.post_thread(
             poster=self.user,
             category=self.category,
@@ -1556,18 +1394,16 @@ class OwnerOnlyThreadsVisibilityTests(AuthenticatedUserTestCase):
             is_unapproved=True,
         )
 
-        response = self.client.get(self.category.get_absolute_url())
-
-        self.assertEqual(response.status_code, 200)
-        self.assertContains(response, visible_thread.get_absolute_url())
-        self.assertNotContains(response, hidden_thread.get_absolute_url())
+        with patch_category_see_all_threads_acl():
+            response = self.client.get(self.category.get_absolute_url())
+            self.assertEqual(response.status_code, 200)
+            self.assertContains(response, visible_thread.get_absolute_url())
+            self.assertNotContains(response, hidden_thread.get_absolute_url())
 
     def test_owned_threads_visibility_anonymous(self):
         """anons can't see any threads in limited visibility category"""
         self.logout_user()
 
-        self.override_acl(AnonymousUser())
-
         user_thread = testutils.post_thread(
             poster=self.user,
             category=self.category,
@@ -1579,8 +1415,8 @@ class OwnerOnlyThreadsVisibilityTests(AuthenticatedUserTestCase):
             is_unapproved=True,
         )
 
-        response = self.client.get(self.category.get_absolute_url())
-
-        self.assertEqual(response.status_code, 200)
-        self.assertNotContains(response, user_thread.get_absolute_url())
-        self.assertNotContains(response, guest_thread.get_absolute_url())
+        with patch_category_see_all_threads_acl():
+            response = self.client.get(self.category.get_absolute_url())
+            self.assertEqual(response.status_code, 200)
+            self.assertNotContains(response, user_thread.get_absolute_url())
+            self.assertNotContains(response, guest_thread.get_absolute_url())

+ 3 - 3
misago/threads/viewmodels/category.py

@@ -66,8 +66,8 @@ class ThreadsCategory(ThreadsRootCategory):
             if category.pk == int(kwargs['pk']):
                 if not category.special_role:
                     # check permissions for non-special categories
-                    allow_see_category(request.user, category)
-                    allow_browse_category(request.user, category)
+                    allow_see_category(request.user_acl, category)
+                    allow_browse_category(request.user_acl, category)
 
                 if 'slug' in kwargs:
                     validate_slug(category, kwargs['slug'])
@@ -81,7 +81,7 @@ class PrivateThreadsCategory(ViewModel):
         return [Category.objects.private_threads()]
 
     def get_category(self, request, categories, **kwargs):
-        allow_use_private_threads(request.user)
+        allow_use_private_threads(request.user_acl)
 
         return categories[0]
 

+ 10 - 8
misago/threads/viewmodels/threads.py

@@ -77,7 +77,7 @@ class ViewModel(object):
                 thread.is_read = False
                 thread.is_new = True
         else:
-            threadstracker.make_read_aware(request.user, threads)
+            threadstracker.make_read_aware(request.user, request.user_acl, threads)
 
         self.filter_threads(request, threads)
 
@@ -188,29 +188,31 @@ def get_threads_queryset(request, categories, list_type):
     if list_type == 'all':
         return queryset
     else:
-        return filter_threads_queryset(request.user, categories, list_type, queryset)
+        return filter_threads_queryset(request, categories, list_type, queryset)
 
 
-def filter_threads_queryset(user, categories, list_type, queryset):
+def filter_threads_queryset(request, categories, list_type, queryset):
     if list_type == 'my':
-        return queryset.filter(starter=user)
+        return queryset.filter(starter=request.user)
     elif list_type == 'subscribed':
-        subscribed_threads = user.subscription_set.values('thread_id')
+        subscribed_threads = request.user.subscription_set.values('thread_id')
         return queryset.filter(id__in=subscribed_threads)
     elif list_type == 'unapproved':
         return queryset.filter(has_unapproved_posts=True)
     elif list_type in ('new', 'unread'):
-        return filter_read_threads_queryset(user, categories, list_type, queryset)
+        return filter_read_threads_queryset(request, categories, list_type, queryset)
     else:
         return queryset
 
 
-def filter_read_threads_queryset(user, categories, list_type, queryset):
+def filter_read_threads_queryset(request, categories, list_type, queryset):
     # grab cutoffs for categories
+    user = request.user
+
     cutoff_date = get_cutoff_date(user)
 
     visible_posts = Post.objects.filter(posted_on__gt=cutoff_date)
-    visible_posts = exclude_invisible_posts(user, categories, visible_posts)
+    visible_posts = exclude_invisible_posts(request.user_acl, categories, visible_posts)
 
     queryset = queryset.filter(id__in=visible_posts.distinct().values('thread'))