1
0
mirror of https://github.com/django/django.git synced 2025-10-25 06:36:07 +00:00

Refs #33308 -- Added get_type_oids() hook and simplified registering type handlers on PostgreSQL.

This commit is contained in:
Daniele Varrazzo
2022-12-01 10:59:20 +01:00
committed by Mariusz Felisiak
parent 149b55fefa
commit d3e746ace5
2 changed files with 27 additions and 33 deletions

View File

@@ -1,22 +1,16 @@
import functools import functools
import psycopg2 import psycopg2
from psycopg2 import ProgrammingError
from psycopg2.extras import register_hstore from psycopg2.extras import register_hstore
from django.db import connections from django.db import connections
from django.db.backends.base.base import NO_DB_ALIAS from django.db.backends.base.base import NO_DB_ALIAS
@functools.lru_cache def get_type_oids(connection_alias, type_name):
def get_hstore_oids(connection_alias):
"""Return hstore and hstore array OIDs."""
with connections[connection_alias].cursor() as cursor: with connections[connection_alias].cursor() as cursor:
cursor.execute( cursor.execute(
"SELECT t.oid, typarray " "SELECT oid, typarray FROM pg_type WHERE typname = %s", (type_name,)
"FROM pg_type t "
"JOIN pg_namespace ns ON typnamespace = ns.oid "
"WHERE typname = 'hstore'"
) )
oids = [] oids = []
array_oids = [] array_oids = []
@@ -26,43 +20,42 @@ def get_hstore_oids(connection_alias):
return tuple(oids), tuple(array_oids) return tuple(oids), tuple(array_oids)
@functools.lru_cache
def get_hstore_oids(connection_alias):
"""Return hstore and hstore array OIDs."""
return get_type_oids(connection_alias, "hstore")
@functools.lru_cache @functools.lru_cache
def get_citext_oids(connection_alias): def get_citext_oids(connection_alias):
"""Return citext array OIDs.""" """Return citext and citext array OIDs."""
with connections[connection_alias].cursor() as cursor: return get_type_oids(connection_alias, "citext")
cursor.execute("SELECT typarray FROM pg_type WHERE typname = 'citext'")
return tuple(row[0] for row in cursor)
def register_type_handlers(connection, **kwargs): def register_type_handlers(connection, **kwargs):
if connection.vendor != "postgresql" or connection.alias == NO_DB_ALIAS: if connection.vendor != "postgresql" or connection.alias == NO_DB_ALIAS:
return return
try: oids, array_oids = get_hstore_oids(connection.alias)
oids, array_oids = get_hstore_oids(connection.alias) # Don't register handlers when hstore is not available on the database.
#
# If someone tries to create an hstore field it will error there. This is
# necessary as someone may be using PSQL without extensions installed but
# be using other features of contrib.postgres.
#
# This is also needed in order to create the connection in order to install
# the hstore extension.
if oids:
register_hstore( register_hstore(
connection.connection, globally=True, oid=oids, array_oid=array_oids connection.connection, globally=True, oid=oids, array_oid=array_oids
) )
except ProgrammingError:
# Hstore is not available on the database.
#
# If someone tries to create an hstore field it will error there.
# This is necessary as someone may be using PSQL without extensions
# installed but be using other features of contrib.postgres.
#
# This is also needed in order to create the connection in order to
# install the hstore extension.
pass
try: oids, citext_oids = get_citext_oids(connection.alias)
citext_oids = get_citext_oids(connection.alias) # Don't register handlers when citext is not available on the database.
#
# The same comments in the above call to register_hstore() also apply here.
if oids:
array_type = psycopg2.extensions.new_array_type( array_type = psycopg2.extensions.new_array_type(
citext_oids, "citext[]", psycopg2.STRING citext_oids, "citext[]", psycopg2.STRING
) )
psycopg2.extensions.register_type(array_type, None) psycopg2.extensions.register_type(array_type, None)
except ProgrammingError:
# citext is not available on the database.
#
# The same comments in the except block of the above call to
# register_hstore() also apply here.
pass

View File

@@ -34,8 +34,9 @@ class OIDTests(PostgreSQLTestCase):
self.assertOIDs(array_oids) self.assertOIDs(array_oids)
def test_citext_values(self): def test_citext_values(self):
oids = get_citext_oids(connection.alias) oids, citext_oids = get_citext_oids(connection.alias)
self.assertOIDs(oids) self.assertOIDs(oids)
self.assertOIDs(citext_oids)
def test_register_type_handlers_no_db(self): def test_register_type_handlers_no_db(self):
"""Registering type handlers for the nodb connection does nothing.""" """Registering type handlers for the nodb connection does nothing."""