Просмотр исходного кода

Merge pull request #304 from justanr/Set-Object-Rather-Than-ID

Update relationships rather than IDs (in most cases)
Peter Justin 7 лет назад
Родитель
Сommit
19edc98d2f

+ 61 - 51
flaskbb/forum/models.py

@@ -16,7 +16,7 @@ from sqlalchemy.orm import aliased
 from flaskbb.extensions import db
 from flaskbb.utils.helpers import (slugify, get_categories_and_forums,
                                    get_forums, time_utcnow, topic_is_unread)
-from flaskbb.utils.database import CRUDMixin, UTCDateTime
+from flaskbb.utils.database import CRUDMixin, UTCDateTime, make_comparable
 from flaskbb.utils.settings import flaskbb_config
 
 
@@ -53,14 +53,17 @@ class TopicsRead(db.Model, CRUDMixin):
 
     user_id = db.Column(db.Integer, db.ForeignKey("users.id"),
                         primary_key=True)
+    user = db.relationship('User', uselist=False, foreign_keys=[user_id])
     topic_id = db.Column(db.Integer,
                          db.ForeignKey("topics.id", use_alter=True,
                                        name="fk_tr_topic_id"),
                          primary_key=True)
+    topic = db.relationship('Topic', uselist=False, foreign_keys=[topic_id])
     forum_id = db.Column(db.Integer,
                          db.ForeignKey("forums.id", use_alter=True,
                                        name="fk_tr_forum_id"),
                          primary_key=True)
+    forum = db.relationship('Forum', uselist=False, foreign_keys=[forum_id])
     last_read = db.Column(UTCDateTime(timezone=True), default=time_utcnow,
                           nullable=False)
 
@@ -70,15 +73,18 @@ class ForumsRead(db.Model, CRUDMixin):
 
     user_id = db.Column(db.Integer, db.ForeignKey("users.id"),
                         primary_key=True)
+    user = db.relationship('User', uselist=False, foreign_keys=[user_id])
     forum_id = db.Column(db.Integer,
                          db.ForeignKey("forums.id", use_alter=True,
                                        name="fk_fr_forum_id"),
                          primary_key=True)
+    forum = db.relationship('Forum', uselist=False, foreign_keys=[forum_id])
     last_read = db.Column(UTCDateTime(timezone=True), default=time_utcnow,
                           nullable=False)
     cleared = db.Column(UTCDateTime(timezone=True), nullable=True)
 
 
+@make_comparable
 class Report(db.Model, CRUDMixin):
     __tablename__ = "reports"
 
@@ -121,15 +127,16 @@ class Report(db.Model, CRUDMixin):
             return self
 
         if post and user:
-            self.reporter_id = user.id
+            self.reporter = user
             self.reported = time_utcnow()
-            self.post_id = post.id
+            self.post = post
 
         db.session.add(self)
         db.session.commit()
         return self
 
 
+@make_comparable
 class Post(db.Model, CRUDMixin):
     __tablename__ = "posts"
 
@@ -166,6 +173,8 @@ class Post(db.Model, CRUDMixin):
             self.content = content
 
         if user:
+            # setting user here -- even with setting the user id explicitly
+            # breaks the bulk insert for some reason
             self.user_id = user.id
             self.username = user.username
 
@@ -197,24 +206,18 @@ class Post(db.Model, CRUDMixin):
         # Adding a new post
         if user and topic:
             created = time_utcnow()
-            self.user_id = user.id
+            self.user = user
             self.username = user.username
-            self.topic_id = topic.id
+            self.topic = topic
             self.date_created = created
 
             topic.last_updated = created
-
-            # This needs to be done before the last_post_id gets updated.
-            db.session.add(self)
-            db.session.commit()
-
-            # Now lets update the last post id
-            topic.last_post_id = self.id
+            topic.last_post = self
 
             # Update the last post info for the forum
-            topic.forum.last_post_id = self.id
+            topic.forum.last_post = self
+            topic.forum.last_post_user = self.user
             topic.forum.last_post_title = topic.title
-            topic.forum.last_post_user_id = user.id
             topic.forum.last_post_username = user.username
             topic.forum.last_post_created = created
 
