diff --git a/django/contrib/gis/geos/libgeos.py b/django/contrib/gis/geos/libgeos.py index 500825a700..f4f93b57cc 100644 --- a/django/contrib/gis/geos/libgeos.py +++ b/django/contrib/gis/geos/libgeos.py @@ -9,12 +9,13 @@ import logging import os import re +import threading from ctypes import CDLL, CFUNCTYPE, POINTER, Structure, c_char_p from ctypes.util import find_library from django.contrib.gis.geos.error import GEOSException from django.core.exceptions import ImproperlyConfigured -from django.utils.functional import SimpleLazyObject +from django.utils.functional import SimpleLazyObject, cached_property logger = logging.getLogger('django.contrib.gis') @@ -63,10 +64,11 @@ def load_geos(): _lgeos = CDLL(lib_path) # Here we set up the prototypes for the initGEOS_r and finishGEOS_r # routines. These functions aren't actually called until they are - # attached to a GEOS context handle -- this actually occurs in - # geos/prototypes/threadsafe.py. + # attached to a GEOS context handle. _lgeos.initGEOS_r.restype = CONTEXT_PTR _lgeos.finishGEOS_r.argtypes = [CONTEXT_PTR] + # Ensures compatibility across 32 and 64-bit platforms. + _lgeos.GEOSversion.restype = c_char_p return _lgeos @@ -134,6 +136,27 @@ def get_pointer_arr(n): lgeos = SimpleLazyObject(load_geos) +class GEOSContextHandle(object): + def __init__(self): + # Initializing the context handle for this thread with + # the notice and error handler. + self.ptr = lgeos.initGEOS_r(notice_h, error_h) + + def __del__(self): + if self.ptr and lgeos: + lgeos.finishGEOS_r(self.ptr) + + +class GEOSContext(threading.local): + + @cached_property + def ptr(self): + # Assign handle so it will will garbage collected when + # thread is finished. + self.handle = GEOSContextHandle() + return self.handle.ptr + + class GEOSFuncFactory(object): """ Lazy loading of GEOS functions. @@ -141,6 +164,7 @@ class GEOSFuncFactory(object): argtypes = None restype = None errcheck = None + thread_context = GEOSContext() def __init__(self, func_name, *args, **kwargs): self.func_name = func_name @@ -154,21 +178,23 @@ class GEOSFuncFactory(object): def __call__(self, *args, **kwargs): if self.func is None: self.func = self.get_func(*self.args, **self.kwargs) - return self.func(*args, **kwargs) + # Call the threaded GEOS routine with pointer of the context handle + # as the first argument. + return self.func(self.thread_context.ptr, *args) def get_func(self, *args, **kwargs): - from django.contrib.gis.geos.prototypes.threadsafe import GEOSFunc - func = GEOSFunc(self.func_name) - func.argtypes = self.argtypes or [] + # GEOS thread-safe function signatures end with '_r' and + # take an additional context handle parameter. + func = getattr(lgeos, self.func_name + '_r') + func.argtypes = [CONTEXT_PTR] + (self.argtypes or []) func.restype = self.restype if self.errcheck: func.errcheck = self.errcheck return func -# Returns the string version of the GEOS library. Have to set the restype -# explicitly to c_char_p to ensure compatibility across 32 and 64-bit platforms. -geos_version = GEOSFuncFactory('GEOSversion', restype=c_char_p) +# Returns the string version of the GEOS library. +geos_version = lambda: lgeos.GEOSversion() # Regular expression should be able to parse version strings such as # '3.0.0rc4-CAPI-1.3.3', '3.0.0-CAPI-1.4.1', '3.4.0dev-CAPI-1.8.0' or '3.4.0dev-CAPI-1.8.0 r0' diff --git a/django/contrib/gis/geos/prototypes/threadsafe.py b/django/contrib/gis/geos/prototypes/threadsafe.py deleted file mode 100644 index 45c87d6004..0000000000 --- a/django/contrib/gis/geos/prototypes/threadsafe.py +++ /dev/null @@ -1,93 +0,0 @@ -import threading - -from django.contrib.gis.geos.libgeos import ( - CONTEXT_PTR, error_h, lgeos, notice_h, -) - - -class GEOSContextHandle(object): - """ - Python object representing a GEOS context handle. - """ - def __init__(self): - # Initializing the context handler for this thread with - # the notice and error handler. - self.ptr = lgeos.initGEOS_r(notice_h, error_h) - - def __del__(self): - if self.ptr and lgeos: - lgeos.finishGEOS_r(self.ptr) - - -# Defining a thread-local object and creating an instance -# to hold a reference to GEOSContextHandle for this thread. -class GEOSContext(threading.local): - handle = None - -thread_context = GEOSContext() - - -class GEOSFunc(object): - """ - Class that serves as a wrapper for GEOS C Functions, and will - use thread-safe function variants when available. - """ - def __init__(self, func_name): - try: - # GEOS thread-safe function signatures end with '_r', and - # take an additional context handle parameter. - self.cfunc = getattr(lgeos, func_name + '_r') - self.threaded = True - # Create a reference here to thread_context so it's not - # garbage-collected before an attempt to call this object. - self.thread_context = thread_context - except AttributeError: - # Otherwise, use usual function. - self.cfunc = getattr(lgeos, func_name) - self.threaded = False - - def __call__(self, *args): - if self.threaded: - # If a context handle does not exist for this thread, initialize one. - if not self.thread_context.handle: - self.thread_context.handle = GEOSContextHandle() - # Call the threaded GEOS routine with pointer of the context handle - # as the first argument. - return self.cfunc(self.thread_context.handle.ptr, *args) - else: - return self.cfunc(*args) - - def __str__(self): - return self.cfunc.__name__ - - # argtypes property - def _get_argtypes(self): - return self.cfunc.argtypes - - def _set_argtypes(self, argtypes): - if self.threaded: - new_argtypes = [CONTEXT_PTR] - new_argtypes.extend(argtypes) - self.cfunc.argtypes = new_argtypes - else: - self.cfunc.argtypes = argtypes - - argtypes = property(_get_argtypes, _set_argtypes) - - # restype property - def _get_restype(self): - return self.cfunc.restype - - def _set_restype(self, restype): - self.cfunc.restype = restype - - restype = property(_get_restype, _set_restype) - - # errcheck property - def _get_errcheck(self): - return self.cfunc.errcheck - - def _set_errcheck(self, errcheck): - self.cfunc.errcheck = errcheck - - errcheck = property(_get_errcheck, _set_errcheck) diff --git a/tests/gis_tests/geos_tests/test_geos.py b/tests/gis_tests/geos_tests/test_geos.py index 116e8d35f8..8c67139e92 100644 --- a/tests/gis_tests/geos_tests/test_geos.py +++ b/tests/gis_tests/geos_tests/test_geos.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals import ctypes import json import random +import threading from binascii import a2b_hex, b2a_hex from io import BytesIO from unittest import skipUnless @@ -12,7 +13,7 @@ from django.contrib.gis.gdal import HAS_GDAL from django.contrib.gis.geos import ( HAS_GEOS, GeometryCollection, GEOSException, GEOSGeometry, LinearRing, LineString, MultiLineString, MultiPoint, MultiPolygon, Point, Polygon, - fromfile, fromstr, + fromfile, fromstr, libgeos, ) from django.contrib.gis.geos.base import GEOSBase from django.contrib.gis.geos.libgeos import geos_version_info @@ -1232,6 +1233,48 @@ class GEOSTest(SimpleTestCase, TestDataMixin): self.assertEqual(m.group('version'), v_geos) self.assertEqual(m.group('capi_version'), v_capi) + def test_geos_threads(self): + pnt = Point() + context_ptrs = [] + + geos_init = libgeos.lgeos.initGEOS_r + geos_finish = libgeos.lgeos.finishGEOS_r + + def init(*args, **kwargs): + result = geos_init(*args, **kwargs) + context_ptrs.append(result) + return result + + def finish(*args, **kwargs): + result = geos_finish(*args, **kwargs) + destructor_called.set() + return result + + for i in range(2): + destructor_called = threading.Event() + patch_path = 'django.contrib.gis.geos.libgeos.lgeos' + with mock.patch.multiple(patch_path, initGEOS_r=mock.DEFAULT, finishGEOS_r=mock.DEFAULT) as mocked: + mocked['initGEOS_r'].side_effect = init + mocked['finishGEOS_r'].side_effect = finish + with mock.patch('django.contrib.gis.geos.prototypes.predicates.geos_hasz.func') as mocked_hasz: + thread = threading.Thread(target=lambda: pnt.hasz) + thread.start() + thread.join() + + # We can't be sure that members of thread locals are + # garbage collected right after `thread.join()` so + # we must wait until destructor is actually called. + # Fail if destructor wasn't called within a second. + self.assertTrue(destructor_called.wait(1)) + + context_ptr = context_ptrs[i] + self.assertIsInstance(context_ptr, libgeos.CONTEXT_PTR) + mocked_hasz.assert_called_once_with(context_ptr, pnt.ptr) + mocked['finishGEOS_r'].assert_called_once_with(context_ptr) + + # Check that different contexts were used for the different threads. + self.assertNotEqual(context_ptrs[0], context_ptrs[1]) + @ignore_warnings(category=RemovedInDjango20Warning) def test_deprecated_srid_getters_setters(self): p = Point(1, 2, srid=123)