1
0
mirror of https://github.com/django/django.git synced 2025-05-04 22:17:34 +00:00

Fixed #34411 -- Updated GDAL API to handle vector DataSource's.

Co-authored-by: David Smith <smithdc@gmail.com>
This commit is contained in:
Claude Paroz 2023-03-13 18:38:23 +01:00 committed by Mariusz Felisiak
parent 77278929c8
commit 08306bad57
5 changed files with 58 additions and 76 deletions

View File

@ -33,7 +33,6 @@
# OFTReal returns floats, all else returns string. # OFTReal returns floats, all else returns string.
val = field.value val = field.value
""" """
from ctypes import byref
from pathlib import Path from pathlib import Path
from django.contrib.gis.gdal.base import GDALBase from django.contrib.gis.gdal.base import GDALBase
@ -54,21 +53,22 @@ class DataSource(GDALBase):
def __init__(self, ds_input, ds_driver=False, write=False, encoding="utf-8"): def __init__(self, ds_input, ds_driver=False, write=False, encoding="utf-8"):
# The write flag. # The write flag.
if write: self._write = capi.GDAL_OF_UPDATE if write else capi.GDAL_OF_READONLY
self._write = 1
else:
self._write = 0
# See also https://gdal.org/development/rfc/rfc23_ogr_unicode.html # See also https://gdal.org/development/rfc/rfc23_ogr_unicode.html
self.encoding = encoding self.encoding = encoding
Driver.ensure_registered() Driver.ensure_registered()
if isinstance(ds_input, (str, Path)): if isinstance(ds_input, (str, Path)):
# The data source driver is a void pointer.
ds_driver = Driver.ptr_type()
try: try:
# OGROpen will auto-detect the data source type. # GDALOpenEx will auto-detect the data source type.
ds = capi.open_ds(force_bytes(ds_input), self._write, byref(ds_driver)) ds = capi.open_ds(
force_bytes(ds_input),
self._write | capi.GDAL_OF_VECTOR,
None,
None,
None,
)
except GDALException: except GDALException:
# Making the error message more clear rather than something # Making the error message more clear rather than something
# like "Invalid pointer returned from OGROpen". # like "Invalid pointer returned from OGROpen".
@ -82,7 +82,8 @@ class DataSource(GDALBase):
if ds: if ds:
self.ptr = ds self.ptr = ds
self.driver = Driver(ds_driver) driver = capi.get_dataset_driver(ds)
self.driver = Driver(driver)
else: else:
# Raise an exception if the returned pointer is NULL # Raise an exception if the returned pointer is NULL
raise GDALException('Invalid data source file "%s"' % ds_input) raise GDALException('Invalid data source file "%s"' % ds_input)

View File

@ -2,8 +2,7 @@ from ctypes import c_void_p
from django.contrib.gis.gdal.base import GDALBase from django.contrib.gis.gdal.base import GDALBase
from django.contrib.gis.gdal.error import GDALException from django.contrib.gis.gdal.error import GDALException
from django.contrib.gis.gdal.prototypes import ds as vcapi from django.contrib.gis.gdal.prototypes import ds as capi
from django.contrib.gis.gdal.prototypes import raster as rcapi
from django.utils.encoding import force_bytes, force_str from django.utils.encoding import force_bytes, force_str
@ -49,16 +48,10 @@ class Driver(GDALBase):
name = dr_input name = dr_input
# Attempting to get the GDAL/OGR driver by the string name. # Attempting to get the GDAL/OGR driver by the string name.
for iface in (vcapi, rcapi): driver = c_void_p(capi.get_driver_by_name(force_bytes(name)))
driver = c_void_p(iface.get_driver_by_name(force_bytes(name)))
if driver:
break
elif isinstance(dr_input, int): elif isinstance(dr_input, int):
self.ensure_registered() self.ensure_registered()
for iface in (vcapi, rcapi): driver = capi.get_driver(dr_input)
driver = iface.get_driver(dr_input)
if driver:
break
elif isinstance(dr_input, c_void_p): elif isinstance(dr_input, c_void_p):
driver = dr_input driver = dr_input
else: else:
@ -81,23 +74,21 @@ class Driver(GDALBase):
""" """
Attempt to register all the data source drivers. Attempt to register all the data source drivers.
""" """
# Only register all if the driver counts are 0 (or else all drivers # Only register all if the driver count is 0 (or else all drivers will
# will be registered over and over again) # be registered over and over again).
if not vcapi.get_driver_count(): if not capi.get_driver_count():
vcapi.register_all() capi.register_all()
if not rcapi.get_driver_count():
rcapi.register_all()
@classmethod @classmethod
def driver_count(cls): def driver_count(cls):
""" """
Return the number of GDAL/OGR data source drivers registered. Return the number of GDAL/OGR data source drivers registered.
""" """
return vcapi.get_driver_count() + rcapi.get_driver_count() return capi.get_driver_count()
@property @property
def name(self): def name(self):
""" """
Return description/name string for this driver. Return description/name string for this driver.
""" """
return force_str(rcapi.get_driver_description(self.ptr)) return force_str(capi.get_driver_description(self.ptr))

