Browse Source

Refactor cursor paginator into function

rafalp 6 years ago
parent
commit
83c3dc45b6

+ 27 - 39
misago/core/cursorpaginator.py

@@ -1,49 +1,37 @@
 from django.core.paginator import EmptyPage, InvalidPage
 
 
-class CursorPaginator:
-    def __init__(self, queryset, order_by, per_page):
-        self.queryset = queryset
-        self.per_page = int(per_page)
+def get_page(queryset, order_by, per_page, start=0):
+    if start < 0:
+        raise InvalidPage()
 
+    object_list = list(_slice_queryset(queryset, order_by, per_page, start))
+    if start and not object_list:
+        raise EmptyPage()
+
+    next_cursor = None
+    if len(object_list) > per_page:
+        next_slice_first_item = object_list.pop(-1)
+        next_cursor = getattr(next_slice_first_item, order_by)
+
+    return CursorPage(start, object_list, next_cursor)
+
+
+def _slice_queryset(queryset, order_by, per_page, start):
+    page_len = int(per_page) + 1
+    if start:
         if order_by.startswith("-"):
-            self.order_by = order_by[1:]
-            self.desc = True
+            filter_name = "%s__lte" % order_by[1:]
         else:
-            self.order_by = order_by
-            self.desc = False
-
-    def get_page(self, start=0):
-        if start < 0:
-            raise InvalidPage()
-
-        object_list = list(self._get_slice(start))
-        if start and not object_list:
-            raise EmptyPage()
-
-        next_cursor = None
-        if len(object_list) > self.per_page:
-            next_slice_first_item = object_list.pop(-1)
-            next_cursor = getattr(next_slice_first_item, self.order_by)
-
-        return Page(start, object_list, next_cursor)
-
-    def _get_slice(self, start):
-        page_len = self.per_page + 1
-        if start:
-            print(start)
-            if self.desc:
-                filter_name = "%s__lte" % self.order_by
-            else:
-                filter_name = "%s__gte" % self.order_by
-            print({filter_name: start})
-            return self.queryset.filter(**{filter_name: start})[:page_len]
-        return self.queryset[:page_len]
-
-
-class Page:
-    def __init__(self, start, object_list, next_):
+            filter_name = "%s__gte" % order_by
+        return queryset.filter(**{filter_name: start})[:page_len]
+    return queryset[:page_len]
+
+
+class CursorPage:
+    def __init__(self, start, object_list, next_=None):
         self.start = start or 0
+        self.first = self.start == 0
         self.object_list = object_list
         self.next = next_
 

+ 52 - 34
misago/core/tests/test_cursor_paginator.py

@@ -1,6 +1,6 @@
 import pytest
 
-from ..cursorpaginator import CursorPaginator, EmptyPage, InvalidPage
+from ..cursorpaginator import CursorPage, EmptyPage, InvalidPage, get_page
 
 
 @pytest.fixture
@@ -14,64 +14,82 @@ def mock_queryset(mocker, mock_objects):
 
 
 def test_paginator_returns_first_page(mock_objects):
-    paginator = CursorPaginator(mock_objects, "post", 6)
-    assert paginator.get_page()
+    assert get_page(mock_objects, "post", 6)
 
 
-def test_first_page_has_no_start(mock_objects):
-    paginator = CursorPaginator(mock_objects, "post", 6)
-    assert paginator.get_page().start is None
+def test_first_page_has_first_flag(mock_objects):
+    page = get_page(mock_objects, "post", 6)
+    assert page.first
+
+
+def test_first_page_start_is_zero(mock_objects):
+    page = get_page(mock_objects, "post", 6)
+    assert page.start == 0
 
 
 def test_first_page_has_correct_length(mock_objects):
-    paginator = CursorPaginator(mock_objects, "post", 6)
-    assert len(paginator.get_page().object_list) == 6
+    page = get_page(mock_objects, "post", 6)
+    assert len(page.object_list) == 6
 
 
 def test_first_page_has_correct_items(mock_objects):
-    paginator = CursorPaginator(mock_objects, "post", 6)
-    assert paginator.get_page().object_list == mock_objects[:6]
+    page = get_page(mock_objects, "post", 6)
+    assert page.object_list == mock_objects[:6]
 
 
 def test_page_has_next_attr_pointing_to_first_item_of_next_page(mock_objects):