@@ -231,15 +234,15 @@ class Post(db.Model, CRUDMixin):
     def delete(self):
         """Deletes a post and returns self."""
         # This will delete the whole topic
-        if self.topic.first_post_id == self.id:
+        if self.topic.first_post == self:
             self.topic.delete()
             return self
 
         # Delete the last post
-        if self.topic.last_post_id == self.id:
+        if self.topic.last_post == self:
 
             # update the last post in the forum
-            if self.topic.last_post_id == self.topic.forum.last_post_id:
+            if self.topic.last_post == self.topic.forum.last_post:
                 # We need the second last post in the forum here,
                 # because the last post will be deleted
                 second_last_post = Post.query.\
@@ -250,9 +253,9 @@ class Post(db.Model, CRUDMixin):
 
                 # now lets update the second last post to the last post
                 last_post = second_last_post[1]
-                self.topic.forum.last_post_id = last_post.id
+                self.topic.forum.last_post = last_post
                 self.topic.forum.last_post_title = last_post.topic.title
-                self.topic.forum.last_post_user_id = last_post.user_id
+                self.topic.forum.last_post_user = last_post.user
                 self.topic.forum.last_post_username = last_post.username
                 self.topic.forum.last_post_created = last_post.date_created
 
@@ -264,22 +267,21 @@ class Post(db.Model, CRUDMixin):
             # there is no second last post, now the last post is also the
             # first post
             else:
-                self.topic.last_post_id = self.topic.first_post_id
+                self.topic.last_post = self.topic.first_post
 
-            post = Post.query.get(self.topic.last_post_id)
-            self.topic.last_updated = post.date_created
+            self.topic.last_updated = self.topic.last_post.date_created
 
         # Update the post counts
         self.user.post_count -= 1
         self.topic.post_count -= 1
         self.topic.forum.post_count -= 1
-        db.session.commit()
 
         db.session.delete(self)
         db.session.commit()
         return self
 
 
+@make_comparable
 class Topic(db.Model, CRUDMixin):
     __tablename__ = "topics"
 
@@ -370,6 +372,9 @@ class Topic(db.Model, CRUDMixin):
             self.title = title
 
         if user:
+            # setting the user here, even with setting the id, breaks the bulk insert
+            # stuff as they use the session.bulk_save_objects which does not trigger
+            # relationships
             self.user_id = user.id
             self.username = user.username
 
@@ -458,9 +463,9 @@ class Topic(db.Model, CRUDMixin):
         # the TopicsRead model.
         elif not topicsread:
             topicsread = TopicsRead()
-            topicsread.user_id = user.id
-            topicsread.topic_id = self.id
-            topicsread.forum_id = self.forum_id
+            topicsread.user = user
+            topicsread.topic = self
+            topicsread.forum = self.forum
             topicsread.last_read = time_utcnow()
             topicsread.save()
             updated = True
