pgutils.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. from __future__ import unicode_literals
  2. from django.core.paginator import Paginator
  3. from django.db.models import Index
  4. from django.db.migrations.operations import RunSQL
  5. from django.utils.six import text_type
  6. class PgPartialIndex(Index):
  7. suffix = 'part'
  8. max_name_length = 31
  9. def __init__(self, fields=[], name=None, where=None):
  10. if not where:
  11. raise ValueError('partial index requires WHERE clause')
  12. self.where = where
  13. super(PgPartialIndex, self).__init__(fields, name)
  14. def set_name_with_model(self, model):
  15. table_name = model._meta.db_table
  16. column_names = sorted(self.where.keys())
  17. where_items = []
  18. for key in sorted(self.where.keys()):
  19. where_items.append('{}:{}'.format(key, repr(self.where[key])))
  20. # The length of the parts of the name is based on the default max
  21. # length of 30 characters.
  22. hash_data = [table_name] + self.fields + where_items + [self.suffix]
  23. self.name = '%s_%s_%s' % (
  24. table_name[:11],
  25. column_names[0][:7],
  26. '%s_%s' % (self._hash_generator(*hash_data), self.suffix),
  27. )
  28. assert len(self.name) <= self.max_name_length, (
  29. 'Index too long for multiple database support. Is self.suffix '
  30. 'longer than 3 characters?'
  31. )
  32. self.check_name()
  33. def __repr__(self):
  34. if self.where is not None:
  35. where_items = []
  36. for key in sorted(self.where.keys()):
  37. where_items.append('='.join([
  38. key,
  39. repr(self.where[key])
  40. ]))
  41. return '<%(name)s: fields=%(fields)s, where=%(where)s>' % {
  42. 'name': self.__class__.__name__,
  43. 'fields': "'{}'".format(', '.join(self.fields)),
  44. 'where': "'{}'".format(', '.join(where_items)),
  45. }
  46. else:
  47. return super(PgPartialIndex, self).__repr__()
  48. def deconstruct(self):
  49. path, args, kwargs = super(PgPartialIndex, self).deconstruct()
  50. kwargs['where'] = self.where
  51. return path, args, kwargs
  52. def get_sql_create_template_values(self, model, schema_editor, using):
  53. parameters = super(PgPartialIndex, self).get_sql_create_template_values(
  54. model, schema_editor, '')
  55. parameters['extra'] = self.get_sql_extra(model, schema_editor)
  56. return parameters
  57. def get_sql_extra(self, model, schema_editor):
  58. quote_name = schema_editor.quote_name
  59. quote_value = schema_editor.quote_value
  60. clauses = []
  61. for field, condition in self.where.items():
  62. field_name = None
  63. compr = None
  64. if field.endswith('__lt'):
  65. field_name = field[:-4]
  66. compr = '<'
  67. elif field.endswith('__gt'):
  68. field_name = field[:-4]
  69. compr = '>'
  70. elif field.endswith('__lte'):
  71. field_name = field[:-5]
  72. compr = '<='
  73. elif field.endswith('__gte'):
  74. field_name = field[:-5]
  75. compr = '>='
  76. else:
  77. field_name = field
  78. compr = '='
  79. column = model._meta.get_field(field_name).column
  80. clauses.append('{} {} {}'.format(
  81. quote_name(column), compr, quote_value(condition)))
  82. # sort clauses for their order to be determined and testable
  83. return ' WHERE {}'.format(' AND '.join(sorted(clauses)))
  84. class CreatePartialIndex(RunSQL):
  85. CREATE_SQL = """
  86. CREATE INDEX %(index_name)s ON %(table)s (%(field)s)
  87. WHERE %(condition)s;
  88. """
  89. REMOVE_SQL = """
  90. DROP INDEX %(index_name)s
  91. """
  92. def __init__(self, field, index_name, condition):
  93. self.model, self.field = field.split('.')
  94. self.index_name = index_name
  95. self.condition = condition
  96. @property
  97. def reversible(self):
  98. return True
  99. def state_forwards(self, app_label, state):
  100. pass
  101. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  102. model = from_state.apps.get_model(app_label, self.model)
  103. statement = self.CREATE_SQL % {
  104. 'index_name': self.index_name,
  105. 'table': model._meta.db_table,
  106. 'field': self.field,
  107. 'condition': self.condition,
  108. }
  109. schema_editor.execute(statement)
  110. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  111. schema_editor.execute(self.REMOVE_SQL % {'index_name': self.index_name})
  112. def describe(self):
  113. message = "Create PostgreSQL partial index on field %s in %s for %s"
  114. formats = (self.field, self.model_name, self.values)
  115. return message % formats
  116. class CreatePartialCompositeIndex(CreatePartialIndex):
  117. CREATE_SQL = """
  118. CREATE INDEX %(index_name)s ON %(table)s (%(fields)s)
  119. WHERE %(condition)s;
  120. """
  121. REMOVE_SQL = """
  122. DROP INDEX %(index_name)s
  123. """
  124. def __init__(self, model, fields, index_name, condition):
  125. self.model = model
  126. self.fields = fields
  127. self.index_name = index_name
  128. self.condition = condition
  129. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  130. model = from_state.apps.get_model(app_label, self.model)
  131. statement = self.CREATE_SQL % {
  132. 'index_name': self.index_name,
  133. 'table': model._meta.db_table,
  134. 'fields': ', '.join(self.fields),
  135. 'condition': self.condition,
  136. }
  137. schema_editor.execute(statement)
  138. def describe(self):
  139. message = ("Create PostgreSQL partial composite index on fields %s in %s for %s")
  140. formats = (', '.join(self.fields), self.model_name, self.values)
  141. return message % formats
  142. def batch_update(queryset, step=50):
  143. """util because psycopg2 iterators aren't memory effective in Dj<1.11"""
  144. paginator = Paginator(queryset.order_by('pk'), step)
  145. for page_number in paginator.page_range:
  146. for obj in paginator.page(page_number).object_list:
  147. yield obj
  148. def batch_delete(queryset, step=50):
  149. """another util cos paginator goes bobbins when you are deleting"""
  150. queryset_exists = True
  151. while queryset_exists:
  152. for obj in queryset[:step]:
  153. yield obj
  154. queryset_exists = queryset.exists()