database.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # -*- coding: utf-8 -*-
  2. """
  3. flaskbb.utils.database
  4. ~~~~~~~~~~~~~~~~~~~~~~
  5. Some database helpers such as a CRUD mixin.
  6. :copyright: (c) 2015 by the FlaskBB Team.
  7. :license: BSD, see LICENSE for more details.
  8. """
  9. import logging
  10. import pytz
  11. from flask_login import current_user
  12. from flask_sqlalchemy import BaseQuery
  13. from sqlalchemy.ext.declarative import declared_attr
  14. from flaskbb.extensions import db
  15. from ..core.exceptions import PersistenceError
  16. logger = logging.getLogger(__name__)
  17. def make_comparable(cls):
  18. def __eq__(self, other):
  19. return isinstance(other, cls) and self.id == other.id
  20. def __ne__(self, other):
  21. return not self.__eq__(other)
  22. def __hash__(self):
  23. return hash((cls, self.id))
  24. cls.__eq__ = __eq__
  25. cls.__ne__ = __ne__
  26. cls.__hash__ = __hash__
  27. return cls
  28. class CRUDMixin(object):
  29. def __repr__(self):
  30. return "<{}>".format(self.__class__.__name__)
  31. @classmethod
  32. def create(cls, **kwargs):
  33. instance = cls(**kwargs)
  34. return instance.save()
  35. def save(self):
  36. """Saves the object to the database."""
  37. db.session.add(self)
  38. db.session.commit()
  39. return self
  40. def delete(self):
  41. """Delete the object from the database."""
  42. db.session.delete(self)
  43. db.session.commit()
  44. return self
  45. class UTCDateTime(db.TypeDecorator):
  46. impl = db.DateTime
  47. def process_bind_param(self, value, dialect):
  48. """Way into the database."""
  49. if value is not None:
  50. # store naive datetime for sqlite and mysql
  51. if dialect.name in ("sqlite", "mysql"):
  52. return value.replace(tzinfo=None)
  53. return value.astimezone(pytz.UTC)
  54. def process_result_value(self, value, dialect):
  55. """Way out of the database."""
  56. # convert naive datetime to non naive datetime
  57. if dialect.name in ("sqlite", "mysql") and value is not None:
  58. return value.replace(tzinfo=pytz.UTC)
  59. # other dialects are already non-naive
  60. return value
  61. class HideableQuery(BaseQuery):
  62. def __new__(cls, *args, **kwargs):
  63. inst = super(HideableQuery, cls).__new__(cls)
  64. include_hidden = kwargs.pop("_with_hidden", False)
  65. has_view_hidden = current_user and current_user.permissions.get(
  66. "viewhidden", False
  67. )
  68. with_hidden = include_hidden or has_view_hidden
  69. if args or kwargs:
  70. super(HideableQuery, inst).__init__(*args, **kwargs)
  71. entity = inst._mapper_zero().class_
  72. return inst.filter(
  73. entity.hidden != True
  74. ) if not with_hidden else inst
  75. return inst
  76. def __init__(self, *args, **kwargs):
  77. pass
  78. def with_hidden(self):
  79. return self.__class__(
  80. db.class_mapper(self._mapper_zero().class_),
  81. session=db.session(),
  82. _with_hidden=True,
  83. )
  84. def _get(self, *args, **kwargs):
  85. return super(HideableQuery, self).get(*args, **kwargs)
  86. def get(self, *args, **kwargs):
  87. include_hidden = kwargs.pop("include_hidden", False)
  88. obj = self.with_hidden()._get(*args, **kwargs)
  89. return obj if obj is not None and (
  90. include_hidden or not obj.hidden
  91. ) else None
  92. class HideableMixin(object):
  93. query_class = HideableQuery
  94. hidden = db.Column(db.Boolean, default=False, nullable=True)
  95. hidden_at = db.Column(UTCDateTime(timezone=True), nullable=True)
  96. @declared_attr
  97. def hidden_by_id(cls): # noqa: B902
  98. return db.Column(
  99. db.Integer,
  100. db.ForeignKey(
  101. "users.id", name="fk_{}_hidden_by".format(cls.__name__)
  102. ),
  103. nullable=True,
  104. )
  105. @declared_attr
  106. def hidden_by(cls): # noqa: B902
  107. return db.relationship(
  108. "User", uselist=False, foreign_keys=[cls.hidden_by_id]
  109. )
  110. def hide(self, user, *args, **kwargs):
  111. from flaskbb.utils.helpers import time_utcnow
  112. self.hidden_by = user
  113. self.hidden = True
  114. self.hidden_at = time_utcnow()
  115. return self
  116. def unhide(self, *args, **kwargs):
  117. self.hidden_by = None
  118. self.hidden = False
  119. self.hidden_at = None
  120. return self
  121. class HideableCRUDMixin(HideableMixin, CRUDMixin):
  122. pass
  123. def try_commit(session, message="Error while saving"):
  124. try:
  125. session.commit()
  126. except Exception:
  127. raise PersistenceError(message)