-    paginator = CursorPaginator(mock_objects, "post", 6)
-    assert paginator.get_page().next == 7
-
-
-def test_page_can_be_tested_to_see_if_next_page_exists(mock_objects):
-    paginator = CursorPaginator(mock_objects, "post", 6)
-    assert paginator.get_page().has_next()
+    page = get_page(mock_objects, "post", 6)
+    assert page.next == 7
 
 
-def test_paginator_returns_empty_first_page_without_errors():
-    paginator = CursorPaginator([], "post", 6)
-    assert paginator.get_page().object_list == []
+def test_requesting_next_page_filters_queryset_using_filter_name(mock_queryset):
+    page = get_page(mock_queryset, "post", 6, 7)
+    mock_queryset.filter.assert_called_once_with(post__gte=7)
 
 
-def test_paginator_returns_page_starting_at_requested_address(mock_queryset):
-    paginator = CursorPaginator(mock_queryset, "post", 6)
-    assert paginator.get_page(7)
+def test_requesting_next_page_for_reversed_order_filters_queryset_with_descending(
+    mock_queryset
+):
+    page = get_page(mock_queryset, "-post", 6, 7)
+    mock_queryset.filter.assert_called_once_with(post__lte=7)
 
 
-def test_requesting_next_page_filters_queryset_using_filter_name(mock_queryset):
-    paginator = CursorPaginator(mock_queryset, "post", 6)
-    paginator.get_page(7)
-    mock_queryset.filter.assert_called_once_with(post__gte=7)
+def test_requesting_next_page_limits_queryset_to_specified_length(mock_queryset):
+    page = get_page(mock_queryset, "post", 6, 7)
+    assert len(page.object_list) == 6
 
 
-def test_requesting_next_page_limits_queryset_to_specified_length(mock_queryset):
-    paginator = CursorPaginator(mock_queryset, "post", 6)
-    assert len(paginator.get_page(7).object_list) == 6
+def test_paginator_returns_empty_first_page_without_errors():
+    get_page([], "post", 6)
 
 
 def test_paginator_raises_empty_page_error_if_nth_page_is_empty(mocker):
     queryset = mocker.Mock(filter=lambda **_: [])
-    paginator = CursorPaginator(queryset, "post", 6)
     with pytest.raises(EmptyPage):
-        paginator.get_page(20)
+        get_page(queryset, "post", 6, 20)
 
 
 def test_paginator_raises_invalid_page_error_if_starting_position_is_negative():
-    paginator = CursorPaginator(None, None, 0)
     with pytest.raises(InvalidPage):
-        paginator.get_page(-1)
+        get_page(None, None, 0, -1)
+
+
+def test_page_can_be_tested_to_see_if_next_page_exists(mock_objects):
+    page = get_page(mock_objects, "post", 6)
+    assert page.has_next()
+
+
+def test_last_page_has_no_next(mock_objects):
+    page = get_page([], "post", 6)
+    assert not page.next
+    assert not page.has_next()
+
+
+def test_cursor_page_is_first_if_start_is_zero():
+    page = CursorPage(0, [])
+    assert page.first
+
+
+def test_cursor_page_is_not_first_if_start_is_not_zero():
+    page = CursorPage(1, [])
+    assert not page.first

+ 14 - 10
misago/threads/viewmodels/threads.py

@@ -1,4 +1,5 @@
 from django.core.exceptions import PermissionDenied
+from django.core.paginator import EmptyPage, InvalidPage
 from django.db.models import Q
 from django.http import Http404
 from django.utils.translation import gettext as _
@@ -6,7 +7,7 @@ from django.utils.translation import gettext_lazy
 
 from ...acl.objectacl import add_acl_to_obj
 from ...conf import settings
-from ...core.cursorpaginator import CursorPaginator
+from ...core.cursorpaginator import get_queryset_slice
 from ...readtracker import threadstracker
 from ...readtracker.dates import get_cutoff_date
 from ..models import Post, Thread
@@ -59,22 +60,25 @@ class ViewModel:
             base_queryset, category_model, threads_categories
         )
 
-        paginator = CursorPaginator(
-            threads_queryset,
-            "-last_post_id",
-            settings.MISAGO_THREADS_PER_PAGE
-        )
-        list_page = paginator.get_page(start)
+        try:
+            list_page = get_page(
+                threads_queryset,
+                "-last_post_id",
+                settings.MISAGO_THREADS_PER_PAGE,
+                start,
+            )
+        except (EmptyPage, InvalidPage):
+            raise Http404()
 
-        if list_page.start:
-            threads = list(list_page.object_list)
-        else:
+        if list_page.first:
             pinned_threads = list(
                 self.get_pinned_threads(
                     base_queryset, category_model, threads_categories
                 )
             )
             threads = list(pinned_threads) + list(list_page.object_list)
+        else:
+            threads = list(list_page.object_list)
 
         add_categories_to_items(category_model, category.categories, threads)
         add_acl_to_obj(request.user_acl, threads)