Просмотр исходного кода

Merge pull request #593 from flaskbb/fix-sqlalchemy-1-4

Update SQLAlchemy to 1.4
Peter Justin 3 лет назад
Родитель
Сommit
cc083eb5f6

+ 1 - 1
flaskbb/extensions.py

@@ -41,7 +41,7 @@ metadata = MetaData(
         "pk": "pk_%(table_name)s",
     }
 )
-db = SQLAlchemy(metadata=metadata)
+db = SQLAlchemy(metadata=metadata, session_options={"future": True})
 
 # Whooshee (Full Text Search)
 whooshee = Whooshee()

+ 49 - 46
flaskbb/forum/models.py

@@ -206,7 +206,7 @@ class Post(HideableCRUDMixin, db.Model):
         """
         current_app.pluggy.hook.flaskbb_event_post_save_before(post=self)
 
-        # update/edit the post
+        # update a post
         if self.id:
             db.session.add(self)
             db.session.commit()
@@ -216,30 +216,31 @@ class Post(HideableCRUDMixin, db.Model):
 
         # Adding a new post
         if user and topic:
-            created = time_utcnow()
-            self.user = user
-            self.username = user.username
-            self.topic = topic
-            self.date_created = created
-
-            if not topic.hidden:
-                topic.last_updated = created
-                topic.last_post = self
-
-                # Update the last post info for the forum
-                topic.forum.last_post = self
-                topic.forum.last_post_user = self.user
-                topic.forum.last_post_title = topic.title
-                topic.forum.last_post_username = user.username
-                topic.forum.last_post_created = created
-
-                # Update the post counts
-                user.post_count += 1
-                topic.post_count += 1
-                topic.forum.post_count += 1
+            with db.session.no_autoflush:
+                created = time_utcnow()
+                self.user = user
+                self.username = user.username
+                self.topic = topic
+                self.date_created = created
+
+                if not topic.hidden:
+                    topic.last_updated = created
+                    topic.last_post = self
+
+                    # Update the last post info for the forum
+                    topic.forum.last_post = self
+                    topic.forum.last_post_user = self.user
+                    topic.forum.last_post_title = topic.title
+                    topic.forum.last_post_username = user.username
+                    topic.forum.last_post_created = created
+
+                    # Update the post counts
+                    user.post_count += 1
+                    topic.post_count += 1
+                    topic.forum.post_count += 1
 
             # And commit it!
-            db.session.add(topic)
+            db.session.add(self)
             db.session.commit()
             current_app.pluggy.hook.flaskbb_event_post_save_after(post=self,
                                                                   is_new=True)
@@ -669,31 +670,32 @@ class Topic(HideableCRUDMixin, db.Model):
             )
             return self
 
-        # Set the forum and user id
-        self.forum = forum
-        self.user = user
-        self.username = user.username
+        with db.session.no_autoflush:
+            # Set the forum and user id
+            self.forum = forum
+            self.user = user
+            self.username = user.username
 
-        # Set the last_updated time. Needed for the readstracker
-        self.date_created = self.last_updated = time_utcnow()
+            # Set the last_updated time. Needed for the readstracker
+            self.date_created = self.last_updated = time_utcnow()
 
-        # Insert and commit the topic
-        db.session.add(self)
-        db.session.commit()
+            # Insert and commit the topic
+            db.session.add(self)
+            db.session.commit()
 
-        if post is not None:
-            self._post = post
+            if post is not None:
+                self._post = post
 
-        # Create the topic post
-        self._post.save(user, self)
+            # Create the topic post
+            self._post.save(user, self)
 
-        # Update the first and last post id
-        self.last_post = self.first_post = self._post
+            # Update the first and last post id
+            self.last_post = self.first_post = self._post
 
-        # Update the topic count
-        forum.topic_count += 1
-        db.session.commit()
+            # Update the topic count
+            forum.topic_count += 1
 
+        db.session.commit()
         current_app.pluggy.hook.flaskbb_event_topic_save_after(topic=self,
                                                                is_new=True)
         return self
@@ -1067,11 +1069,12 @@ class Forum(db.Model, CRUDMixin):
         if self.id:
             db.session.merge(self)
         else:
-            if groups is None:
-                # importing here because of circular dependencies
-                from flaskbb.user.models import Group
-                self.groups = Group.query.order_by(Group.name.asc()).all()
-            db.session.add(self)
+            with db.session.no_autoflush:
+                if groups is None:
+                    # importing here because of circular dependencies
+                    from flaskbb.user.models import Group
+                    self.groups = Group.query.order_by(Group.name.asc()).all()
+                db.session.add(self)
 
         db.session.commit()
         return self

+ 4 - 4
flaskbb/user/models.py

@@ -445,10 +445,10 @@ class User(db.Model, UserMixin, CRUDMixin):
         """
         if groups is not None:
             # TODO: Only remove/add groups that are selected