@@ -489,13 +494,13 @@ class Topic(db.Model, CRUDMixin):
         """
 
         # if the target forum is the current forum, abort
-        if self.forum_id == new_forum.id:
+        if self.forum == new_forum:
             return False
 
         old_forum = self.forum
         self.forum.post_count -= self.post_count
         self.forum.topic_count -= 1
-        self.forum_id = new_forum.id
+        self.forum = new_forum
 
         new_forum.post_count += self.post_count
         new_forum.topic_count += 1
@@ -525,8 +530,8 @@ class Topic(db.Model, CRUDMixin):
             return self
 
         # Set the forum and user id
-        self.forum_id = forum.id
-        self.user_id = user.id
+        self.forum = forum
+        self.user = user
         self.username = user.username
 
         # Set the last_updated time. Needed for the readstracker
@@ -540,7 +545,7 @@ class Topic(db.Model, CRUDMixin):
         post.save(user, self)
 
         # Update the first and last post id
-        self.last_post_id = self.first_post_id = post.id
+        self.last_post = self.first_post = post
 
         # Update the topic count
         forum.topic_count += 1
@@ -559,26 +564,23 @@ class Topic(db.Model, CRUDMixin):
             order_by(Topic.last_post_id.desc()).limit(2).offset(0).all()
 
         # do we want to delete the topic with the last post in the forum?
-        if topic and topic[0].id == self.id:
+        if topic and topic[0] == self:
             try:
                 # Now the second last post will be the last post
-                self.forum.last_post_id = topic[1].last_post_id
+                self.forum.last_post = topic[1].last_post
                 self.forum.last_post_title = topic[1].title
-                self.forum.last_post_user_id = topic[1].user_id
+                self.forum.last_post_user = topic[1].user
                 self.forum.last_post_username = topic[1].username
                 self.forum.last_post_created = topic[1].last_updated
             # Catch an IndexError when you delete the last topic in the forum
             # There is no second last post
             except IndexError:
-                self.forum.last_post_id = None
+                self.forum.last_post = None
                 self.forum.last_post_title = None
-                self.forum.last_post_user_id = None
+                self.forum.last_post_user = None
                 self.forum.last_post_username = None
                 self.forum.last_post_created = None
 
-            # Commit the changes
-            db.session.commit()
-
         # These things needs to be stored in a variable before they are deleted
         forum = self.forum
 
@@ -586,13 +588,11 @@ class Topic(db.Model, CRUDMixin):
 
         # Delete the topic
         db.session.delete(self)
-        db.session.commit()
 
         # Update the post counts
         if users:
             for user in users:
                 user.post_count = Post.query.filter_by(user_id=user.id).count()
-                db.session.commit()
 
         forum.topic_count = Topic.query.\
             filter_by(forum_id=self.forum_id).\
@@ -607,6 +607,7 @@ class Topic(db.Model, CRUDMixin):
         return self
 
 
+@make_comparable
 class Forum(db.Model, CRUDMixin):
     __tablename__ = "forums"
 
@@ -629,11 +630,18 @@ class Forum(db.Model, CRUDMixin):
     last_post = db.relationship("Post", backref="last_post_forum",
                                 uselist=False, foreign_keys=[last_post_id])
 
+    last_post_user_id = db.Column(db.Integer, db.ForeignKey("users.id"),
+                                  nullable=True)
+
+    last_post_user = db.relationship(
+        "User",
+        uselist=False,
+        foreign_keys=[last_post_user_id]
+    )
+
     # Not nice, but needed to improve the performance; can be set to NULL
     # if the forum has no posts
     last_post_title = db.Column(db.String(255), nullable=True)
-    last_post_user_id = db.Column(db.Integer, db.ForeignKey("users.id"),
-                                  nullable=True)
     last_post_username = db.Column(db.String(255), nullable=True)
     last_post_created = db.Column(UTCDateTime(timezone=True),
                                   default=time_utcnow, nullable=True)
@@ -693,14 +701,15 @@ class Forum(db.Model, CRUDMixin):
             filter(Post.topic_id == Topic.id,
                    Topic.forum_id == self.id).\
             order_by(Post.date_created.desc()).\
-            first()
+            limit(1)\
+            .first()
 
         # Last post is none when there are no topics in the forum
         if last_post is not None:
 
             # a new last post was found in the forum
-            if not last_post.id == self.last_post_id:
-                self.last_post_id = last_post.id
+            if last_post != self.last_post:
+                self.last_post = last_post
                 self.last_post_title = last_post.topic.title
                 self.last_post_user_id = last_post.user_id
                 self.last_post_username = last_post.username
@@ -708,9 +717,9 @@ class Forum(db.Model, CRUDMixin):
 
         # No post found..
         else:
-            self.last_post_id = None
+            self.last_post = None
             self.last_post_title = None
-            self.last_post_user_id = None
+            self.last_post_user = None
             self.last_post_username = None
             self.last_post_created = None
 
@@ -753,7 +762,7 @@ class Forum(db.Model, CRUDMixin):
                               ForumsRead.user_id == user.id)).\
             filter(Topic.forum_id == self.id,
                    Topic.last_updated > read_cutoff,
-                   db.or_(TopicsRead.last_read == None,
+                   db.or_(TopicsRead.last_read == None,  # noqa: E711
                           TopicsRead.last_read < Topic.last_updated)).\
             count()
 
@@ -773,8 +782,8 @@ class Forum(db.Model, CRUDMixin):
 
             # No ForumRead Entry existing - creating one.
             forumsread = ForumsRead()
-            forumsread.user_id = user.id
-            forumsread.forum_id = self.id
+            forumsread.user = user
+            forumsread.forum = self
             forumsread.last_read = time_utcnow()
             forumsread.save()
             return True
@@ -913,6 +922,7 @@ class Forum(db.Model, CRUDMixin):
         return topics
 
 
+@make_comparable
 class Category(db.Model, CRUDMixin):
     __tablename__ = "categories"
 

+ 37 - 37
flaskbb/forum/views.py

@@ -21,7 +21,7 @@ from flaskbb.extensions import db, allows
 from flaskbb.utils.settings import flaskbb_config
 from flaskbb.utils.helpers import (get_online_users, time_diff, time_utcnow,
                                    format_quote, render_template,
-                                   do_topic_action)
+                                   do_topic_action, real)
 from flaskbb.utils.requirements import (CanAccessForum, CanAccessTopic,
                                         CanDeletePost, CanDeleteTopic,
                                         CanEditPost, CanPostReply,
@@ -38,7 +38,7 @@ forum = Blueprint("forum", __name__)
 
 @forum.route("/")
 def index():
-    categories = Category.get_all(user=current_user)
+    categories = Category.get_all(user=real(current_user))
 
     # Fetch a few stats about the forum
     user_count = User.query.count()
@@ -71,7 +71,7 @@ def index():
 @forum.route("/category/<int:category_id>-<slug>")
 def view_category(category_id, slug=None):
     category, forums = Category.\
-        get_forums(category_id=category_id, user=current_user)
+        get_forums(category_id=category_id, user=real(current_user))
 
     return render_template("forum/category.html", forums=forums,
                            category=category)
@@ -84,14 +84,14 @@ def view_forum(forum_id, slug=None):
     page = request.args.get('page', 1, type=int)
 
     forum_instance, forumsread = Forum.get_forum(
-        forum_id=forum_id, user=current_user
+        forum_id=forum_id, user=real(current_user)
     )
 
     if forum_instance.external:
         return redirect(forum_instance.external)
 
     topics = Forum.get_topics(
-        forum_id=forum_instance.id, user=current_user, page=page,
+        forum_id=forum_instance.id, user=real(current_user), page=page,
         per_page=flaskbb_config["TOPICS_PER_PAGE"]
     )
 
@@ -108,7 +108,7 @@ def view_topic(topic_id, slug=None):
     page = request.args.get('page', 1, type=int)
 
     # Fetch some information about the topic
-    topic = Topic.get_topic(topic_id=topic_id, user=current_user)
+    topic = Topic.get_topic(topic_id=topic_id, user=real(current_user))
 
     # Count the topic views
     topic.views += 1
@@ -130,16 +130,16 @@ def view_topic(topic_id, slug=None):
     forumsread = None
     if current_user.is_authenticated:
         forumsread = ForumsRead.query.\
-            filter_by(user_id=current_user.id,
+            filter_by(user_id=real(current_user).id,
                       forum_id=topic.forum.id).first()
 
-    topic.update_read(current_user, topic.forum, forumsread)
+    topic.update_read(real(current_user), topic.forum, forumsread)
 
     form = None
     if Permission(CanPostReply):
         form = QuickreplyForm()
         if form.validate_on_submit():
-            post = form.save(current_user, topic)
+            post = form.save(real(current_user), topic)
             return view_post(post.id)
 
     return render_template("forum/topic.html", topic=topic, posts=posts,
@@ -179,7 +179,7 @@ def new_topic(forum_id, slug=None):
                 form=form, preview=form.content.data
             )
         if "submit" in request.form and form.validate():
-            topic = form.save(current_user, forum_instance)
+            topic = form.save(real(current_user), forum_instance)
             # redirect to the new topic
             return redirect(url_for('forum.view_topic', topic_id=topic.id))
 
@@ -277,7 +277,7 @@ def manage_forum(forum_id, slug=None):
     page = request.args.get('page', 1, type=int)
 
     forum_instance, forumsread = Forum.get_forum(forum_id=forum_id,
-                                                 user=current_user)
+                                                 user=real(current_user))
 
     # remove the current forum from the select field (move).
     available_forums = Forum.query.order_by(Forum.position).all()
@@ -292,7 +292,7 @@ def manage_forum(forum_id, slug=None):
         return redirect(forum_instance.external)
 
     topics = Forum.get_topics(
-        forum_id=forum_instance.id, user=current_user, page=page,
+        forum_id=forum_instance.id, user=real(current_user), page=page,
         per_page=flaskbb_config["TOPICS_PER_PAGE"]
     )
 
@@ -312,34 +312,34 @@ def manage_forum(forum_id, slug=None):
 
         # locking/unlocking
         if "lock" in request.form:
-            changed = do_topic_action(topics=tmp_topics, user=current_user,
+            changed = do_topic_action(topics=tmp_topics, user=real(current_user),
                                       action="locked", reverse=False)
 
             flash(_("%(count)s topics locked.", count=changed), "success")
             return redirect(mod_forum_url)
 
         elif "unlock" in request.form:
-            changed = do_topic_action(topics=tmp_topics, user=current_user,
+            changed = do_topic_action(topics=tmp_topics, user=real(current_user),
                                       action="locked", reverse=True)
             flash(_("%(count)s topics unlocked.", count=changed), "success")
             return redirect(mod_forum_url)
 
         # highlighting/trivializing
         elif "highlight" in request.form:
-            changed = do_topic_action(topics=tmp_topics, user=current_user,
+            changed = do_topic_action(topics=tmp_topics, user=real(current_user),
                                       action="important", reverse=False)
             flash(_("%(count)s topics highlighted.", count=changed), "success")
             return redirect(mod_forum_url)
 
         elif "trivialize" in request.form:
-            changed = do_topic_action(topics=tmp_topics, user=current_user,
+            changed = do_topic_action(topics=tmp_topics, user=real(current_user),
                                       action="important", reverse=True)
             flash(_("%(count)s topics trivialized.", count=changed), "success")
             return redirect(mod_forum_url)
 
         # deleting
         elif "delete" in request.form:
-            changed = do_topic_action(topics=tmp_topics, user=current_user,
+            changed = do_topic_action(topics=tmp_topics, user=real(current_user),
                                       action="delete", reverse=False)
             flash(_("%(count)s topics deleted.", count=changed), "success")
             return redirect(mod_forum_url)
@@ -393,7 +393,7 @@ def new_post(topic_id, slug=None):
                 form=form, preview=form.content.data
             )
         else:
-            post = form.save(current_user, topic)
+            post = form.save(real(current_user), topic)
             return view_post(post.id)
 
     return render_template("forum/new_post.html", topic=topic, form=form)
@@ -420,7 +420,7 @@ def reply_post(topic_id, post_id):
                 form=form, preview=form.content.data
             )
         else:
-            post = form.save(current_user, topic)
+            post = form.save(real(current_user), topic)
             return view_post(post.id)
     else:
         form.content.data = format_quote(post.username, post.content)
@@ -452,7 +452,7 @@ def edit_post(post_id):
         else:
             form.populate_obj(post)
             post.date_modified = time_utcnow()
-            post.modified_by = current_user.username
+            post.modified_by = real(current_user).username
             post.save()
 
             if post.first_post:
@@ -500,7 +500,7 @@ def report_post(post_id):
 
     form = ReportForm()
     if form.validate_on_submit():
-        form.save(current_user, post)
+        form.save(real(current_user), post)
         flash(_("Thanks for reporting."), "success")
 
     return render_template("forum/report_post.html", form=form)
@@ -521,15 +521,15 @@ def markread(forum_id=None, slug=None):
     if forum_id:
         forum_instance = Forum.query.filter_by(id=forum_id).first_or_404()
         forumsread = ForumsRead.query.filter_by(
-            user_id=current_user.id, forum_id=forum_instance.id
+            user_id=real(current_user).id, forum_id=forum_instance.id
         ).first()
-        TopicsRead.query.filter_by(user_id=current_user.id,
+        TopicsRead.query.filter_by(user_id=real(current_user).id,
                                    forum_id=forum_instance.id).delete()
 
         if not forumsread:
             forumsread = ForumsRead()
-            forumsread.user_id = current_user.id
-            forumsread.forum_id = forum_instance.id
+            forumsread.user = real(current_user)
+            forumsread.forum = forum_instance
 
         forumsread.last_read = time_utcnow()
         forumsread.cleared = time_utcnow()
@@ -543,15 +543,15 @@ def markread(forum_id=None, slug=None):
         return redirect(forum_instance.url)
 
     # Mark all forums as read
-    ForumsRead.query.filter_by(user_id=current_user.id).delete()
-    TopicsRead.query.filter_by(user_id=current_user.id).delete()
+    ForumsRead.query.filter_by(user_id=real(current_user).id).delete()
+    TopicsRead.query.filter_by(user_id=real(current_user).id).delete()
 
     forums = Forum.query.all()
     forumsread_list = []
     for forum_instance in forums:
         forumsread = ForumsRead()
-        forumsread.user_id = current_user.id
-        forumsread.forum_id = forum_instance.id
+        forumsread.user = real(current_user)
+        forumsread.forum = forum_instance
         forumsread.last_read = time_utcnow()
         forumsread.cleared = time_utcnow()
         forumsread_list.append(forumsread)
@@ -611,10 +611,10 @@ def memberlist():
 @login_required
 def topictracker():
     page = request.args.get("page", 1, type=int)
-    topics = current_user.tracked_topics.\
+    topics = real(current_user).tracked_topics.\
         outerjoin(TopicsRead,
                   db.and_(TopicsRead.topic_id == Topic.id,
-                          TopicsRead.user_id == current_user.id)).\
+                          TopicsRead.user_id == real(current_user).id)).\
         add_entity(TopicsRead).\
         order_by(Topic.last_updated.desc()).\
         paginate(page, flaskbb_config['TOPICS_PER_PAGE'], True)
@@ -625,8 +625,8 @@ def topictracker():
         tmp_topics = Topic.query.filter(Topic.id.in_(topic_ids)).all()
 
         for topic in tmp_topics:
-            current_user.untrack_topic(topic)
-        current_user.save()
+            real(current_user).untrack_topic(topic)
+        real(current_user).save()
 
         flash(_("%(topic_count)s topics untracked.",
                 topic_count=len(tmp_topics)), "success")
@@ -640,8 +640,8 @@ def topictracker():
 @login_required
 def track_topic(topic_id, slug=None):
     topic = Topic.query.filter_by(id=topic_id).first_or_404()
-    current_user.track_topic(topic)
-    current_user.save()
+    real(current_user).track_topic(topic)
+    real(current_user).save()
     return redirect(topic.url)
 
 
@@ -650,8 +650,8 @@ def track_topic(topic_id, slug=None):
 @login_required
 def untrack_topic(topic_id, slug=None):
     topic = Topic.query.filter_by(id=topic_id).first_or_404()
-    current_user.untrack_topic(topic)
-    current_user.save()
+    real(current_user).untrack_topic(topic)
+    real(current_user).save()
     return redirect(topic.url)
 
 

+ 1 - 1
flaskbb/message/models.py

@@ -103,7 +103,7 @@ class Message(db.Model, CRUDMixin):
                              belongs to.
         """
         if conversation is not None:
-            self.conversation_id = conversation.id
+            self.conversation = conversation
             conversation.date_modified = time_utcnow()
             self.date_created = time_utcnow()
 

+ 1 - 1
flaskbb/templates/forum/topic.html

@@ -110,7 +110,7 @@
                             <!-- Edit Post -->
                             <a href="{{ url_for('forum.edit_post', post_id=post.id) }}" class="btn btn-icon icon-edit" data-toggle="tooltip" data-placement="top" title="Edit this post"></a>
                             {% endif %}
-                            {% if topic.first_post_id == post.id %}
+                            {% if topic.first_post == post %}
                                 {% if current_user|delete_topic(topic) %}
                                 <form class="inline-form" method="post" action="{{ url_for('forum.delete_topic', topic_id=topic.id, slug=topic.slug) }}">
                                     <input type="hidden" name="csrf_token" value="{{ csrf_token() }}" />

+ 1 - 1
flaskbb/templates/forum/topic_horizontal.html

@@ -111,7 +111,7 @@
                             <!-- Edit Post -->
                             <a href="{{ url_for('forum.edit_post', post_id=post.id) }}" class="btn btn-icon icon-edit" data-toggle="tooltip" data-placement="top" title="Edit this post"></a>
                             {% endif %}
-                            {% if topic.first_post_id == post.id %}
+                            {% if topic.first_post == post %}
                                 {% if current_user|delete_topic(topic) %}
                                 <form class="inline-form" method="post" action="{{ url_for('forum.delete_topic', topic_id=topic.id, slug=topic.slug) }}">
                                     <input type="hidden" name="csrf_token" value="{{ csrf_token() }}" />

