Просмотр исходного кода

Track private threads reads correctly.

Ralfp 12 лет назад
Родитель
Сommit
7b7dd1f659

+ 3 - 0
misago/apps/privatethreads/thread.py

@@ -50,6 +50,9 @@ class ThreadView(ThreadBaseView, ThreadModeration, PostsModeration, TypeMixin):
         context['invite_form'] = FormFields(InviteMemberForm(request=self.request))
         return context
 
+    def tracker_queryset(self):
+        return self.forum.thread_set.filter(participants__id=self.request.user.pk)
+
     def tracker_update(self, last_post):
         super(ThreadView, self).tracker_update(last_post)
         unread = self.tracker.unread_count(self.forum.thread_set.filter(participants__id=self.request.user.pk))

+ 4 - 1
misago/apps/threadtype/thread/views.py

@@ -77,7 +77,10 @@ class ThreadBaseView(ViewBase):
 
     def tracker_update(self, last_post):
         self.tracker.set_read(self.thread, last_post)
-        self.tracker.sync()
+        try:
+            self.tracker.sync(self.tracker_queryset())
+        except AttributeError:
+            self.tracker.sync()
 
     def thread_actions(self):
         pass

+ 9 - 4
misago/readstrackers.py

@@ -67,20 +67,22 @@ class ThreadsTracker(object):
             except KeyError:
                 self.need_create = thread
 
-    def unread_count(self, queryset=None):
+    def unread_count(self, queryset):
         try:
             return self.unread_threads
         except AttributeError:
             self.unread_threads = 0
             if not queryset:
-                queryset = self.request.acl.threads.filter_threads(self.request, self.forum, self.forum.thread_set)
+                queryset = self.default_queryset()
             for thread in queryset.filter(last__gte=self.record.cleared):
                 if not self.is_read(thread):
                     self.unread_threads += 1
             return self.unread_threads
 
-    def sync(self):
+    def sync(self, queryset=None):
         now = timezone.now()
+        if not queryset:
+            queryset = self.default_queryset()
 
         if self.need_create:
             new_record = ThreadRead(
@@ -97,10 +99,13 @@ class ThreadsTracker(object):
             self.need_update.save(force_update=True)
 
         if self.need_create or self.need_update:
-            if not self.unread_count():
+            if not self.unread_count(queryset):
                 self.record.cleared = now
             self.record.updated = now
             if self.record.pk:
                 self.record.save(force_update=True)
             else:
                 self.record.save(force_insert=True)
+
+    def default_queryset(self):
+        return self.request.acl.threads.filter_threads(self.request, self.forum, self.forum.thread_set)