Browse Source

Make forum models comparable with eq/ne

Alec Nikolas Reiter 7 years ago
parent
commit
287f3c84d9
2 changed files with 17 additions and 2 deletions
  1. 6 1
      flaskbb/forum/models.py
  2. 11 1
      flaskbb/utils/database.py

+ 6 - 1
flaskbb/forum/models.py

@@ -16,7 +16,7 @@ from sqlalchemy.orm import aliased
 from flaskbb.extensions import db
 from flaskbb.utils.helpers import (slugify, get_categories_and_forums,
                                    get_forums, time_utcnow, topic_is_unread)
-from flaskbb.utils.database import CRUDMixin, UTCDateTime
+from flaskbb.utils.database import CRUDMixin, UTCDateTime, make_comparable
 from flaskbb.utils.settings import flaskbb_config
 
 
@@ -79,6 +79,7 @@ class ForumsRead(db.Model, CRUDMixin):
     cleared = db.Column(UTCDateTime(timezone=True), nullable=True)
 
 
+@make_comparable
 class Report(db.Model, CRUDMixin):
     __tablename__ = "reports"
 
@@ -130,6 +131,7 @@ class Report(db.Model, CRUDMixin):
         return self
 
 
+@make_comparable
 class Post(db.Model, CRUDMixin):
     __tablename__ = "posts"
 
@@ -274,6 +276,7 @@ class Post(db.Model, CRUDMixin):
         return self
 
 
+@make_comparable
 class Topic(db.Model, CRUDMixin):
     __tablename__ = "topics"
 
@@ -601,6 +604,7 @@ class Topic(db.Model, CRUDMixin):
         return self
 
 
+@make_comparable
 class Forum(db.Model, CRUDMixin):
     __tablename__ = "forums"
 
@@ -914,6 +918,7 @@ class Forum(db.Model, CRUDMixin):
         return topics
 
 
+@make_comparable
 class Category(db.Model, CRUDMixin):
     __tablename__ = "categories"
 

+ 11 - 1
flaskbb/utils/database.py

@@ -9,10 +9,20 @@
     :license: BSD, see LICENSE for more details.
 """
 import pytz
-
 from flaskbb.extensions import db
 
 
+def make_comparable(cls):
+    def __eq__(self, other):
+        return isinstance(other, cls) and self.id == other.id
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    cls.__eq__ = __eq__
+    return cls
+
+
 class CRUDMixin(object):
     def __repr__(self):
         return "<{}>".format(self.__class__.__name__)