+ 5 - 4
flaskbb/user/models.py

@@ -16,7 +16,7 @@ from flaskbb.extensions import db, cache
 from flaskbb.exceptions import AuthenticationError
 from flaskbb.utils.helpers import time_utcnow
 from flaskbb.utils.settings import flaskbb_config
-from flaskbb.utils.database import CRUDMixin, UTCDateTime
+from flaskbb.utils.database import CRUDMixin, UTCDateTime, make_comparable
 from flaskbb.forum.models import (Post, Topic, Forum, topictracker, TopicsRead,
                                   ForumsRead)
 from flaskbb.message.models import Conversation
@@ -31,6 +31,7 @@ groups_users = db.Table(
 )
 
 
+@make_comparable
 class Group(db.Model, CRUDMixin):
     __tablename__ = "groups"
 
@@ -413,7 +414,7 @@ class User(db.Model, UserMixin, CRUDMixin):
                 Group.banned == True
             ).first()
 
-            self.primary_group_id = banned_group.id
+            self.primary_group = banned_group
             self.save()
             self.invalidate_cache()
             return True
@@ -430,7 +431,7 @@ class User(db.Model, UserMixin, CRUDMixin):
                 Group.banned == False
             ).first()
 
-            self.primary_group_id = member_group.id
+            self.primary_group = member_group
             self.save()
             self.invalidate_cache()
             return True
