mirror of
				https://github.com/django/django.git
				synced 2025-10-24 22:26:08 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			226 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			226 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import unicode_literals
 | |
| 
 | |
| from collections import defaultdict
 | |
| import datetime
 | |
| import decimal
 | |
| import hashlib
 | |
| import logging
 | |
| from time import time
 | |
| 
 | |
| from django.conf import settings
 | |
| from django.utils.encoding import force_bytes
 | |
| from django.utils.timezone import utc
 | |
| 
 | |
| 
 | |
| logger = logging.getLogger('django.db.backends')
 | |
| 
 | |
| 
 | |
| class CursorWrapper(object):
 | |
|     def __init__(self, cursor, db):
 | |
|         self.cursor = cursor
 | |
|         self.db = db
 | |
| 
 | |
|     WRAP_ERROR_ATTRS = frozenset(['fetchone', 'fetchmany', 'fetchall', 'nextset'])
 | |
| 
 | |
|     def __getattr__(self, attr):
 | |
|         cursor_attr = getattr(self.cursor, attr)
 | |
|         if attr in CursorWrapper.WRAP_ERROR_ATTRS:
 | |
|             return self.db.wrap_database_errors(cursor_attr)
 | |
|         else:
 | |
|             return cursor_attr
 | |
| 
 | |
|     def __iter__(self):
 | |
|         return iter(self.cursor)
 | |
| 
 | |
|     def __enter__(self):
 | |
|         return self
 | |
| 
 | |
|     def __exit__(self, type, value, traceback):
 | |
|         # Ticket #17671 - Close instead of passing thru to avoid backend
 | |
|         # specific behavior.
 | |
|         self.close()
 | |
| 
 | |
|     # The following methods cannot be implemented in __getattr__, because the
 | |
|     # code must run when the method is invoked, not just when it is accessed.
 | |
| 
 | |
|     def callproc(self, procname, params=None):
 | |
|         self.db.validate_no_broken_transaction()
 | |
|         self.db.set_dirty()
 | |
|         with self.db.wrap_database_errors:
 | |
|             if params is None:
 | |
|                 return self.cursor.callproc(procname)
 | |
|             else:
 | |
|                 return self.cursor.callproc(procname, params)
 | |
| 
 | |
|     def execute(self, sql, params=None):
 | |
|         self.db.validate_no_broken_transaction()
 | |
|         self.db.set_dirty()
 | |
|         with self.db.wrap_database_errors:
 | |
|             if params is None:
 | |
|                 return self.cursor.execute(sql)
 | |
|             else:
 | |
|                 return self.cursor.execute(sql, params)
 | |
| 
 | |
|     def executemany(self, sql, param_list):
 | |
|         self.db.validate_no_broken_transaction()
 | |
|         self.db.set_dirty()
 | |
|         with self.db.wrap_database_errors:
 | |
|             return self.cursor.executemany(sql, param_list)
 | |
| 
 | |
| 
 | |
| class CursorDebugWrapper(CursorWrapper):
 | |
| 
 | |
|     # XXX callproc isn't instrumented at this time.
 | |
| 
 | |
|     def execute(self, sql, params=None):
 | |
|         start = time()
 | |
|         try:
 | |
|             return super(CursorDebugWrapper, self).execute(sql, params)
 | |
|         finally:
 | |
|             stop = time()
 | |
|             duration = stop - start
 | |
|             sql = self.db.ops.last_executed_query(self.cursor, sql, params)
 | |
|             self.db.queries.append({
 | |
|                 'sql': sql,
 | |
|                 'time': "%.3f" % duration,
 | |
|             })
 | |
|             logger.debug('(%.3f) %s; args=%s' % (duration, sql, params),
 | |
|                 extra={'duration': duration, 'sql': sql, 'params': params}
 | |
|             )
 | |
| 
 | |
|     def executemany(self, sql, param_list):
 | |
|         start = time()
 | |
|         try:
 | |
|             return super(CursorDebugWrapper, self).executemany(sql, param_list)
 | |
|         finally:
 | |
|             stop = time()
 | |
|             duration = stop - start
 | |
|             try:
 | |
|                 times = len(param_list)
 | |
|             except TypeError:           # param_list could be an iterator
 | |
|                 times = '?'
 | |
|             self.db.queries.append({
 | |
|                 'sql': '%s times: %s' % (times, sql),
 | |
|                 'time': "%.3f" % duration,
 | |
|             })
 | |
|             logger.debug('(%.3f) %s; args=%s' % (duration, sql, param_list),
 | |
|                 extra={'duration': duration, 'sql': sql, 'params': param_list}
 | |
|             )
 | |
| 
 | |
| 
 | |
| ###############################################
 | |
| # Converters from database (string) to Python #
 | |
| ###############################################
 | |
| 
 | |
