Просмотр исходного кода

#893: merge multiple threads data validator

Rafał Pitoń 7 лет назад
Родитель
Сommit
1240fd5fb9

+ 34 - 76
misago/threads/api/threadendpoints/merge.py

@@ -1,30 +1,17 @@
 from rest_framework.response import Response
 
 from django.core.exceptions import PermissionDenied
-from django.http import Http404
 from django.utils.six import text_type
 from django.utils.translation import ugettext as _
-from django.utils.translation import ungettext
 
 from misago.acl import add_acl
-from misago.categories import THREADS_ROOT_NAME
-from misago.core.utils import clean_ids_list
 from misago.threads.events import record_event
 from misago.threads.models import Thread
 from misago.threads.moderation import threads as moderation
-from misago.threads.permissions import allow_merge_thread, can_reply_thread, can_see_thread
+from misago.threads.permissions import allow_merge_thread
 from misago.threads.pollmergehandler import PollMergeHandler
-from misago.threads.serializers import MergeThreadSerializer, NewThreadSerializer, ThreadsListSerializer
-from misago.threads.threadtypes import trees_map
-from misago.threads.utils import get_thread_id_from_url
-
-
-MERGE_LIMIT = 20  # no more than 20 threads can be merged in single action
-
-
-class MergeError(Exception):
-    def __init__(self, msg):
-        self.msg = msg
+from misago.threads.serializers import (
+    MergeThreadSerializer, MergeThreadsSerializer, ThreadsListSerializer)
 
 
 def thread_merge_endpoint(request, thread, viewmodel):
@@ -86,12 +73,26 @@ def thread_merge_endpoint(request, thread, viewmodel):
 
 
 def threads_merge_endpoint(request):
-    try:
-        threads = clean_threads_for_merge(request)
-    except MergeError as e:
-        return Response({'detail': e.msg}, status=403)
+    serializer = MergeThreadsSerializer(
+        data=request.data,
+        context={
+            'user': request.user
+        },
+    )
+
+    if not serializer.is_valid():
+        if 'threads' in serializer.errors:
+            errors = {'detail': serializer.errors['threads'][0]}
+            return Response(errors, status=403)
+        elif 'non_field_errors' in serializer.errors:
+            errors = {'detail': serializer.errors['non_field_errors'][0]}
+            return Response(errors, status=403)
+        else:
+            return Response(serializer.errors, status=400)
 
+    threads = serializer.validated_data['threads']
     invalid_threads = []
+
     for thread in threads:
         try:
             allow_merge_thread(request.user, thread)
@@ -105,66 +106,23 @@ def threads_merge_endpoint(request):
     if invalid_threads:
         return Response(invalid_threads, status=403)
 
-    serializer = NewThreadSerializer(
-        data=request.data,
-        context={'user': request.user},
-    )
-
-    if serializer.is_valid():
-        polls_handler = PollMergeHandler(threads)
-        if len(polls_handler.polls) == 1:
-            poll = polls_handler.polls[0]
-        elif polls_handler.is_merge_conflict():
-            if 'poll' in request.data:
-                polls_handler.set_resolution(request.data.get('poll'))
-                if polls_handler.is_valid():
-                    poll = polls_handler.get_resolution()
-                else:
-                    return Response({'detail': _("Invalid choice.")}, status=400)
+    polls_handler = PollMergeHandler(threads)
+    if len(polls_handler.polls) == 1:
+        poll = polls_handler.polls[0]
+    elif polls_handler.is_merge_conflict():
+        if 'poll' in request.data:
+            polls_handler.set_resolution(request.data.get('poll'))
+            if polls_handler.is_valid():
+                poll = polls_handler.get_resolution()
             else:
-                return Response({'polls': polls_handler.get_available_resolutions()}, status=400)
+                return Response({'detail': _("Invalid choice.")}, status=400)
         else:
-            poll = None
-
-        new_thread = merge_threads(request, serializer.validated_data, threads, poll)
-        return Response(ThreadsListSerializer(new_thread).data)
+            return Response({'polls': polls_handler.get_available_resolutions()}, status=400)
     else:
-        return Response(serializer.errors, status=400)
-
-
-def clean_threads_for_merge(request):
-    threads_ids = clean_ids_list(
-        request.data.get('threads', []),
-        _("One or more thread ids received were invalid."),
-    )
-
-    if len(threads_ids) < 2:
-        raise MergeError(_("You have to select at least two threads to merge."))
-    elif len(threads_ids) > MERGE_LIMIT:
-        message = ungettext(
-            "No more than %(limit)s thread can be merged at single time.",
-            "No more than %(limit)s threads can be merged at single time.",
-            MERGE_LIMIT,
-        )
-        raise MergeError(message % {'limit': MERGE_LIMIT})
-
-    threads_tree_id = trees_map.get_tree_id_for_root(THREADS_ROOT_NAME)
-
-    threads_queryset = Thread.objects.filter(
-        id__in=threads_ids,
-        category__tree_id=threads_tree_id,
-    ).select_related('category').order_by('-id')
-
-    threads = []
-    for thread in threads_queryset:
-        add_acl(request.user, thread)
-        if can_see_thread(request.user, thread):
-            threads.append(thread)
-
-    if len(threads) != len(threads_ids):
-        raise MergeError(_("One or more threads to merge could not be found."))
+        poll = None
 
-    return threads
+    new_thread = merge_threads(request, serializer.validated_data, threads, poll)
+    return Response(ThreadsListSerializer(new_thread).data)
 
 
 def merge_threads(request, validated_data, threads, poll):

+ 53 - 0
misago/threads/serializers/moderation.py

@@ -5,6 +5,7 @@ from django.http import Http404
 from django.utils.translation import ugettext as _, ugettext_lazy, ungettext
 
 from misago.acl import add_acl
+from misago.categories import THREADS_ROOT_NAME
 from misago.conf import settings
 from misago.threads.models import Thread
 from misago.threads.permissions import (
@@ -12,17 +13,20 @@ from misago.threads.permissions import (
     allow_move_post, allow_split_post, can_reply_thread, can_see_thread,
     can_start_thread, exclude_invisible_posts)
 from misago.threads.pollmergehandler import PollMergeHandler
+from misago.threads.threadtypes import trees_map
 from misago.threads.utils import get_thread_id_from_url
 from misago.threads.validators import validate_category, validate_title
 
 
 POSTS_LIMIT = settings.MISAGO_POSTS_PER_PAGE + settings.MISAGO_POSTS_TAIL
+THREADS_LIMIT = 20
 
 
 __all__ = [
     'DeletePostsSerializer',
     'MergePostsSerializer',
     'MergeThreadSerializer',
+    'MergeThreadsSerializer',
     'MovePostsSerializer',
     'NewThreadSerializer',
     'SplitPostsSerializer',
@@ -404,3 +408,52 @@ class MergeThreadSerializer(serializers.Serializer):
         self.polls_handler = polls_handler
 
         return data
+
+
+class MergeThreadsSerializer(NewThreadSerializer):
+    error_empty_or_required = ugettext_lazy("You have to select at least two threads to merge.")
+
+    threads = serializers.ListField(
+        allow_empty=False,
+        min_length=2,
+        child=serializers.IntegerField(
+            error_messages={
+                'invalid': ugettext_lazy("One or more thread ids received were invalid."),
+            },
+        ),
+        error_messages={
+            'empty': error_empty_or_required,
+            'null': error_empty_or_required,
+            'required': error_empty_or_required,
+            'min_length': error_empty_or_required,
+        },
+    )
+
+    def validate_threads(self, data):
+        if len(data) > THREADS_LIMIT:
+            message = ungettext(
+                "No more than %(limit)s thread can be merged at single time.",
+                "No more than %(limit)s threads can be merged at single time.",
+                POSTS_LIMIT,
+            )
+            raise ValidationError(message % {'limit': THREADS_LIMIT})
+
+        threads_tree_id = trees_map.get_tree_id_for_root(THREADS_ROOT_NAME)
+
+        threads_queryset = Thread.objects.filter(
+            id__in=data,
+            category__tree_id=threads_tree_id,
+        ).select_related('category').order_by('-id')
+
+        user = self.context['user']
+
+        threads = []
+        for thread in threads_queryset:
+            add_acl(user, thread)
+            if can_see_thread(user, thread):
+                threads.append(thread)
+
+        if len(threads) != len(data):
+            raise ValidationError(_("One or more threads to merge could not be found."))
+
+        return threads

+ 39 - 11
misago/threads/tests/test_threads_merge_api.py

@@ -3,9 +3,10 @@ import json
 from django.urls import reverse
 
 from misago.acl import add_acl
+from misago.acl.testutils import override_acl
 from misago.categories.models import Category
 from misago.threads import testutils
-from misago.threads.api.threadendpoints.merge import MERGE_LIMIT
+from misago.threads.serializers.moderation import THREADS_LIMIT
 from misago.threads.models import Poll, PollVote, Post, Thread
 from misago.threads.serializers import ThreadsListSerializer
 
@@ -27,6 +28,32 @@ class ThreadsMergeApiTests(ThreadsApiTestCase):
         )
         self.category_b = Category.objects.get(slug='category-b')
 
+    def override_other_category(self):
+        categories =  self.user.acl_cache['categories']
+
+        visible_categories = self.user.acl_cache['visible_categories']
+        browseable_categories = self.user.acl_cache['browseable_categories']
+
+        visible_categories.append(self.category_b.pk)
+        browseable_categories.append(self.category_b.pk)
+
+        override_acl(
+            self.user, {
+                'visible_categories': visible_categories,
+                'browseable_categories': browseable_categories,
+                'categories': {
+                    self.category.pk: categories[self.category.pk],
+                    self.category_b.pk: {
+                        'can_see': 1,
+                        'can_browse': 1,
+                        'can_see_all_threads': 1,
+                        'can_see_own_threads': 0,
+                        'can_start_threads': 2,
+                    },
+                },
+            }
+        )
+
     def test_merge_no_threads(self):
         """api validates if we are trying to merge no threads"""
         response = self.client.post(self.api_link, content_type="application/json")
@@ -66,14 +93,7 @@ class ThreadsMergeApiTests(ThreadsApiTestCase):
             }),
             content_type="application/json",
         )