@@ -452,7 +453,7 @@ class User(db.Model, UserMixin, CRUDMixin):
 
             for group in groups:
                 # Do not add the primary group to the secondary groups
-                if group.id == self.primary_group_id:
+                if group == self.primary_group:
                     continue
                 self.add_to_group(group)
 

+ 11 - 1
flaskbb/utils/database.py

@@ -9,10 +9,20 @@
     :license: BSD, see LICENSE for more details.
 """
 import pytz
-
 from flaskbb.extensions import db
 
 
+def make_comparable(cls):
+    def __eq__(self, other):
+        return isinstance(other, cls) and self.id == other.id
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    cls.__eq__ = __eq__
+    return cls
+
+
 class CRUDMixin(object):
     def __repr__(self):
         return "<{}>".format(self.__class__.__name__)

+ 11 - 0
flaskbb/utils/helpers.py

@@ -29,6 +29,8 @@ from flask_babelplus import lazy_gettext as _
 from flask_themes2 import render_theme_template, get_themes_list
 from flask_login import current_user
 
+from werkzeug.local import LocalProxy
+
 from flaskbb._compat import range_method, text_type, iteritems
 from flaskbb.extensions import redis_store, babel
 from flaskbb.utils.settings import flaskbb_config
@@ -614,3 +616,12 @@ class ReverseProxyPathFix(object):
             environ['wsgi.url_scheme'] = 'https'
 
         return self.app(environ, start_response)
+
+
+def real(obj):
+    """
+    Unwraps a werkzeug.local.LocalProxy object if given one, else returns the object
+    """
+    if isinstance(obj, LocalProxy):
+        return obj._get_current_object()
+    return obj

+ 1 - 1
flaskbb/utils/populate.py

@@ -197,7 +197,7 @@ def update_user(username, password, email, groupname):
 
     user.password = password
     user.email = email
-    user.primary_group_id = group.id
+    user.primary_group = group
     return user.save()