View File

@ -3,7 +3,7 @@
related data structures. OGR_Dr_*, OGR_DS_*, OGR_L_*, OGR_F_*, related data structures. OGR_Dr_*, OGR_DS_*, OGR_L_*, OGR_F_*,
OGR_Fld_* routines are relevant here. OGR_Fld_* routines are relevant here.
""" """
from ctypes import POINTER, c_char_p, c_double, c_int, c_long, c_void_p from ctypes import POINTER, c_char_p, c_double, c_int, c_long, c_uint, c_void_p
from django.contrib.gis.gdal.envelope import OGREnvelope from django.contrib.gis.gdal.envelope import OGREnvelope
from django.contrib.gis.gdal.libgdal import lgdal from django.contrib.gis.gdal.libgdal import lgdal
@ -21,26 +21,36 @@ from django.contrib.gis.gdal.prototypes.generation import (
c_int_p = POINTER(c_int) # shortcut type c_int_p = POINTER(c_int) # shortcut type
GDAL_OF_READONLY = 0x00
GDAL_OF_UPDATE = 0x01
GDAL_OF_ALL = 0x00
GDAL_OF_RASTER = 0x02
GDAL_OF_VECTOR = 0x04
# Driver Routines # Driver Routines
register_all = void_output(lgdal.OGRRegisterAll, [], errcheck=False) register_all = void_output(lgdal.GDALAllRegister, [], errcheck=False)
cleanup_all = void_output(lgdal.OGRCleanupAll, [], errcheck=False) cleanup_all = void_output(lgdal.GDALDestroyDriverManager, [], errcheck=False)
get_driver = voidptr_output(lgdal.OGRGetDriver, [c_int]) get_driver = voidptr_output(lgdal.GDALGetDriver, [c_int])
get_driver_by_name = voidptr_output( get_driver_by_name = voidptr_output(
lgdal.OGRGetDriverByName, [c_char_p], errcheck=False lgdal.GDALGetDriverByName, [c_char_p], errcheck=False
)
get_driver_count = int_output(lgdal.OGRGetDriverCount, [])
get_driver_name = const_string_output(
lgdal.OGR_Dr_GetName, [c_void_p], decoding="ascii"
) )
get_driver_count = int_output(lgdal.GDALGetDriverCount, [])
get_driver_description = const_string_output(lgdal.GDALGetDescription, [c_void_p])
# DataSource # DataSource
open_ds = voidptr_output(lgdal.OGROpen, [c_char_p, c_int, POINTER(c_void_p)]) open_ds = voidptr_output(
destroy_ds = void_output(lgdal.OGR_DS_Destroy, [c_void_p], errcheck=False) lgdal.GDALOpenEx,
release_ds = void_output(lgdal.OGRReleaseDataSource, [c_void_p]) [c_char_p, c_uint, POINTER(c_char_p), POINTER(c_char_p), POINTER(c_char_p)],
get_ds_name = const_string_output(lgdal.OGR_DS_GetName, [c_void_p]) )
get_layer = voidptr_output(lgdal.OGR_DS_GetLayer, [c_void_p, c_int]) destroy_ds = void_output(lgdal.GDALClose, [c_void_p], errcheck=False)
get_layer_by_name = voidptr_output(lgdal.OGR_DS_GetLayerByName, [c_void_p, c_char_p]) get_ds_name = const_string_output(lgdal.GDALGetDescription, [c_void_p])
get_layer_count = int_output(lgdal.OGR_DS_GetLayerCount, [c_void_p]) get_dataset_driver = voidptr_output(lgdal.GDALGetDatasetDriver, [c_void_p])
get_layer = voidptr_output(lgdal.GDALDatasetGetLayer, [c_void_p, c_int])
get_layer_by_name = voidptr_output(
lgdal.GDALDatasetGetLayerByName, [c_void_p, c_char_p]
)
get_layer_count = int_output(lgdal.GDALDatasetGetLayerCount, [c_void_p])
# Layer Routines # Layer Routines
get_extent = void_output(lgdal.OGR_L_GetExtent, [c_void_p, POINTER(OGREnvelope), c_int]) get_extent = void_output(lgdal.OGR_L_GetExtent, [c_void_p, POINTER(OGREnvelope), c_int])

View File

@ -25,15 +25,6 @@ void_output = partial(void_output, cpl=True)
const_string_output = partial(const_string_output, cpl=True) const_string_output = partial(const_string_output, cpl=True)
double_output = partial(double_output, cpl=True) double_output = partial(double_output, cpl=True)
# Raster Driver Routines
register_all = void_output(std_call("GDALAllRegister"), [], errcheck=False)
get_driver = voidptr_output(std_call("GDALGetDriver"), [c_int])
get_driver_by_name = voidptr_output(
std_call("GDALGetDriverByName"), [c_char_p], errcheck=False
)
get_driver_count = int_output(std_call("GDALGetDriverCount"), [])
get_driver_description = const_string_output(std_call("GDALGetDescription"), [c_void_p])
# Raster Data Source Routines # Raster Data Source Routines
create_ds = voidptr_output( create_ds = voidptr_output(
std_call("GDALCreate"), [c_void_p, c_char_p, c_int, c_int, c_int, c_int, c_void_p] std_call("GDALCreate"), [c_void_p, c_char_p, c_int, c_int, c_int, c_int, c_void_p]

View File

@ -54,32 +54,21 @@ class DriverTest(unittest.TestCase):
dr = Driver(alias) dr = Driver(alias)
self.assertEqual(full_name, str(dr)) self.assertEqual(full_name, str(dr))
@mock.patch("django.contrib.gis.gdal.driver.vcapi.get_driver_count") @mock.patch("django.contrib.gis.gdal.driver.capi.get_driver_count")
@mock.patch("django.contrib.gis.gdal.driver.rcapi.get_driver_count") @mock.patch("django.contrib.gis.gdal.driver.capi.register_all")
@mock.patch("django.contrib.gis.gdal.driver.vcapi.register_all") def test_registered(self, reg, count):
@mock.patch("django.contrib.gis.gdal.driver.rcapi.register_all")
def test_registered(self, rreg, vreg, rcount, vcount):
""" """
Prototypes are registered only if their respective driver counts are Prototypes are registered only if the driver count is zero.
zero.
""" """
def check(rcount_val, vcount_val): def check(count_val):
vreg.reset_mock() reg.reset_mock()
rreg.reset_mock() count.return_value = count_val
rcount.return_value = rcount_val
vcount.return_value = vcount_val
Driver.ensure_registered() Driver.ensure_registered()
if rcount_val: if count_val:
self.assertFalse(rreg.called) self.assertFalse(reg.called)
else: else:
rreg.assert_called_once_with() reg.assert_called_once_with()
if vcount_val:
self.assertFalse(vreg.called)
else:
vreg.assert_called_once_with()
check(0, 0) check(0)
check(120, 0) check(120)
check(0, 120)
check(120, 120)