pgutils.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from django.db.models import Index
  2. class PgPartialIndex(Index):
  3. suffix = "part"
  4. max_name_length = 31
  5. def __init__(self, fields=None, name=None, where=None):
  6. if not where:
  7. raise ValueError("partial index requires WHERE clause")
  8. self.where = where
  9. fields = fields or []
  10. super().__init__(fields, name)
  11. def set_name_with_model(self, model):
  12. table_name = model._meta.db_table
  13. column_names = sorted(self.where.keys())
  14. where_items = []
  15. for key in sorted(self.where.keys()):
  16. where_items.append("%s:%s" % (key, repr(self.where[key])))
  17. # The length of the parts of the name is based on the default max
  18. # length of 30 characters.
  19. hash_data = [table_name] + self.fields + where_items + [self.suffix]
  20. self.name = "%s_%s_%s" % (
  21. table_name[:11],
  22. column_names[0][:7],
  23. "%s_%s" % (self._hash_generator(*hash_data), self.suffix),
  24. )
  25. assert len(self.name) <= self.max_name_length, (
  26. "Index too long for multiple database support. Is self.suffix "
  27. "longer than 3 characters?"
  28. )
  29. self.check_name()
  30. def __repr__(self):
  31. if self.where is None:
  32. return super().__repr__()
  33. where_items = []
  34. for key in sorted(self.where.keys()):
  35. where_items.append("=".join([key, repr(self.where[key])]))
  36. return "<%(name)s: fields=%(fields)s, where=%(where)s>" % {
  37. "name": self.__class__.__name__,
  38. "fields": "'%s'" % (", ".join(self.fields)),
  39. "where": "'%s'" % (", ".join(where_items)),
  40. }
  41. def deconstruct(self):
  42. path, args, kwargs = super().deconstruct()
  43. kwargs["where"] = self.where
  44. return path, args, kwargs
  45. def get_sql_create_template_values(self, model, schema_editor, using):
  46. parameters = super().get_sql_create_template_values(model, schema_editor, "")
  47. parameters["extra"] = self.get_sql_extra(model, schema_editor)
  48. return parameters
  49. def get_sql_extra(self, model, schema_editor):
  50. quote_name = schema_editor.quote_name
  51. quote_value = schema_editor.quote_value
  52. clauses = []
  53. for field, condition in self.where.items():
  54. field_name = None
  55. compr = None
  56. if field.endswith("__lt"):
  57. field_name = field[:-4]
  58. compr = "<"
  59. elif field.endswith("__gt"):
  60. field_name = field[:-4]
  61. compr = ">"
  62. elif field.endswith("__lte"):
  63. field_name = field[:-5]
  64. compr = "<="
  65. elif field.endswith("__gte"):
  66. field_name = field[:-5]
  67. compr = ">="
  68. else:
  69. field_name = field
  70. compr = "="
  71. column = model._meta.get_field(field_name).column
  72. clauses.append(
  73. "%s %s %s" % (quote_name(column), compr, quote_value(condition))
  74. )
  75. # sort clauses for their order to be determined and testable
  76. return " WHERE %s" % (" AND ".join(sorted(clauses)))
  77. def chunk_queryset(queryset, chunk_size=20):
  78. ordered_queryset = queryset.order_by("-pk") # bias to newest items first
  79. chunk = ordered_queryset[:chunk_size]
  80. while chunk:
  81. last_pk = None
  82. for item in chunk:
  83. last_pk = item.pk
  84. yield item
  85. chunk = ordered_queryset.filter(pk__lt=last_pk)[:chunk_size]