pgutils.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from django.core.paginator import Paginator
  2. from django.db.models import Index
  3. class PgPartialIndex(Index):
  4. suffix = "part"
  5. max_name_length = 31
  6. def __init__(self, fields=[], name=None, where=None):
  7. if not where:
  8. raise ValueError("partial index requires WHERE clause")
  9. self.where = where
  10. super().__init__(fields, name)
  11. def set_name_with_model(self, model):
  12. table_name = model._meta.db_table
  13. column_names = sorted(self.where.keys())
  14. where_items = []
  15. for key in sorted(self.where.keys()):
  16. where_items.append("%s:%s" % (key, repr(self.where[key])))
  17. # The length of the parts of the name is based on the default max
  18. # length of 30 characters.
  19. hash_data = [table_name] + self.fields + where_items + [self.suffix]
  20. self.name = "%s_%s_%s" % (
  21. table_name[:11],
  22. column_names[0][:7],
  23. "%s_%s" % (self._hash_generator(*hash_data), self.suffix),
  24. )
  25. assert len(self.name) <= self.max_name_length, (
  26. "Index too long for multiple database support. Is self.suffix "
  27. "longer than 3 characters?"
  28. )
  29. self.check_name()
  30. def __repr__(self):
  31. if self.where is not None:
  32. where_items = []
  33. for key in sorted(self.where.keys()):
  34. where_items.append("=".join([key, repr(self.where[key])]))
  35. return "<%(name)s: fields=%(fields)s, where=%(where)s>" % {
  36. "name": self.__class__.__name__,
  37. "fields": "'%s'" % (", ".join(self.fields)),
  38. "where": "'%s'" % (", ".join(where_items)),
  39. }
  40. else:
  41. return super().__repr__()
  42. def deconstruct(self):
  43. path, args, kwargs = super().deconstruct()
  44. kwargs["where"] = self.where
  45. return path, args, kwargs
  46. def get_sql_create_template_values(self, model, schema_editor, using):
  47. parameters = super().get_sql_create_template_values(model, schema_editor, "")
  48. parameters["extra"] = self.get_sql_extra(model, schema_editor)
  49. return parameters
  50. def get_sql_extra(self, model, schema_editor):
  51. quote_name = schema_editor.quote_name
  52. quote_value = schema_editor.quote_value
  53. clauses = []
  54. for field, condition in self.where.items():
  55. field_name = None
  56. compr = None
  57. if field.endswith("__lt"):
  58. field_name = field[:-4]
  59. compr = "<"
  60. elif field.endswith("__gt"):
  61. field_name = field[:-4]
  62. compr = ">"
  63. elif field.endswith("__lte"):
  64. field_name = field[:-5]
  65. compr = "<="
  66. elif field.endswith("__gte"):
  67. field_name = field[:-5]
  68. compr = ">="
  69. else:
  70. field_name = field
  71. compr = "="
  72. column = model._meta.get_field(field_name).column
  73. clauses.append(
  74. "%s %s %s" % (quote_name(column), compr, quote_value(condition))
  75. )
  76. # sort clauses for their order to be determined and testable
  77. return " WHERE %s" % (" AND ".join(sorted(clauses)))
  78. def chunk_queryset(queryset, chunk_size=20):
  79. ordered_queryset = queryset.order_by("-pk") # bias to newest items first
  80. chunk = ordered_queryset[:chunk_size]
  81. while chunk:
  82. last_pk = None
  83. for item in chunk:
  84. last_pk = item.pk
  85. yield item
  86. chunk = ordered_queryset.filter(pk__lt=last_pk)[:chunk_size]