-            secondary_groups = self.secondary_groups.all()
-            for group in secondary_groups:
-                self.remove_from_group(group)
-            db.session.commit()
+            with db.session.no_autoflush:
+                secondary_groups = self.secondary_groups.all()
+                for group in secondary_groups:
+                    self.remove_from_group(group)
 
             for group in groups:
                 # Do not add the primary group to the secondary groups

+ 13 - 16
flaskbb/utils/database.py

@@ -12,7 +12,7 @@ import logging
 import pytz
 from flask_login import current_user
 from flask_sqlalchemy import BaseQuery
-from sqlalchemy.ext.declarative import declared_attr
+from sqlalchemy.orm import declarative_mixin, declared_attr
 from flaskbb.extensions import db
 from ..core.exceptions import PersistenceError
 
@@ -21,7 +21,6 @@ logger = logging.getLogger(__name__)
 
 
 def make_comparable(cls):
-
     def __eq__(self, other):
         return isinstance(other, cls) and self.id == other.id
 
@@ -38,7 +37,6 @@ def make_comparable(cls):
 
 
 class CRUDMixin(object):
-
     def __repr__(self):
         return "<{}>".format(self.__class__.__name__)
 
@@ -62,6 +60,7 @@ class CRUDMixin(object):
 
 class UTCDateTime(db.TypeDecorator):
     impl = db.DateTime
+    cache_ok = True
 
     def process_bind_param(self, value, dialect):
         """Way into the database."""
@@ -83,28 +82,26 @@ class UTCDateTime(db.TypeDecorator):
 
 
 class HideableQuery(BaseQuery):
+    _with_hidden = False
 
     def __new__(cls, *args, **kwargs):
-        inst = super(HideableQuery, cls).__new__(cls)
+        obj = super(HideableQuery, cls).__new__(cls)
         include_hidden = kwargs.pop("_with_hidden", False)
         has_view_hidden = current_user and current_user.permissions.get(
             "viewhidden", False
         )
-        with_hidden = include_hidden or has_view_hidden
+        obj._with_hidden = include_hidden or has_view_hidden
         if args or kwargs:
-            super(HideableQuery, inst).__init__(*args, **kwargs)
-            entity = inst._mapper_zero().class_
-            return inst.filter(
-                entity.hidden != True
-            ) if not with_hidden else inst
-        return inst
+            super(HideableQuery, obj).__init__(*args, **kwargs)
+            return obj.filter_by(hidden=False) if not obj._with_hidden else obj
+        return obj
 
     def __init__(self, *args, **kwargs):
         pass
 
     def with_hidden(self):
         return self.__class__(
-            db.class_mapper(self._mapper_zero().class_),
+            self._only_full_mapper_zero("get"),
             session=db.session(),
             _with_hidden=True,
         )
@@ -113,13 +110,13 @@ class HideableQuery(BaseQuery):
         return super(HideableQuery, self).get(*args, **kwargs)
 
     def get(self, *args, **kwargs):
-        include_hidden = kwargs.pop("include_hidden", False)
         obj = self.with_hidden()._get(*args, **kwargs)
-        return obj if obj is not None and (
-            include_hidden or not obj.hidden
-        ) else None
+        return (
+            obj if obj is None or self._with_hidden or not obj.hidden else None
+        )
 
 
+@declarative_mixin
 class HideableMixin(object):
     query_class = HideableQuery
 

+ 50 - 68
flaskbb/utils/populate.py

