Browse Source

api for merge conflict resolution for threadview merge

Rafał Pitoń 8 years ago
parent
commit
df6fda792d

+ 48 - 2
misago/threads/api/threadendpoints/merge.py

@@ -14,6 +14,7 @@ from ...permissions import can_reply_thread, can_see_thread
 from ...serializers import NewThreadSerializer, ThreadsListSerializer
 from ...serializers import NewThreadSerializer, ThreadsListSerializer
 from ...threadtypes import trees_map
 from ...threadtypes import trees_map
 from ...utils import add_categories_to_threads, get_thread_id_from_url
 from ...utils import add_categories_to_threads, get_thread_id_from_url
+from .pollmergehandler import PollMergeHandler
 
 
 
 
 MERGE_LIMIT = 20 # no more than 20 threads can be merged in single action
 MERGE_LIMIT = 20 # no more than 20 threads can be merged in single action
@@ -49,6 +50,29 @@ def thread_merge_endpoint(request, thread, viewmodel):
             'detail': _("The thread you have entered link to doesn't exist or you don't have permission to see it.")
             'detail': _("The thread you have entered link to doesn't exist or you don't have permission to see it.")
         }, status=400)
         }, status=400)
 
 
+    polls_handler = PollMergeHandler([thread, other_thread])
+    if len(polls_handler.polls) == 1:
+        poll = polls_handler.polls[0]
+        poll.move(other_thread)
+    elif polls_handler.is_merge_conflict():
+        if 'poll' in request.data:
+            polls_handler.set_resolution(request.data.get('poll'))
+            if polls_handler.is_valid():
+                poll = polls_handler.get_resolution()
+                if poll and poll.thread_id != other_thread.id:
+                    other_thread.poll.delete()
+                    poll.move(other_thread)
+                elif not poll:
+                    other_thread.poll.delete()
+            else:
+                return Response({
+                    'detail': _("Invalid choice.")
+                }, status=400)
+        else:
+            return Response({
+                'polls': polls_handler.get_available_resolutions()
+            }, status=400)
+
     moderation.merge_thread(request, other_thread, thread)
     moderation.merge_thread(request, other_thread, thread)
 
 
     other_thread.synchronize()
     other_thread.synchronize()
@@ -90,7 +114,26 @@ def threads_merge_endpoint(request):
 
 
     serializer = NewThreadSerializer(context=request.user, data=request.data)
     serializer = NewThreadSerializer(context=request.user, data=request.data)
     if serializer.is_valid():
     if serializer.is_valid():
-        new_thread = merge_threads(request, serializer.validated_data, threads)
+        polls_handler = PollMergeHandler(threads)
+        if len(polls_handler.polls) == 1:
+            poll = polls_handler.polls[0]
+        elif polls_handler.is_merge_conflict():
+            if 'poll' in request.data:
+                polls_handler.set_resolution(request.data.get('poll'))
+                if polls_handler.is_valid():
+                    poll = polls_handler.get_resolution()
+                else:
+                    return Response({
+                        'detail': _("Invalid choice.")
+                    }, status=400)
+            else:
+                return Response({
+                    'polls': polls_handler.get_available_resolutions()
+                }, status=400)
+        else:
+            poll = None
+
+        new_thread = merge_threads(request, serializer.validated_data, threads, poll)
         return Response(ThreadsListSerializer(new_thread).data)
         return Response(ThreadsListSerializer(new_thread).data)
     else:
     else:
         return Response(serializer.errors, status=400)
         return Response(serializer.errors, status=400)
@@ -130,7 +173,7 @@ def clean_threads_for_merge(request):
     return threads
     return threads
 
 
 
 
