pgutils.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from django.core.paginator import Paginator
  2. from django.db.models import Index
  3. class PgPartialIndex(Index):
  4. suffix = 'part'
  5. max_name_length = 31
  6. def __init__(self, fields=[], name=None, where=None):
  7. if not where:
  8. raise ValueError('partial index requires WHERE clause')
  9. self.where = where
  10. super().__init__(fields, name)
  11. def set_name_with_model(self, model):
  12. table_name = model._meta.db_table
  13. column_names = sorted(self.where.keys())
  14. where_items = []
  15. for key in sorted(self.where.keys()):
  16. where_items.append('%s:%s' % (key, repr(self.where[key])))
  17. # The length of the parts of the name is based on the default max
  18. # length of 30 characters.
  19. hash_data = [table_name] + self.fields + where_items + [self.suffix]
  20. self.name = '%s_%s_%s' % (
  21. table_name[:11],
  22. column_names[0][:7],
  23. '%s_%s' % (self._hash_generator(*hash_data), self.suffix),
  24. )
  25. assert len(self.name) <= self.max_name_length, (
  26. 'Index too long for multiple database support. Is self.suffix '
  27. 'longer than 3 characters?'
  28. )
  29. self.check_name()
  30. def __repr__(self):
  31. if self.where is not None:
  32. where_items = []
  33. for key in sorted(self.where.keys()):
  34. where_items.append('='.join([
  35. key,
  36. repr(self.where[key])
  37. ]))
  38. return '<%(name)s: fields=%(fields)s, where=%(where)s>' % {
  39. 'name': self.__class__.__name__,
  40. 'fields': "'%s'" % (', '.join(self.fields)),
  41. 'where': "'%s'" % (', '.join(where_items)),
  42. }
  43. else:
  44. return super().__repr__()
  45. def deconstruct(self):
  46. path, args, kwargs = super().deconstruct()
  47. kwargs['where'] = self.where
  48. return path, args, kwargs
  49. def get_sql_create_template_values(self, model, schema_editor, using):
  50. parameters = super().get_sql_create_template_values(
  51. model, schema_editor, '')
  52. parameters['extra'] = self.get_sql_extra(model, schema_editor)
  53. return parameters
  54. def get_sql_extra(self, model, schema_editor):
  55. quote_name = schema_editor.quote_name
  56. quote_value = schema_editor.quote_value
  57. clauses = []
  58. for field, condition in self.where.items():
  59. field_name = None
  60. compr = None
  61. if field.endswith('__lt'):
  62. field_name = field[:-4]
  63. compr = '<'
  64. elif field.endswith('__gt'):
  65. field_name = field[:-4]
  66. compr = '>'
  67. elif field.endswith('__lte'):
  68. field_name = field[:-5]
  69. compr = '<='
  70. elif field.endswith('__gte'):
  71. field_name = field[:-5]
  72. compr = '>='
  73. else:
  74. field_name = field
  75. compr = '='
  76. column = model._meta.get_field(field_name).column
  77. clauses.append('%s %s %s' % (quote_name(column), compr, quote_value(condition)))
  78. # sort clauses for their order to be determined and testable
  79. return ' WHERE %s' % (' AND '.join(sorted(clauses)))
  80. def chunk_queryset(queryset, chunk_size=20):
  81. ordered_queryset = queryset.order_by('-pk') # bias to newest items first
  82. chunk = ordered_queryset[:chunk_size]
  83. while chunk:
  84. last_pk = None
  85. for item in chunk:
  86. last_pk = item.pk
  87. yield item
  88. chunk = ordered_queryset.filter(pk__lt=last_pk)[:chunk_size]