@@ -280,37 +280,38 @@ def create_test_data(users=5, categories=2, forums=2, topics=1, posts=1):
     user2 = User.query.filter_by(id=2).first()
 
     # create 2 categories
-    for i in range(1, categories + 1):
-        category_title = "Test Category %s" % i
-        category = Category(title=category_title,
-                            description="Test Description")
-        category.save()
-        data_created['categories'] += 1
-
-        # create 2 forums in each category
-        for j in range(1, forums + 1):
-            if i == 2:
-                j += 2
-
-            forum_title = "Test Forum %s %s" % (j, i)
-            forum = Forum(title=forum_title, description="Test Description",
-                          category_id=i)
-            forum.save()
-            data_created['forums'] += 1
-
-            for _ in range(1, topics + 1):
-                # create a topic
-                topic = Topic(title="Test Title %s" % j)
-                post = Post(content="Test Content")
-
-                topic.save(post=post, user=user1, forum=forum)
-                data_created['topics'] += 1
-
-                for _ in range(1, posts + 1):
-                    # create a second post in the forum
-                    post = Post(content="Test Post")
-                    post.save(user=user2, topic=topic)
-                    data_created['posts'] += 1
+    with db.session.no_autoflush:
+        for i in range(1, categories + 1):
+            category_title = "Test Category %s" % i
+            category = Category(title=category_title,
+                                description="Test Description")
+            category.save()
+            data_created['categories'] += 1
+
+            # create 2 forums in each category
+            for j in range(1, forums + 1):
+                if i == 2:
+                    j += 2
+
+                forum_title = "Test Forum %s %s" % (j, i)
+                forum = Forum(title=forum_title, description="Test Description",
+                              category_id=i)
+                forum.save()
+                data_created['forums'] += 1
+
+                for _ in range(1, topics + 1):
+                    # create a topic
+                    topic = Topic(title="Test Title %s" % j)
+                    post = Post(content="Test Content")
+                    topic.save(post=post, user=user1, forum=forum)
+                    data_created['topics'] += 1
+
+                    for _ in range(1, posts + 1):
+                        # create a second post in the forum
+                        post = Post(content="Test Post")
+                        #db.session.add_all([post, user2, topic])
+                        post.save(user=user2, topic=topic)
+                        data_created['posts'] += 1
 
     return data_created
 
@@ -337,46 +338,27 @@ def insert_bulk_data(topic_count=10, post_count=100):
     if not (user1 or user2 or forum):
         return False
 
-    db.session.begin(subtransactions=True)
+    with db.session.no_autoflush:
+        for i in range(1, topic_count + 1):
+            last_post_id += 1
 
-    for i in range(1, topic_count + 1):
-        last_post_id += 1
+            # create a topic
+            topic = Topic(title="Test Title %s" % i)
+            post = Post(content="First Post")
+            topic.save(post=post, user=user1, forum=forum)
+            created_topics += 1
 
-        # create a topic
-        topic = Topic(title="Test Title %s" % i)
-        post = Post(content="First Post")
-        topic.save(post=post, user=user1, forum=forum)
-        created_topics += 1
+            # create some posts in the topic
+            for _ in range(1, post_count + 1):
+                last_post_id += 1
+                post = Post(content="Some other Post", user=user2, topic=topic.id)
+                topic.last_updated = post.date_created
+                topic.post_count += 1
 
-        # create some posts in the topic
-        for _ in range(1, post_count + 1):
-            last_post_id += 1
-            post = Post(content="Some other Post", user=user2, topic=topic.id)
-            topic.last_updated = post.date_created
-            topic.post_count += 1
-
-            # FIXME: Is there a way to ignore IntegrityErrors?
-            # At the moment, the first_post_id is also the last_post_id.
-            # This does no harm, except that in the forums view, you see
-            # the information for the first post instead of the last one.
-            # I run a little benchmark:
-            # 5.3643078804 seconds to create 100 topics and 10000 posts
-            # Using another method (where data integrity is ok) I benchmarked
-            # these stats:
-            # 49.7832770348 seconds to create 100 topics and 10000 posts
-
-            # Uncomment the line underneath and the other line to reduce
-            # performance but fixes the above mentioned problem.
-            # topic.last_post_id = last_post_id
-
-            created_posts += 1
-            posts.append(post)
-
-        # uncomment this and delete the one below, also uncomment the
-        # topic.last_post_id line above. This will greatly reduce the
-        # performance.
-        # db.session.bulk_save_objects(posts)
-    db.session.bulk_save_objects(posts)
+                created_posts += 1
+                posts.append(post)
+
+        db.session.bulk_save_objects(posts)
 
     # and finally, lets update some stats
     forum.recalculate(last_post=True)