| def typecast_date(s):
 | |
|     return datetime.date(*map(int, s.split('-'))) if s else None  # returns None if s is null
 | |
| 
 | |
| 
 | |
| def typecast_time(s):  # does NOT store time zone information
 | |
|     if not s:
 | |
|         return None
 | |
|     hour, minutes, seconds = s.split(':')
 | |
|     if '.' in seconds:  # check whether seconds have a fractional part
 | |
|         seconds, microseconds = seconds.split('.')
 | |
|     else:
 | |
|         microseconds = '0'
 | |
|     return datetime.time(int(hour), int(minutes), int(seconds), int(float('.' + microseconds) * 1000000))
 | |
| 
 | |
| 
 | |
| def typecast_timestamp(s):  # does NOT store time zone information
 | |
|     # "2005-07-29 15:48:00.590358-05"
 | |
|     # "2005-07-29 09:56:00-05"
 | |
|     if not s:
 | |
|         return None
 | |
|     if not ' ' in s:
 | |
|         return typecast_date(s)
 | |
|     d, t = s.split()
 | |
|     # Extract timezone information, if it exists. Currently we just throw
 | |
|     # it away, but in the future we may make use of it.
 | |
|     if '-' in t:
 | |
|         t, tz = t.split('-', 1)
 | |
|         tz = '-' + tz
 | |
|     elif '+' in t:
 | |
|         t, tz = t.split('+', 1)
 | |
|         tz = '+' + tz
 | |
|     else:
 | |
|         tz = ''
 | |
|     dates = d.split('-')
 | |
|     times = t.split(':')
 | |
|     seconds = times[2]
 | |
|     if '.' in seconds:  # check whether seconds have a fractional part
 | |
|         seconds, microseconds = seconds.split('.')
 | |
|     else:
 | |
|         microseconds = '0'
 | |
|     tzinfo = utc if settings.USE_TZ else None
 | |
|     return datetime.datetime(int(dates[0]), int(dates[1]), int(dates[2]),
 | |
|         int(times[0]), int(times[1]), int(seconds),
 | |
|         int((microseconds + '000000')[:6]), tzinfo)
 | |
| 
 | |
| 
 | |
| def typecast_decimal(s):
 | |
|     if s is None or s == '':
 | |
|         return None
 | |
|     return decimal.Decimal(s)
 | |
| 
 | |
| 
 | |
| ###############################################
 | |
| # Converters from Python to database (string) #
 | |
| ###############################################
 | |
| 
 | |
| def rev_typecast_decimal(d):
 | |
|     if d is None:
 | |
|         return None
 | |
|     return str(d)
 | |
| 
 | |
| 
 | |
| def truncate_name(name, length=None, hash_len=4):
 | |
|     """Shortens a string to a repeatable mangled version with the given length.
 | |
|     """
 | |
|     if length is None or len(name) <= length:
 | |
|         return name
 | |
| 
 | |
|     hsh = hashlib.md5(force_bytes(name)).hexdigest()[:hash_len]
 | |
|     return '%s%s' % (name[:length - hash_len], hsh)
 | |
| 
 | |
| 
 | |
| def format_number(value, max_digits, decimal_places):
 | |
|     """
 | |
|     Formats a number into a string with the requisite number of digits and
 | |
|     decimal places.
 | |
|     """
 | |
|     if isinstance(value, decimal.Decimal):
 | |
|         context = decimal.getcontext().copy()
 | |
|         context.prec = max_digits
 | |
|         return "{0:f}".format(value.quantize(decimal.Decimal(".1") ** decimal_places, context=context))
 | |
|     else:
 | |
|         return "%.*f" % (decimal_places, value)
 | |
| 
 | |
| # Map of vendor name -> map of query element class -> implementation function
 | |
| compile_implementations = defaultdict(dict)
 | |
| 
 | |
| 
 | |
| def get_implementations(vendor):
 | |
|     return compile_implementations[vendor]
 | |
| 
 | |
| 
 | |
| class add_implementation(object):
 | |
|     """
 | |
|     A decorator to allow customised implementations for query expressions.
 | |
|     For example:
 | |
|         @add_implementation(Exact, 'mysql')
 | |
|         def mysql_exact(node, qn, connection):
 | |
|             # Play with the node here.
 | |
|             return somesql, list_of_params
 | |
|     Now Exact nodes are compiled to SQL using mysql_exact instead of
 | |
|     Exact.as_sql() when using MySQL backend.
 | |
|     """
 | |
|     def __init__(self, klass, vendor):
 | |
|         self.klass = klass
 | |
|         self.vendor = vendor
 | |
| 
 | |
|     def __call__(self, func):
 | |
|         implementations = get_implementations(self.vendor)
 | |
|         implementations[self.klass] = func
 | |
|         return func
 |