pgutils.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import hashlib
  2. from django.db.models import Index, Q
  3. from django.utils.encoding import force_bytes
  4. class PgPartialIndex(Index):
  5. suffix = "part"
  6. max_name_length = 31
  7. def __init__(self, fields=None, name=None, where=None):
  8. if not where:
  9. raise ValueError("partial index requires WHERE clause")
  10. self.where = where
  11. if isinstance(where, dict):
  12. condition = Q(**where)
  13. else:
  14. condition = where
  15. if not name:
  16. name = "_".join(where.keys())[:30]
  17. fields = fields or []
  18. super().__init__(fields=fields, name=name, condition=condition)
  19. def set_name_with_model(self, model):
  20. table_name = model._meta.db_table
  21. column_names = sorted(self.where.keys())
  22. where_items = []
  23. for key in sorted(self.where.keys()):
  24. where_items.append("%s:%s" % (key, repr(self.where[key])))
  25. # The length of the parts of the name is based on the default max
  26. # length of 30 characters.
  27. hash_data = [table_name] + self.fields + where_items + [self.suffix]
  28. self.name = "%s_%s_%s" % (
  29. table_name[:11],
  30. column_names[0][:7],
  31. "%s_%s" % (self._hash_generator(*hash_data), self.suffix),
  32. )
  33. assert len(self.name) <= self.max_name_length, (
  34. "Index too long for multiple database support. Is self.suffix "
  35. "longer than 3 characters?"
  36. )
  37. self.check_name()
  38. @staticmethod
  39. def _hash_generator(*args):
  40. """
  41. Method Index._hash_generator is removed in django 2.2
  42. This method is copy from old django 2.1
  43. """
  44. h = hashlib.md5()
  45. for arg in args:
  46. h.update(force_bytes(arg))
  47. return h.hexdigest()[:6]
  48. def deconstruct(self):
  49. path, args, kwargs = super().deconstruct()
  50. # TODO: check this patch
  51. kwargs["where"] = self.condition
  52. del kwargs["condition"]
  53. return path, args, kwargs
  54. def get_sql_create_template_values(self, model, schema_editor, using):
  55. parameters = super().get_sql_create_template_values(model, schema_editor, "")
  56. parameters["extra"] = self.get_sql_extra(model, schema_editor)
  57. return parameters
  58. def get_sql_extra(self, model, schema_editor):
  59. quote_name = schema_editor.quote_name
  60. quote_value = schema_editor.quote_value
  61. clauses = []
  62. for field, condition in self.where.items():
  63. field_name = None
  64. compr = None
  65. if field.endswith("__lt"):
  66. field_name = field[:-4]
  67. compr = "<"
  68. elif field.endswith("__gt"):
  69. field_name = field[:-4]
  70. compr = ">"
  71. elif field.endswith("__lte"):
  72. field_name = field[:-5]
  73. compr = "<="
  74. elif field.endswith("__gte"):
  75. field_name = field[:-5]
  76. compr = ">="
  77. else:
  78. field_name = field
  79. compr = "="
  80. column = model._meta.get_field(field_name).column
  81. clauses.append(
  82. "%s %s %s" % (quote_name(column), compr, quote_value(condition))
  83. )
  84. # sort clauses for their order to be determined and testable
  85. return " WHERE %s" % (" AND ".join(sorted(clauses)))
  86. def chunk_queryset(queryset, chunk_size=20):
  87. ordered_queryset = queryset.order_by("-pk") # bias to newest items first
  88. chunk = ordered_queryset[:chunk_size]
  89. while chunk:
  90. last_pk = None
  91. for item in chunk:
  92. last_pk = item.pk
  93. yield item
  94. chunk = ordered_queryset.filter(pk__lt=last_pk)[:chunk_size]