Browse Source

#893: moved missing validation to serializers in posting endpoints

Rafał Pitoń 7 years ago
parent
commit
8c918c25b1

+ 3 - 0
misago/threads/api/postingendpoint/__init__.py

@@ -1,4 +1,7 @@
+from rest_framework import serializers
+
 from django.core.exceptions import PermissionDenied
+from django.http import QueryDict
 from django.utils import timezone
 from django.utils.module_loading import import_string
 

+ 11 - 7
misago/threads/api/postingendpoint/close.py

@@ -1,3 +1,5 @@
+from rest_framework import serializers
+
 from misago.threads import moderation
 
 from . import PostingEndpoint, PostingMiddleware
@@ -5,14 +7,16 @@ from . import PostingEndpoint, PostingMiddleware
 
 class CloseMiddleware(PostingMiddleware):
     def use_this_middleware(self):
-        return self.mode == PostingEndpoint.START and 'close' in self.request.data
+        return self.mode == PostingEndpoint.START
+
+    def get_serializer(self):
+        return CloseSerializer(data=self.request.data)
 
     def post_save(self, serializer):
         if self.thread.category.acl['can_close_threads']:
-            try:
-                close = bool(self.request.data['close'])
-            except (TypeError, ValueError):
-                close = False
-
-            if close:
+            if serializer.validated_data.get('close'):
                 moderation.close_thread(self.request, self.thread)
+
+
+class CloseSerializer(serializers.Serializer):
+    close = serializers.BooleanField(required=False, default=False)

+ 11 - 7
misago/threads/api/postingendpoint/hide.py

@@ -1,3 +1,5 @@
+from rest_framework import serializers
+
 from misago.threads import moderation
 
 from . import PostingEndpoint, PostingMiddleware
@@ -5,19 +7,21 @@ from . import PostingEndpoint, PostingMiddleware
 
 class HideMiddleware(PostingMiddleware):
     def use_this_middleware(self):
-        return self.mode == PostingEndpoint.START and 'hide' in self.request.data
+        return self.mode == PostingEndpoint.START
+
+    def get_serializer(self):
+        return HideSerializer(data=self.request.data)
 
     def post_save(self, serializer):
         if self.thread.category.acl['can_hide_threads']:
-            try:
-                hide = bool(self.request.data['hide'])
-            except (TypeError, ValueError):
-                hide = False
-
-            if hide:
+            if serializer.validated_data.get('hide'):
                 moderation.hide_thread(self.request, self.thread)
                 self.thread.update_all = True
                 self.thread.save(update_fields=['is_hidden'])
 
                 self.thread.category.synchronize()
                 self.thread.category.update_all = True
+
+
+class HideSerializer(serializers.Serializer):
+    hide = serializers.BooleanField(required=False, default=False)

+ 11 - 5
misago/threads/api/postingendpoint/pin.py

@@ -1,3 +1,5 @@
+from rest_framework import serializers
+
 from misago.threads import moderation
 from misago.threads.models import Thread
 
@@ -6,18 +8,22 @@ from . import PostingEndpoint, PostingMiddleware
 
 class PinMiddleware(PostingMiddleware):
     def use_this_middleware(self):
-        return self.mode == PostingEndpoint.START and 'pin' in self.request.data
+        return self.mode == PostingEndpoint.START
+
+    def get_serializer(self):
+        return PinSerializer(data=self.request.data)
 
     def post_save(self, serializer):
         allowed_pin = self.thread.category.acl['can_pin_threads']
         if allowed_pin > 0:
-            try:
-                pin = int(self.request.data['pin'])
-            except (TypeError, ValueError):
-                pin = 0
+            pin = serializer.validated_data['pin']
 
             if pin <= allowed_pin:
                 if pin == Thread.WEIGHT_GLOBAL:
                     moderation.pin_thread_globally(self.request, self.thread)
                 elif pin == Thread.WEIGHT_PINNED:
                     moderation.pin_thread_locally(self.request, self.thread)
+
+
+class PinSerializer(serializers.Serializer):
+    pin = serializers.IntegerField(required=False, default=0)

+ 11 - 2
misago/threads/api/postingendpoint/protect.py

@@ -1,14 +1,23 @@
+from rest_framework import serializers
+
 from . import PostingEndpoint, PostingMiddleware
 
 
 class ProtectMiddleware(PostingMiddleware):
     def use_this_middleware(self):
-        return self.mode == PostingEndpoint.EDIT and 'protect' in self.request.data
+        return self.mode == PostingEndpoint.EDIT
+
+    def get_serializer(self):
+        return ProtectSerializer(data=self.request.data)
 
     def post_save(self, serializer):
         if self.thread.category.acl['can_protect_posts']:
             try:
-                self.post.is_protected = bool(self.request.data['protect'])
+                self.post.is_protected = serializer.validated_data.get('protect', False)
                 self.post.update_fields.append('is_protected')
             except (TypeError, ValueError):
                 pass
+
+
+class ProtectSerializer(serializers.Serializer):
+    protect = serializers.BooleanField(required=False, default=False)

+ 13 - 4
misago/threads/tests/test_thread_editreply_api.py

@@ -164,10 +164,19 @@ class EditReplyTests(AuthenticatedUserTestCase):
 
         response = self.put(self.api_link, data={})
 
-        self.assertEqual(response.status_code, 400)
-        self.assertEqual(response.json(), {
-            'post': ["You have to enter a message."],
-        })
+        self.assertContains(response, "You have to enter a message.", status_code=400)
+
+    def test_invalid_data(self):
+        """api errors for invalid request data"""
+        self.override_acl()
+
+        response = self.client.put(
+            self.api_link,
+            'false',
+            content_type="application/json",
+        )
+
+        self.assertContains(response, "Invalid data.", status_code=400)
 
     def test_edit_event(self):
         """events can't be edited"""

+ 14 - 4
misago/threads/tests/test_thread_reply_api.py

@@ -110,10 +110,20 @@ class ReplyThreadTests(AuthenticatedUserTestCase):
         self.override_acl()
 
         response = self.client.post(self.api_link, data={})
-        self.assertEqual(response.status_code, 400)
-        self.assertEqual(response.json(), {
-            'post': ["You have to enter a message."],
-        })
+
+        self.assertContains(response, "You have to enter a message.", status_code=400)
+
+    def test_invalid_data(self):
+        """api errors for invalid request data"""
+        self.override_acl()
+
+        response = self.client.post(
+            self.api_link,
+            'false',
+            content_type="application/json",
+        )
+
+        self.assertContains(response, "Invalid data.", status_code=400)
 
     def test_post_is_validated(self):
         """post is validated"""

+ 12 - 0
misago/threads/tests/test_thread_start_api.py

@@ -116,6 +116,18 @@ class StartThreadTests(AuthenticatedUserTestCase):
             }
         )
 
+    def test_invalid_data(self):
+        """api errors for invalid request data"""
+        self.override_acl()
+
+        response = self.client.post(
+            self.api_link,
+            'false',
+            content_type="application/json",
+        )
+
+        self.assertContains(response, "Invalid data.", status_code=400)
+
     def test_title_is_validated(self):
         """title is validated"""
         self.override_acl()