Browse Source

#893: fix build on py3k, moved pollvotes validation to serializer

Rafał Pitoń 7 years ago
parent
commit
860bf4bfd5

+ 1 - 1
misago/markup/tests/test_api.py

@@ -36,7 +36,7 @@ class ParseMarkupApiTests(AuthenticatedUserTestCase):
         self.assertContains(response, "Invalid data. Expected a dictionary", status_code=400)
         self.assertContains(response, "Invalid data. Expected a dictionary", status_code=400)
 
 
         response = self.client.post(self.api_link, 'malformed', content_type="application/json")
         response = self.client.post(self.api_link, 'malformed', content_type="application/json")
-        self.assertContains(response, "JSON parse error - No JSON object could be decoded", status_code=400)
+        self.assertContains(response, "JSON parse error", status_code=400)
 
 
     def test_empty_post(self):
     def test_empty_post(self):
         """api handles empty post"""
         """api handles empty post"""

+ 21 - 41
misago/threads/api/pollvotecreateendpoint.py

@@ -2,14 +2,9 @@ from copy import deepcopy
 
 
 from rest_framework.response import Response
 from rest_framework.response import Response
 
 
-from django.core.exceptions import ValidationError
-from django.utils import six
-from django.utils.translation import ugettext as _
-from django.utils.translation import ungettext
-
 from misago.acl import add_acl
 from misago.acl import add_acl
 from misago.threads.permissions import allow_vote_poll
 from misago.threads.permissions import allow_vote_poll
-from misago.threads.serializers import PollSerializer
+from misago.threads.serializers import PollSerializer, NewVoteSerializer
 
 
 
 
 def poll_vote_create(request, thread, poll):
 def poll_vote_create(request, thread, poll):
@@ -17,13 +12,26 @@ def poll_vote_create(request, thread, poll):
 
 
     allow_vote_poll(request.user, poll)
     allow_vote_poll(request.user, poll)
 
 
-    try:
-        clean_votes = validate_votes(poll, request.data)
-    except ValidationError as e:
-        return Response({'detail': six.text_type(e)}, status=400)
-
-    remove_user_votes(request.user, poll, clean_votes)
-    set_new_votes(request, poll, clean_votes)
+    serializer = NewVoteSerializer(
+        data={
+            'choices': request.data,
+        },
+        context={
+            'allowed_choices': poll.allowed_choices,
+            'choices': poll.choices,
+        },
+    )
+
+    if not serializer.is_valid():
+        return Response(
+            {
+                'detail': serializer.errors['choices'][0],
+            },
+            status=400,
+        )
+
+    remove_user_votes(request.user, poll, serializer.data['choices'])
+    set_new_votes(request, poll, serializer.data['choices'])
 
 
     add_acl(request.user, poll)
     add_acl(request.user, poll)
     serialized_poll = PollSerializer(poll).data
     serialized_poll = PollSerializer(poll).data
@@ -39,34 +47,6 @@ def presave_clean_choice(choice):
     return choice
     return choice
 
 
 
 
-def validate_votes(poll, votes):
-    try:
-        votes_len = len(votes)
-        if votes_len > poll.allowed_choices:
-            message = ungettext(
-                "This poll disallows voting for more than %(choices)s choice.",
-                "This poll disallows voting for more than %(choices)s choices.",
-                poll.allowed_choices,
-            )
-            raise ValidationError(message % {'choices': poll.allowed_choices})
-    except TypeError:
-        raise ValidationError(_("One or more of poll choices were invalid."))
-
-    valid_choices = [c['hash'] for c in poll.choices]
-    clean_votes = []
-
-    for vote in votes:
-        if vote in valid_choices:
-            clean_votes.append(vote)
-
-    if len(clean_votes) != len(votes):
-        raise ValidationError(_("One or more of poll choices were invalid."))
-    if not len(votes):
-        raise ValidationError(_("You have to make a choice."))
-
-    return clean_votes
-
-
 def remove_user_votes(user, poll, final_votes):
 def remove_user_votes(user, poll, final_votes):
     removed_votes = []
     removed_votes = []
     for choice in poll.choices:
     for choice in poll.choices:

+ 41 - 1
misago/threads/serializers/pollvote.py

@@ -1,9 +1,49 @@
 from rest_framework import serializers
 from rest_framework import serializers
 
 
 from django.urls import reverse
 from django.urls import reverse
