From 00fa1474d7d2dbd5530a650ea9fe09f36bff77c7 Mon Sep 17 00:00:00 2001
From: Claude Paroz <claude@2xlibre.net>
Date: Tue, 2 Dec 2014 17:53:04 +0100
Subject: [PATCH] Added raster support for GDAL Driver class

Based on Daniel Wiesmann's work. Refs #23804.
---
 django/contrib/gis/gdal/datasource.py        |   5 +-
 django/contrib/gis/gdal/driver.py            | 105 ++++++++++++-------
 django/contrib/gis/gdal/prototypes/ds.py     |   2 +-
 django/contrib/gis/gdal/tests/test_driver.py |  33 +++---
 4 files changed, 89 insertions(+), 56 deletions(-)

diff --git a/django/contrib/gis/gdal/datasource.py b/django/contrib/gis/gdal/datasource.py
index 6caa2a9068..d8f4438434 100644
--- a/django/contrib/gis/gdal/datasource.py
+++ b/django/contrib/gis/gdal/datasource.py
@@ -67,10 +67,7 @@ class DataSource(GDALBase):
         # See also http://trac.osgeo.org/gdal/wiki/rfc23_ogr_unicode
         self.encoding = encoding
 
-        # Registering all the drivers, this needs to be done
-        #  _before_ we try to open up a data source.
-        if not capi.get_driver_count():
-            capi.register_all()
+        Driver.ensure_registered()
 
         if isinstance(ds_input, six.string_types):
             # The data source driver is a void pointer.
diff --git a/django/contrib/gis/gdal/driver.py b/django/contrib/gis/gdal/driver.py
index 58f736b1a3..1d2796480d 100644
--- a/django/contrib/gis/gdal/driver.py
+++ b/django/contrib/gis/gdal/driver.py
@@ -1,70 +1,97 @@
-# prerequisites imports
 from ctypes import c_void_p
 from django.contrib.gis.gdal.base import GDALBase
 from django.contrib.gis.gdal.error import OGRException
-from django.contrib.gis.gdal.prototypes import ds as capi
+from django.contrib.gis.gdal.prototypes import ds as vcapi, raster as rcapi
 
 from django.utils import six
-from django.utils.encoding import force_bytes
+from django.utils.encoding import force_bytes, force_text
 
 
-# For more information, see the OGR C API source code:
-#  http://www.gdal.org/ogr/ogr__api_8h.html
-#
-# The OGR_Dr_* routines are relevant here.
 class Driver(GDALBase):
-    "Wraps an OGR Data Source Driver."
+    """
+    Wraps a GDAL/OGR Data Source Driver.
+    For more information, see the C API source code:
+    http://www.gdal.org/gdal_8h.html - http://www.gdal.org/ogr__api_8h.html
+    """
 
-    # Case-insensitive aliases for OGR Drivers.
-    _alias = {'esri': 'ESRI Shapefile',
-              'shp': 'ESRI Shapefile',
-              'shape': 'ESRI Shapefile',
-              'tiger': 'TIGER',
-              'tiger/line': 'TIGER',
-              }
+    # Case-insensitive aliases for some GDAL/OGR Drivers.
+    # For a complete list of original driver names see
+    # http://www.gdal.org/ogr_formats.html (vector)
+    # http://www.gdal.org/formats_list.html (raster)
+    _alias = {
+        # vector
+        'esri': 'ESRI Shapefile',
+        'shp': 'ESRI Shapefile',
+        'shape': 'ESRI Shapefile',
+        'tiger': 'TIGER',
+        'tiger/line': 'TIGER',
+        # raster
+        'tiff': 'GTiff',
+        'tif': 'GTiff',
+        'jpeg': 'JPEG',
+        'jpg': 'JPEG',
+    }
 
     def __init__(self, dr_input):
-        "Initializes an OGR driver on either a string or integer input."
-
+        """
+        Initializes an GDAL/OGR driver on either a string or integer input.
+        """
         if isinstance(dr_input, six.string_types):
             # If a string name of the driver was passed in
-            self._register()
+            self.ensure_registered()
 
-            # Checking the alias dictionary (case-insensitive) to see if an alias
-            #  exists for the given driver.
+            # Checking the alias dictionary (case-insensitive) to see if an
+            # alias exists for the given driver.
             if dr_input.lower() in self._alias:
                 name = self._alias[dr_input.lower()]
             else:
                 name = dr_input
 
-            # Attempting to get the OGR driver by the string name.
-            dr = capi.get_driver_by_name(force_bytes(name))
+            # Attempting to get the GDAL/OGR driver by the string name.
+            for iface in (vcapi, rcapi):
+                driver = iface.get_driver_by_name(force_bytes(name))
+                if driver:
+                    break
         elif isinstance(dr_input, int):
-            self._register()
-            dr = capi.get_driver(dr_input)
+            self.ensure_registered()
+            for iface in (vcapi, rcapi):
+                driver = iface.get_driver(dr_input)
+                if driver:
+                    break
         elif isinstance(dr_input, c_void_p):
-            dr = dr_input
+            driver = dr_input
         else:
-            raise OGRException('Unrecognized input type for OGR Driver: %s' % str(type(dr_input)))
+            raise OGRException('Unrecognized input type for GDAL/OGR Driver: %s' % str(type(dr_input)))
 
         # Making sure we get a valid pointer to the OGR Driver
-        if not dr:
-            raise OGRException('Could not initialize OGR Driver on input: %s' % str(dr_input))
-        self.ptr = dr
+        if not driver:
+            raise OGRException('Could not initialize GDAL/OGR Driver on input: %s' % str(dr_input))
+        self.ptr = driver
 
     def __str__(self):
-        "Returns the string name of the OGR Driver."
-        return capi.get_driver_name(self.ptr)
+        return self.name
 
-    def _register(self):
-        "Attempts to register all the data source drivers."
+    @classmethod
+    def ensure_registered(cls):
+        """
+        Attempts to register all the data source drivers.
+        """
         # Only register all if the driver count is 0 (or else all drivers
         # will be registered over and over again)
-        if not self.driver_count:
-            capi.register_all()
+        if not cls.driver_count():
+            vcapi.register_all()
+            rcapi.register_all()
+
+    @classmethod
+    def driver_count(cls):
+        """
+        Returns the number of GDAL/OGR data source drivers registered.
+        """
+        return vcapi.get_driver_count() + rcapi.get_driver_count()
 
-    # Driver properties
     @property
-    def driver_count(self):
-        "Returns the number of OGR data source drivers registered."
-        return capi.get_driver_count()
+    def name(self):
+        """
+        Returns description/name string for this driver.
+        """
+        return force_text(rcapi.get_driver_description(self.ptr))
diff --git a/django/contrib/gis/gdal/prototypes/ds.py b/django/contrib/gis/gdal/prototypes/ds.py
index 175503eba6..69257a1140 100644
--- a/django/contrib/gis/gdal/prototypes/ds.py
+++ b/django/contrib/gis/gdal/prototypes/ds.py
@@ -15,7 +15,7 @@ c_int_p = POINTER(c_int)  # shortcut type
 register_all = void_output(lgdal.OGRRegisterAll, [], errcheck=False)
 cleanup_all = void_output(lgdal.OGRCleanupAll, [], errcheck=False)
 get_driver = voidptr_output(lgdal.OGRGetDriver, [c_int])
-get_driver_by_name = voidptr_output(lgdal.OGRGetDriverByName, [c_char_p])
+get_driver_by_name = voidptr_output(lgdal.OGRGetDriverByName, [c_char_p], errcheck=False)
 get_driver_count = int_output(lgdal.OGRGetDriverCount, [])
 get_driver_name = const_string_output(lgdal.OGR_Dr_GetName, [c_void_p], decoding='ascii')
 
diff --git a/django/contrib/gis/gdal/tests/test_driver.py b/django/contrib/gis/gdal/tests/test_driver.py
index 89c65afeb9..88a8c43585 100644
--- a/django/contrib/gis/gdal/tests/test_driver.py
+++ b/django/contrib/gis/gdal/tests/test_driver.py
@@ -1,5 +1,4 @@
 import unittest
-from unittest import skipUnless
 
 from django.contrib.gis.gdal import HAS_GDAL
 
@@ -7,29 +6,39 @@ if HAS_GDAL:
     from django.contrib.gis.gdal import Driver, OGRException
 
 
-valid_drivers = ('ESRI Shapefile', 'MapInfo File', 'TIGER', 'S57', 'DGN',
-                 'Memory', 'CSV', 'GML', 'KML')
+valid_drivers = (
+    # vector
+    'ESRI Shapefile', 'MapInfo File', 'TIGER', 'S57', 'DGN', 'Memory', 'CSV',
+    'GML', 'KML',
+    # raster
+    'GTiff', 'JPEG', 'netCDF', 'MEM', 'PNG',
+)
 
-invalid_drivers = ('Foo baz', 'clucka', 'ESRI Shp')
+invalid_drivers = ('Foo baz', 'clucka', 'ESRI Shp', 'ESRI rast')
 
-aliases = {'eSrI': 'ESRI Shapefile',
-           'TigER/linE': 'TIGER',
-           'SHAPE': 'ESRI Shapefile',
-           'sHp': 'ESRI Shapefile',
-           }
+aliases = {
+    'eSrI': 'ESRI Shapefile',
+    'TigER/linE': 'TIGER',
+    'SHAPE': 'ESRI Shapefile',
+    'sHp': 'ESRI Shapefile',
+    'tiFf': 'GTiff',
+    'tIf': 'GTiff',
+    'jPEg': 'JPEG',
+    'jpG': 'JPEG',
+}
 
 
-@skipUnless(HAS_GDAL, "GDAL is required")
+@unittest.skipUnless(HAS_GDAL, "GDAL is required")
 class DriverTest(unittest.TestCase):
 
     def test01_valid_driver(self):
-        "Testing valid OGR Data Source Drivers."
+        "Testing valid GDAL/OGR Data Source Drivers."
         for d in valid_drivers:
             dr = Driver(d)
             self.assertEqual(d, str(dr))
 
     def test02_invalid_driver(self):
-        "Testing invalid OGR Data Source Drivers."
+        "Testing invalid GDAL/OGR Data Source Drivers."
         for i in invalid_drivers:
             self.assertRaises(OGRException, Driver, i)