+ 2 - 2
requirements.txt

@@ -29,7 +29,7 @@ Flask-SQLAlchemy==2.5.1
 Flask-Themes2==0.1.5
 flask-whooshee==0.8.1
 Flask-WTF==0.15.1
-flaskbb-plugin-conversations==1.0.7
+flaskbb-plugin-conversations==1.0.8
 flaskbb-plugin-portal==1.1.3
 future==0.18.2
 idna==3.2
@@ -53,7 +53,7 @@ requests==2.26.0
 simplejson==3.17.3
 six==1.16.0
 speaklater==1.3
-SQLAlchemy==1.3.24
+SQLAlchemy==1.4.21
 SQLAlchemy-Utils==0.37.8
 Unidecode==1.2.0
 urllib3==1.26.6

+ 1 - 1
setup.py

@@ -40,7 +40,7 @@ install_requires = [
     "flask-redis>=0.4.0",
     "Flask-SQLAlchemy>=2.4.4",
     "Flask-Themes2>=0.1.5",
-    "flask-whooshee>=0.7.0",
+    "flask-whooshee>=0.8.1",
     "Flask-WTF>=0.14.3",
     "flaskbb-plugin-conversations>=1.0.7",
     "flaskbb-plugin-portal>=1.1.3",

+ 10 - 7
tests/unit/forum/test_forum_utils.py

@@ -12,17 +12,20 @@ class TestForceLoginHelpers(object):
     def test_would_not_force_login_for_anon_in_guest_allowed(self, forum, guest):
         assert not utils.should_force_login(guest, forum)
 
-    def test_would_force_login_for_anon_in_guest_unallowed(self, guest, category):
-        forum = Forum(title="no guest", category=category)
-        forum.groups = Group.query.filter(Group.guest == False).all()
-
+    def test_would_force_login_for_anon_in_guest_unallowed(self, database, guest, category):
+        with database.session.no_autoflush:
+            forum = Forum(title="no guest", category=category)
+            forum.groups = Group.query.filter(Group.guest == False).all()
+            forum.save()
         assert utils.should_force_login(guest, forum)
 
     def test_redirects_to_login_with_anon(
-        self, guest, category, request_context, application
+        self, database, guest, category, request_context, application
     ):
-        forum = Forum(title="no guest", category=category)
-        forum.groups = Group.query.filter(Group.guest == False).all()
+        with database.session.no_autoflush:
+            forum = Forum(title="no guest", category=category)
+            forum.groups = Group.query.filter(Group.guest == False).all()
+            forum.save()
         # sets current_forum
         _request_ctx_stack.top.forum = forum
 

+ 2 - 2
tests/unit/test_forum_models.py

@@ -645,7 +645,7 @@ def test_retrieving_hidden_posts(topic, user):
     new_post.hide(user)
 
     assert Post.query.get(new_post.id) is None
-    assert Post.query.get(new_post.id, include_hidden=True) == new_post
+    assert Post.query.with_hidden().get(new_post.id) == new_post
     assert Post.query.filter(Post.id == new_post.id).first() is None
     hidden_post = Post.query\
         .with_hidden()\
@@ -658,7 +658,7 @@ def test_retrieving_hidden_topics(topic, user):
     topic.hide(user)
 
     assert Topic.query.get(topic.id) is None
-    assert Topic.query.get(topic.id, include_hidden=True) == topic
+    assert Topic.query.with_hidden().get(topic.id) == topic
     assert Topic.query.filter(Topic.id == topic.id).first() is None
     hidden_topic = Topic.query\
         .with_hidden()\