+from django.utils.translation import ugettext as _
+from django.utils.translation import ungettext
 
 
 
 
-__all__ = ['PollVoteSerializer']
+__all__ = [
+    'NewVoteSerializer',
+    'PollVoteSerializer',
+]
+
+
+class NewVoteSerializer(serializers.Serializer):
+    choices = serializers.ListField(
+        child=serializers.CharField(),
+    )
+
+    def validate_choices(self, data):
+        if len(data) > self.context['allowed_choices']:
+            message = ungettext(
+                "This poll disallows voting for more than %(choices)s choice.",
+                "This poll disallows voting for more than %(choices)s choices.",
+                self.context['allowed_choices']
+            )
+            raise serializers.ValidationError(
+                message % {'choices': self.context['allowed_choices']},
+            )
+
+        valid_choices = [c['hash'] for c in self.context['choices']]
+        clean_choices = []
+
+        for choice in data:
+            if choice in valid_choices and choice not in clean_choices:
+                clean_choices.append(choice)
+
+        if len(clean_choices) != len(data):
+            raise serializers.ValidationError(
+                _("One or more of poll choices were invalid."),
+            )
+        if not len(clean_choices):
+            raise serializers.ValidationError(
+                _("You have to make a choice."),
+            )
+
+        return clean_choices
 
 
 
 
 class PollVoteSerializer(serializers.Serializer):
 class PollVoteSerializer(serializers.Serializer):

+ 14 - 8
misago/threads/tests/test_thread_pollvotes_api.py

@@ -188,14 +188,20 @@ class ThreadPostVotesTests(ThreadPollApiTestCase):
         self.delete_user_votes()
         self.delete_user_votes()
 
 
         response = self.post(self.api_link)
         response = self.post(self.api_link)
-        self.assertContains(response, "You have to make a choice.", status_code=400)
+        self.assertContains(response, "Expected a list of items", status_code=400)
 
 
-    def test_noninterable_vote(self):
-        """api validates if vote that user has made was iterable"""
+    def test_malformed_vote(self):
+        """api validates if vote that user has made was correctly structured"""
         self.delete_user_votes()
         self.delete_user_votes()
 
 
+        response = self.post(self.api_link, data={})
+        self.assertContains(response, "Expected a list of items", status_code=400)
+
+        response = self.post(self.api_link, data='hello')
+        self.assertContains(response, "Expected a list of items", status_code=400)
+
         response = self.post(self.api_link, data=123)
         response = self.post(self.api_link, data=123)
-        self.assertContains(response, "One or more of poll choices were invalid.", status_code=400)
+        self.assertContains(response, "Expected a list of items", status_code=400)
 
 
     def test_invalid_choices(self):
     def test_invalid_choices(self):
         """api validates if vote that user has made overlaps with allowed votes"""
         """api validates if vote that user has made overlaps with allowed votes"""
@@ -223,7 +229,7 @@ class ThreadPostVotesTests(ThreadPollApiTestCase):
         self.delete_user_votes()
         self.delete_user_votes()
 
 
         response = self.post(self.api_link)
         response = self.post(self.api_link)
-        self.assertContains(response, "You have to make a choice.", status_code=400)
+        self.assertContains(response, "Expected a list of items", status_code=400)
 
 
     def test_vote_in_closed_thread(self):
     def test_vote_in_closed_thread(self):
         """api validates is user has permission to vote poll in closed thread"""
         """api validates is user has permission to vote poll in closed thread"""
@@ -240,7 +246,7 @@ class ThreadPostVotesTests(ThreadPollApiTestCase):
         self.override_acl(category={'can_close_threads': 1})
         self.override_acl(category={'can_close_threads': 1})
 
 
         response = self.post(self.api_link)
         response = self.post(self.api_link)
-        self.assertContains(response, "You have to make a choice.", status_code=400)
+        self.assertContains(response, "Expected a list of items", status_code=400)
 
 
     def test_vote_in_closed_category(self):
     def test_vote_in_closed_category(self):
         """api validates is user has permission to vote poll in closed category"""
         """api validates is user has permission to vote poll in closed category"""
@@ -257,7 +263,7 @@ class ThreadPostVotesTests(ThreadPollApiTestCase):
         self.override_acl(category={'can_close_threads': 1})
         self.override_acl(category={'can_close_threads': 1})
 
 
         response = self.post(self.api_link)
         response = self.post(self.api_link)
-        self.assertContains(response, "You have to make a choice.", status_code=400)
+        self.assertContains(response, "Expected a list of items", status_code=400)
 
 
     def test_vote_in_finished_poll(self):
     def test_vote_in_finished_poll(self):
         """api valdiates if poll has finished before letting user to vote in it"""
         """api valdiates if poll has finished before letting user to vote in it"""
@@ -274,7 +280,7 @@ class ThreadPostVotesTests(ThreadPollApiTestCase):
         self.poll.save()
         self.poll.save()
 
 
         response = self.post(self.api_link)
         response = self.post(self.api_link)
-        self.assertContains(response, "You have to make a choice.", status_code=400)
+        self.assertContains(response, "Expected a list of items", status_code=400)
 
 
     def test_fresh_vote(self):
     def test_fresh_vote(self):
         """api handles first vote in poll"""
         """api handles first vote in poll"""