-def merge_threads(request, validated_data, threads):
+def merge_threads(request, validated_data, threads, poll):
     new_thread = Thread(
     new_thread = Thread(
         category=validated_data['category'],
         category=validated_data['category'],
         started_on=threads[0].started_on,
         started_on=threads[0].started_on,
@@ -140,6 +183,9 @@ def merge_threads(request, validated_data, threads):
     new_thread.set_title(validated_data['title'])
     new_thread.set_title(validated_data['title'])
     new_thread.save()
     new_thread.save()
 
 
+    if poll:
+        poll.move(new_thread)
+
     categories = []
     categories = []
     for thread in threads:
     for thread in threads:
         categories.append(thread.category)
         categories.append(thread.category)

+ 50 - 0
misago/threads/api/threadendpoints/pollmergehandler.py

@@ -0,0 +1,50 @@
+from django.utils.translation import gettext as _
+
+from ...models import Poll
+
+
+class PollMergeHandler(object):
+    def __init__(self, threads):
+        self._list = []
+        self._choices = {0: None}
+
+        self._is_valid = False
+        self._resolution = None
+
+        self.threads = threads
+
+        for thread in threads:
+            try:
+                self._list.append(thread.poll)
+                self._choices[thread.poll.pk] = thread.poll
+            except Poll.DoesNotExist:
+                pass
+
+    @property
+    def polls(self):
+        return self._list
+
+    def is_merge_conflict(self):
+        return len(self._list) > 1
+
+    def get_available_resolutions(self):
+        resolutions = [(0, _("Delete all polls"))]
+        for poll in self._list:
+            resolutions.append((poll.pk, poll.question))
+        return resolutions
+
+    def set_resolution(self, resolution):
+        try:
+            resolution_clean = int(resolution)
+        except (TypeError, ValueError):
+            return
+
+        if resolution_clean in self._choices:
+            self._resolution = self._choices[resolution_clean]
+            self._is_valid = True
+
+    def is_valid(self):
+        return self._is_valid
+
+    def get_resolution(self):
+        return self._resolution or None

+ 6 - 0
misago/threads/models/poll.py

@@ -31,6 +31,12 @@ class Poll(models.Model):
     votes = models.PositiveIntegerField(default=0)
     votes = models.PositiveIntegerField(default=0)
     is_public = models.BooleanField(default=False)
     is_public = models.BooleanField(default=False)
 
 
+    def move(self, thread):
+        if self.thread_id != thread.id:
+            self.thread = thread
+            self.pollvote_set.update(thread=thread)
+            self.save()
+
     @property
     @property
     def ends_on(self):
     def ends_on(self):
         if self.length:
         if self.length:

+ 223 - 1
misago/threads/tests/test_thread_merge_api.py

@@ -7,7 +7,7 @@ from misago.acl.testutils import override_acl
 from misago.categories.models import Category
 from misago.categories.models import Category
 
 
 from .. import testutils
 from .. import testutils
-from ..models import Thread
+from ..models import Poll, PollVote, Thread
 from .test_threads_api import ThreadsApiTestCase
 from .test_threads_api import ThreadsApiTestCase
 
 
 class ThreadMergeApiTests(ThreadsApiTestCase):
 class ThreadMergeApiTests(ThreadsApiTestCase):
@@ -183,3 +183,225 @@ class ThreadMergeApiTests(ThreadsApiTestCase):
         # first thread is gone
         # first thread is gone
         with self.assertRaises(Thread.DoesNotExist):
         with self.assertRaises(Thread.DoesNotExist):
             Thread.objects.get(pk=self.thread.pk)
             Thread.objects.get(pk=self.thread.pk)
+
+    def test_merge_threads_kept_poll(self):
+        """api merges two threads successfully, keeping poll from old thread"""
+        self.override_acl({
+            'can_merge_threads': 1
+        })
+
+        self.override_other_acl({
+            'can_merge_threads': 1
+        })
+
+        other_thread = testutils.post_thread(self.category_b)
+        poll = testutils.post_poll(other_thread, self.user)
+
+        response = self.client.post(self.api_link, {
+            'thread_url': other_thread.get_absolute_url()
+        })
+        self.assertContains(response, other_thread.get_absolute_url(), status_code=200)
+
+        # other thread has two posts now
+        self.assertEqual(other_thread.post_set.count(), 3)
+
+        # first thread is gone
+        with self.assertRaises(Thread.DoesNotExist):
+            Thread.objects.get(pk=self.thread.pk)
+
+        # poll and its votes were kept
+        self.assertEqual(Poll.objects.filter(pk=poll.pk, thread=other_thread).count(), 1)
+        self.assertEqual(PollVote.objects.filter(poll=poll, thread=other_thread).count(), 4)
+
+    def test_merge_threads_moved_poll(self):
+        """api merges two threads successfully, moving poll from other thread"""
+        self.override_acl({
+            'can_merge_threads': 1
+        })
+
+        self.override_other_acl({
+            'can_merge_threads': 1
+        })
+
+        other_thread = testutils.post_thread(self.category_b)
+        poll = testutils.post_poll(self.thread, self.user)
+
+        response = self.client.post(self.api_link, {
+            'thread_url': other_thread.get_absolute_url()
+        })
+        self.assertContains(response, other_thread.get_absolute_url(), status_code=200)
+
+        # other thread has two posts now
+        self.assertEqual(other_thread.post_set.count(), 3)
+
+        # first thread is gone
+        with self.assertRaises(Thread.DoesNotExist):
+            Thread.objects.get(pk=self.thread.pk)
+
+        # poll and its votes were moved
+        self.assertEqual(Poll.objects.filter(pk=poll.pk, thread=other_thread).count(), 1)
+        self.assertEqual(PollVote.objects.filter(poll=poll, thread=other_thread).count(), 4)
+
+    def test_threads_merge_conflict(self):
+        """api errors on merge conflict, returning list of available polls"""
+        self.override_acl({
+            'can_merge_threads': 1
+        })
+
+        self.override_other_acl({
+            'can_merge_threads': 1
+        })
+
+        other_thread = testutils.post_thread(self.category_b)
+        poll = testutils.post_poll(self.thread, self.user)
+        other_poll = testutils.post_poll(other_thread, self.user)
+
+        response = self.client.post(self.api_link, {
+            'thread_url': other_thread.get_absolute_url()
+        })
+        self.assertEqual(response.status_code, 400)
+        self.assertEqual(response.json(), {
+            'polls': [
+                [0, "Delete all polls"],
+                [poll.pk, poll.question],
+                [other_poll.pk, other_poll.question]
+            ]
+        })
+
+        # poll and its votes were untouched
+        self.assertEqual(Poll.objects.count(), 2)
+        self.assertEqual(PollVote.objects.count(), 8)
+
+    def test_threads_merge_conflict_invalid_resolution(self):
+        """api errors on invalid merge conflict resolution"""
+        self.override_acl({
+            'can_merge_threads': 1
+        })
+
+        self.override_other_acl({
+            'can_merge_threads': 1
+        })
+
+        other_thread = testutils.post_thread(self.category_b)
+        poll = testutils.post_poll(self.thread, self.user)
+        other_poll = testutils.post_poll(other_thread, self.user)
+
+        response = self.client.post(self.api_link, {
+            'thread_url': other_thread.get_absolute_url(),
+            'poll': 'jhdkajshdsak'
+        })
+        self.assertEqual(response.status_code, 400)
+        self.assertEqual(response.json(), {
+            'detail': "Invalid choice."
+        })
+
+        # poll and its votes were untouched
+        self.assertEqual(Poll.objects.count(), 2)
+        self.assertEqual(PollVote.objects.count(), 8)
+
+    def test_threads_merge_conflict_delete_all(self):
+        """api deletes all polls when delete all choice is selected"""
+        self.override_acl({
+            'can_merge_threads': 1
+        })
+
+        self.override_other_acl({
+            'can_merge_threads': 1
+        })
+
+        other_thread = testutils.post_thread(self.category_b)
+        poll = testutils.post_poll(self.thread, self.user)
+        other_poll = testutils.post_poll(other_thread, self.user)
+
+        response = self.client.post(self.api_link, {
+            'thread_url': other_thread.get_absolute_url(),
+            'poll': 0
+        })
+        self.assertContains(response, other_thread.get_absolute_url(), status_code=200)
+
+        # other thread has two posts now
+        self.assertEqual(other_thread.post_set.count(), 3)
+
+        # first thread is gone
+        with self.assertRaises(Thread.DoesNotExist):
+            Thread.objects.get(pk=self.thread.pk)
+
+        # polls and votes are gone
+        self.assertEqual(Poll.objects.filter(pk=poll.pk, thread=other_thread).count(), 0)
+        self.assertEqual(PollVote.objects.filter(poll=poll, thread=other_thread).count(), 0)
+
+    def test_threads_merge_conflict_keep_first_poll(self):
+        """api deletes other poll on merge"""
+        self.override_acl({
+            'can_merge_threads': 1
+        })
+
+        self.override_other_acl({
+            'can_merge_threads': 1
+        })
+
+        other_thread = testutils.post_thread(self.category_b)
+        poll = testutils.post_poll(self.thread, self.user)
+        other_poll = testutils.post_poll(other_thread, self.user)
+
+        response = self.client.post(self.api_link, {
+            'thread_url': other_thread.get_absolute_url(),
+            'poll': poll.pk
+        })
+        self.assertContains(response, other_thread.get_absolute_url(), status_code=200)
+
+        # other thread has two posts now
+        self.assertEqual(other_thread.post_set.count(), 3)
+
+        # first thread is gone
+        with self.assertRaises(Thread.DoesNotExist):
+            Thread.objects.get(pk=self.thread.pk)
+
+        # other poll and its votes are gone
+        self.assertEqual(Poll.objects.filter(thread=self.thread).count(), 0)
+        self.assertEqual(PollVote.objects.filter(thread=self.thread).count(), 0)
+
+        self.assertEqual(Poll.objects.filter(thread=other_thread).count(), 1)
+        self.assertEqual(PollVote.objects.filter(thread=other_thread).count(), 4)
+
+        Poll.objects.get(pk=poll.pk)
+        with self.assertRaises(Poll.DoesNotExist):
+            Poll.objects.get(pk=other_poll.pk)
+
+    def test_threads_merge_conflict_keep_other_poll(self):
+        """api deletes first poll on merge"""
+        self.override_acl({
+            'can_merge_threads': 1
+        })
+
+        self.override_other_acl({
+            'can_merge_threads': 1
+        })
+
+        other_thread = testutils.post_thread(self.category_b)
+        poll = testutils.post_poll(self.thread, self.user)
+        other_poll = testutils.post_poll(other_thread, self.user)
+
+        response = self.client.post(self.api_link, {
+            'thread_url': other_thread.get_absolute_url(),
+            'poll': other_poll.pk
+        })
+        self.assertContains(response, other_thread.get_absolute_url(), status_code=200)
+
+        # other thread has two posts now
+        self.assertEqual(other_thread.post_set.count(), 3)
+
+        # first thread is gone
+        with self.assertRaises(Thread.DoesNotExist):
+            Thread.objects.get(pk=self.thread.pk)
+
+        # other poll and its votes are gone
+        self.assertEqual(Poll.objects.filter(thread=self.thread).count(), 0)
+        self.assertEqual(PollVote.objects.filter(thread=self.thread).count(), 0)
+
+        self.assertEqual(Poll.objects.filter(thread=other_thread).count(), 1)
+        self.assertEqual(PollVote.objects.filter(thread=other_thread).count(), 4)
+
+        Poll.objects.get(pk=other_poll.pk)
+        with self.assertRaises(Poll.DoesNotExist):
+            Poll.objects.get(pk=poll.pk)

+ 4 - 1
misago/threads/testutils.py

@@ -135,7 +135,10 @@ def post_poll(thread, poster):
 
 
     # one user voted for Alpha choice
     # one user voted for Alpha choice
     User = get_user_model()
     User = get_user_model()
-    user = User.objects.create_user('bob', 'bob@test.com', 'Pass.123')
+    try:
+        user = User.objects.get(slug='bob')
+    except User.DoesNotExist:
+        user = User.objects.create_user('bob', 'bob@test.com', 'Pass.123')
 
 
     poll.pollvote_set.create(
     poll.pollvote_set.create(
         category=thread.category,
         category=thread.category,