pgutils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from __future__ import unicode_literals
  2. from django.core.paginator import Paginator
  3. from django.db.models import Index
  4. class PgPartialIndex(Index):
  5. suffix = 'part'
  6. max_name_length = 31
  7. def __init__(self, fields=[], name=None, where=None):
  8. if not where:
  9. raise ValueError('partial index requires WHERE clause')
  10. self.where = where
  11. super(PgPartialIndex, self).__init__(fields, name)
  12. def set_name_with_model(self, model):
  13. table_name = model._meta.db_table
  14. column_names = sorted(self.where.keys())
  15. where_items = []
  16. for key in sorted(self.where.keys()):
  17. where_items.append('{}:{}'.format(key, repr(self.where[key])))
  18. # The length of the parts of the name is based on the default max
  19. # length of 30 characters.
  20. hash_data = [table_name] + self.fields + where_items + [self.suffix]
  21. self.name = '%s_%s_%s' % (
  22. table_name[:11],
  23. column_names[0][:7],
  24. '%s_%s' % (self._hash_generator(*hash_data), self.suffix),
  25. )
  26. assert len(self.name) <= self.max_name_length, (
  27. 'Index too long for multiple database support. Is self.suffix '
  28. 'longer than 3 characters?'
  29. )
  30. self.check_name()
  31. def __repr__(self):
  32. if self.where is not None:
  33. where_items = []
  34. for key in sorted(self.where.keys()):
  35. where_items.append('='.join([
  36. key,
  37. repr(self.where[key])
  38. ]))
  39. return '<%(name)s: fields=%(fields)s, where=%(where)s>' % {
  40. 'name': self.__class__.__name__,
  41. 'fields': "'{}'".format(', '.join(self.fields)),
  42. 'where': "'{}'".format(', '.join(where_items)),
  43. }
  44. else:
  45. return super(PgPartialIndex, self).__repr__()
  46. def deconstruct(self):
  47. path, args, kwargs = super(PgPartialIndex, self).deconstruct()
  48. kwargs['where'] = self.where
  49. return path, args, kwargs
  50. def get_sql_create_template_values(self, model, schema_editor, using):
  51. parameters = super(PgPartialIndex, self).get_sql_create_template_values(
  52. model, schema_editor, '')
  53. parameters['extra'] = self.get_sql_extra(model, schema_editor)
  54. return parameters
  55. def get_sql_extra(self, model, schema_editor):
  56. quote_name = schema_editor.quote_name
  57. quote_value = schema_editor.quote_value
  58. clauses = []
  59. for field, condition in self.where.items():
  60. field_name = None
  61. compr = None
  62. if field.endswith('__lt'):
  63. field_name = field[:-4]
  64. compr = '<'
  65. elif field.endswith('__gt'):
  66. field_name = field[:-4]
  67. compr = '>'
  68. elif field.endswith('__lte'):
  69. field_name = field[:-5]
  70. compr = '<='
  71. elif field.endswith('__gte'):
  72. field_name = field[:-5]
  73. compr = '>='
  74. else:
  75. field_name = field
  76. compr = '='
  77. column = model._meta.get_field(field_name).column
  78. clauses.append('{} {} {}'.format(
  79. quote_name(column), compr, quote_value(condition)))
  80. # sort clauses for their order to be determined and testable
  81. return ' WHERE {}'.format(' AND '.join(sorted(clauses)))
  82. def batch_update(queryset, step=50):
  83. """util because psycopg2 iterators aren't memory effective in Dj<1.11"""
  84. paginator = Paginator(queryset.order_by('pk'), step)
  85. for page_number in paginator.page_range:
  86. for obj in paginator.page(page_number).object_list:
  87. yield obj
  88. def batch_delete(queryset, step=50):
  89. """another util cos paginator goes bobbins when you are deleting"""
  90. queryset_exists = True
  91. while queryset_exists:
  92. for obj in queryset[:step]:
  93. yield obj
  94. queryset_exists = queryset.exists()