mirror of
https://github.com/django/django.git
synced 2025-10-24 06:06:09 +00:00
Refs #3254 -- Added full text search to contrib.postgres.
Adds a reasonably feature complete implementation of full text search using the built in PostgreSQL engine. It uses public APIs from Expression and Lookup. With thanks to Tim Graham, Simon Charettes, Josh Smeaton, Mikey Ariel and many others for their advice and review. Particular thanks also go to the supporters of the contrib.postgres kickstarter.
This commit is contained in:
@@ -3,7 +3,7 @@ from django.db.backends.signals import connection_created
|
||||
from django.db.models import CharField, TextField
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
from .lookups import Unaccent
|
||||
from .lookups import SearchLookup, Unaccent
|
||||
from .signals import register_hstore_handler
|
||||
|
||||
|
||||
@@ -15,3 +15,5 @@ class PostgresConfig(AppConfig):
|
||||
connection_created.connect(register_hstore_handler)
|
||||
CharField.register_lookup(Unaccent)
|
||||
TextField.register_lookup(Unaccent)
|
||||
CharField.register_lookup(SearchLookup)
|
||||
TextField.register_lookup(SearchLookup)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from django.db.models import Lookup, Transform
|
||||
|
||||
from .search import SearchVector, SearchVectorExact, SearchVectorField
|
||||
|
||||
|
||||
class PostgresSimpleLookup(Lookup):
|
||||
def as_sql(self, qn, connection):
|
||||
@@ -43,3 +45,13 @@ class Unaccent(Transform):
|
||||
bilateral = True
|
||||
lookup_name = 'unaccent'
|
||||
function = 'UNACCENT'
|
||||
|
||||
|
||||
class SearchLookup(SearchVectorExact):
|
||||
lookup_name = 'search'
|
||||
|
||||
def process_lhs(self, qn, connection):
|
||||
if not isinstance(self.lhs.output_field, SearchVectorField):
|
||||
self.lhs = SearchVector(self.lhs)
|
||||
lhs, lhs_params = super(SearchLookup, self).process_lhs(qn, connection)
|
||||
return lhs, lhs_params
|
||||
|
||||
187
django/contrib/postgres/search.py
Normal file
187
django/contrib/postgres/search.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from django.db.models import Field, FloatField
|
||||
from django.db.models.expressions import CombinedExpression, Func, Value
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.db.models.lookups import Lookup
|
||||
|
||||
|
||||
class SearchVectorExact(Lookup):
|
||||
lookup_name = 'exact'
|
||||
|
||||
def process_rhs(self, qn, connection):
|
||||
if not hasattr(self.rhs, 'resolve_expression'):
|
||||
config = getattr(self.lhs, 'config', None)
|
||||
self.rhs = SearchQuery(self.rhs, config=config)
|
||||
rhs, rhs_params = super(SearchVectorExact, self).process_rhs(qn, connection)
|
||||
return rhs, rhs_params
|
||||
|
||||
def as_sql(self, qn, connection):
|
||||
lhs, lhs_params = self.process_lhs(qn, connection)
|
||||
rhs, rhs_params = self.process_rhs(qn, connection)
|
||||
params = lhs_params + rhs_params
|
||||
return '%s @@ %s = true' % (lhs, rhs), params
|
||||
|
||||
|
||||
class SearchVectorField(Field):
|
||||
|
||||
def db_type(self, connection):
|
||||
return 'tsvector'
|
||||
|
||||
|
||||
class SearchQueryField(Field):
|
||||
|
||||
def db_type(self, connection):
|
||||
return 'tsquery'
|
||||
|
||||
|
||||
class SearchVectorCombinable(object):
|
||||
ADD = '||'
|
||||
|
||||
def _combine(self, other, connector, reversed, node=None):
|
||||
if not isinstance(other, SearchVectorCombinable) or not self.config == other.config:
|
||||
raise TypeError('SearchVector can only be combined with other SearchVectors')
|
||||
if reversed:
|
||||
return CombinedSearchVector(other, connector, self, self.config)
|
||||
return CombinedSearchVector(self, connector, other, self.config)
|
||||
|
||||
|
||||
class SearchVector(SearchVectorCombinable, Func):
|
||||
function = 'to_tsvector'
|
||||
arg_joiner = " || ' ' || "
|
||||
_output_field = SearchVectorField()
|
||||
config = None
|
||||
|
||||
def __init__(self, *expressions, **extra):
|
||||
super(SearchVector, self).__init__(*expressions, **extra)
|
||||
self.source_expressions = [
|
||||
Coalesce(expression, Value('')) for expression in self.source_expressions
|
||||
]
|
||||
self.config = self.extra.get('config', self.config)
|
||||
weight = self.extra.get('weight')
|
||||
if weight is not None and not hasattr(weight, 'resolve_expression'):
|
||||
weight = Value(weight)
|
||||
self.weight = weight
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
resolved = super(SearchVector, self).resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
if self.config:
|
||||
if not hasattr(self.config, 'resolve_expression'):
|
||||
resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
else:
|
||||
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
return resolved
|
||||
|
||||
def as_sql(self, compiler, connection, function=None, template=None):
|
||||
config_params = []
|
||||
if template is None:
|
||||
if self.config:
|
||||
config_sql, config_params = compiler.compile(self.config)
|
||||
template = "%(function)s({}::regconfig, %(expressions)s)".format(config_sql.replace('%', '%%'))
|
||||
else:
|
||||
template = self.template
|
||||
sql, params = super(SearchVector, self).as_sql(compiler, connection, function=function, template=template)
|
||||
extra_params = []
|
||||
if self.weight:
|
||||
weight_sql, extra_params = compiler.compile(self.weight)
|
||||
sql = 'setweight({}, {})'.format(sql, weight_sql)
|
||||
return sql, config_params + params + extra_params
|
||||
|
||||
|
||||
class CombinedSearchVector(SearchVectorCombinable, CombinedExpression):
|
||||
def __init__(self, lhs, connector, rhs, config, output_field=None):
|
||||
self.config = config
|
||||
super(CombinedSearchVector, self).__init__(lhs, connector, rhs, output_field)
|
||||
|
||||
|
||||
class SearchQuery(Value):
|
||||
invert = False
|
||||
_output_field = SearchQueryField()
|
||||
config = None
|
||||
|
||||
BITAND = '&&'
|
||||
BITOR = '||'
|
||||
|
||||
def __init__(self, value, output_field=None, **extra):
|
||||
self.config = extra.pop('config', self.config)
|
||||
self.invert = extra.pop('invert', self.invert)
|
||||
super(SearchQuery, self).__init__(value, output_field=output_field)
|
||||
|
||||
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
|
||||
resolved = super(SearchQuery, self).resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
if self.config:
|
||||
if not hasattr(self.config, 'resolve_expression'):
|
||||
resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
else:
|
||||
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
|
||||
return resolved
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
params = [self.value]
|
||||
if self.config:
|
||||
config_sql, config_params = compiler.compile(self.config)
|
||||
template = 'plainto_tsquery({}::regconfig, %s)'.format(config_sql)
|
||||
params = config_params + [self.value]
|
||||
else:
|
||||
template = 'plainto_tsquery(%s)'
|
||||
if self.invert:
|
||||
template = '!!({})'.format(template)
|
||||
return template, params
|
||||
|
||||
def _combine(self, other, connector, reversed, node=None):
|
||||
combined = super(SearchQuery, self)._combine(other, connector, reversed, node)
|
||||
combined.output_field = SearchQueryField()
|
||||
return combined
|
||||
|
||||
# On Combinable, these are not implemented to reduce confusion with Q. In
|
||||
# this case we are actually (ab)using them to do logical combination so
|
||||
# it's consistent with other usage in Django.
|
||||
def __or__(self, other):
|
||||
return self._combine(other, self.BITOR, False)
|
||||
|
||||
def __ror__(self, other):
|
||||
return self._combine(other, self.BITOR, True)
|
||||
|
||||
def __and__(self, other):
|
||||
return self._combine(other, self.BITAND, False)
|
||||
|
||||
def __rand__(self, other):
|
||||
return self._combine(other, self.BITAND, True)
|
||||
|
||||
def __invert__(self):
|
||||
extra = {
|
||||
'invert': not self.invert,
|
||||
'config': self.config,
|
||||
}
|
||||
return type(self)(self.value, **extra)
|
||||
|
||||
|
||||
class SearchRank(Func):
|
||||
function = 'ts_rank'
|
||||
_output_field = FloatField()
|
||||
|
||||
def __init__(self, vector, query, **extra):
|
||||
if not hasattr(vector, 'resolve_expression'):
|
||||
vector = SearchVector(vector)
|
||||
if not hasattr(query, 'resolve_expression'):
|
||||
query = SearchQuery(query)
|
||||
weights = extra.get('weights')
|
||||
if weights is not None and not hasattr(weights, 'resolve_expression'):
|
||||
weights = Value(weights)
|
||||
self.weights = weights
|
||||
super(SearchRank, self).__init__(vector, query, **extra)
|
||||
|
||||
def as_sql(self, compiler, connection, function=None, template=None):
|
||||
extra_params = []
|
||||
extra_context = {}
|
||||
if template is None and self.extra.get('weights'):
|
||||
if self.weights:
|
||||
template = '%(function)s(%(weights)s, %(expressions)s)'
|
||||
weight_sql, extra_params = compiler.compile(self.weights)
|
||||
extra_context['weights'] = weight_sql
|
||||
sql, params = super(SearchRank, self).as_sql(
|
||||
compiler, connection,
|
||||
function=function, template=template, **extra_context
|
||||
)
|
||||
return sql, extra_params + params
|
||||
|
||||
|
||||
SearchVectorField.register_lookup(SearchVectorExact)
|
||||
@@ -254,3 +254,9 @@ class DatabaseOperations(BaseDatabaseOperations):
|
||||
rhs_sql, rhs_params = rhs
|
||||
return "age(%s, %s)" % (lhs_sql, rhs_sql), lhs_params + rhs_params
|
||||
return super(DatabaseOperations, self).subtract_temporals(internal_type, lhs, rhs)
|
||||
|
||||
def fulltext_search_sql(self, field_name):
|
||||
raise NotImplementedError(
|
||||
"Add 'django.contrib.postgres' to settings.INSTALLED_APPS to use "
|
||||
"the search operator."
|
||||
)
|
||||
|
||||
@@ -125,9 +125,11 @@ class BaseExpression(object):
|
||||
|
||||
# aggregate specific fields
|
||||
is_summary = False
|
||||
_output_field = None
|
||||
|
||||
def __init__(self, output_field=None):
|
||||
self._output_field = output_field
|
||||
if output_field is not None:
|
||||
self._output_field = output_field
|
||||
|
||||
def get_db_converters(self, connection):
|
||||
return [self.convert_value] + self.output_field.get_db_converters(connection)
|
||||
|
||||
Reference in New Issue
Block a user