-        self.assertEqual(response.status_code, 403)
-
-        response_json = response.json()
-        self.assertEqual(
-            response_json, {
-                'detail': "One or more thread ids received were invalid.",
-            }
-        )
+        self.assertContains(response, "Expected a list of items", status_code=403)
 
         response = self.client.post(
             self.api_link,
@@ -154,6 +174,8 @@ class ThreadsMergeApiTests(ThreadsApiTestCase):
         response = self.client.post(
             self.api_link,
             json.dumps({
+                'category': self.category.pk,
+                'title': 'Lorem ipsum dolor',
                 'threads': [self.thread.id, thread.id],
             }),
             content_type="application/json",
@@ -182,6 +204,7 @@ class ThreadsMergeApiTests(ThreadsApiTestCase):
             'can_merge_threads': 1,
             'can_close_threads': 0,
         })
+        self.override_other_category()
 
         other_thread = testutils.post_thread(self.category)
 
@@ -191,6 +214,8 @@ class ThreadsMergeApiTests(ThreadsApiTestCase):
         response = self.client.post(
             self.api_link,
             json.dumps({
+                'category': self.category_b.pk,
+                'title': 'Lorem ipsum dolor',
                 'threads': [self.thread.id, other_thread.id],
             }),
             content_type="application/json",
@@ -207,6 +232,7 @@ class ThreadsMergeApiTests(ThreadsApiTestCase):
             'can_merge_threads': 1,
             'can_close_threads': 0,
         })
+        self.override_other_category()
 
         other_thread = testutils.post_thread(self.category)
 
@@ -216,6 +242,8 @@ class ThreadsMergeApiTests(ThreadsApiTestCase):
         response = self.client.post(
             self.api_link,
             json.dumps({
+                'category': self.category_b.pk,
+                'title': 'Lorem ipsum dolor',
                 'threads': [self.thread.id, other_thread.id],
             }),
             content_type="application/json",
@@ -229,7 +257,7 @@ class ThreadsMergeApiTests(ThreadsApiTestCase):
     def test_merge_too_many_threads(self):
         """api rejects too many threads to merge"""
         threads = []
-        for _ in range(MERGE_LIMIT + 1):
+        for _ in range(THREADS_LIMIT + 1):
             threads.append(testutils.post_thread(category=self.category).pk)
 
         self.override_acl({
@@ -251,7 +279,7 @@ class ThreadsMergeApiTests(ThreadsApiTestCase):
         response_json = response.json()
         self.assertEqual(
             response_json, {
-                'detail': "No more than %s threads can be merged at single time." % MERGE_LIMIT,
+                'detail': "No more than %s threads can be merged at single time." % THREADS_LIMIT,
             }
         )