1
1
import datetime
2
+ import sys
3
+ import warnings
2
4
3
- from django import forms
5
+ import django
6
+ import pytz
4
7
from django .core import exceptions , checks
5
-
6
- from django .db import models
7
-
8
+ from django .db .models import DateTimeField
8
9
from django .utils import timezone
9
10
from django .utils .dateparse import parse_date , parse_datetime
10
11
from django .utils .translation import gettext_lazy as _
11
12
13
+ if django .VERSION >= (1 , 11 ):
14
+ from django .db .models .functions .datetime import TruncBase , Extract , ExtractYear
15
+ from django .db .models .lookups import Exact , GreaterThan , GreaterThanOrEqual , \
16
+ LessThan , LessThanOrEqual
17
+
12
18
13
- class NaiveDateTimeField (models . DateField ):
19
+ class NaiveDateTimeField (DateTimeField ):
14
20
description = _ ("Naive Date (with time)" )
15
21
22
+ default_error_messages = {
23
+ 'tzaware' : _ ("TZ-aware datetimes cannot be coerced to naive datetimes" ),
24
+ }
25
+
16
26
def get_internal_type (self ):
17
- return "NaiveDateTime "
27
+ return "DateTimeField "
18
28
19
29
def db_type (self , connection ):
20
- if connection .settings_dict ["ENGINE" ] in [
21
- "django_prometheus.db.backends.postgresql" ,
22
- "django_prometheus.db.backends.postgresql_psycopg2" ,
23
- "django.db.backends.postgresql" ,
24
- "django.db.backends.postgresql_psycopg2" ,
25
- ]:
30
+ if connection .vendor == "postgresql" :
26
31
return "timestamp without time zone"
27
-
28
- raise NotImplementedError ("Only postgresql is supported at this time." )
32
+ return super (NaiveDateTimeField , self ).db_type (connection )
29
33
30
34
def _check_fix_default_value (self ):
31
35
"""
@@ -81,13 +85,17 @@ def to_python(self, value):
81
85
if value is None :
82
86
return value
83
87
if isinstance (value , datetime .datetime ):
84
- return value .replace (tzinfo = None )
88
+ if timezone .is_aware (value ):
89
+ raise exceptions .ValidationError (self .error_messages ['tzaware' ])
90
+ return value
85
91
if isinstance (value , datetime .date ):
86
92
return datetime .datetime (value .year , value .month , value .day )
87
93
88
94
try :
89
95
parsed = parse_datetime (value )
90
96
if parsed is not None :
97
+ if timezone .is_aware (parsed ):
98
+ raise exceptions .ValidationError (self .error_messages ['tzaware' ])
91
99
return parsed
92
100
except ValueError :
93
101
raise exceptions .ValidationError (
@@ -112,104 +120,91 @@ def to_python(self, value):
112
120
)
113
121
114
122
def get_prep_value (self , value ):
115
- """
116
- Ensure we have a naive datetime ready for insertion
117
- """
118
- value = super (NaiveDateTimeField , self ).get_prep_value (value )
119
- value = self .to_python (value )
120
-
121
- if value is not None and timezone .is_aware (value ):
122
- # We were given an aware datetime, strip off tzinfo
123
- value = value .replace (tzinfo = None )
124
-
125
- return value
126
-
127
- def get_db_prep_value (self , value , connection , prepared = False ):
128
- if not prepared :
129
- value = self .get_prep_value (value )
130
-
131
- if value is None :
132
- return None
133
-
134
- if hasattr (value , "resolve_expression" ):
123
+ return super (DateTimeField , self ).get_prep_value (value )
124
+
125
+ def from_db_value (self , value , expression , connection , context ):
126
+ is_truncbase = django .VERSION >= (1 , 11 ) and isinstance (expression , TruncBase )
127
+ if is_truncbase and not isinstance (expression , NaiveAsSQLMixin ):
128
+ raise TypeError (
129
+ "Django's %s cannot be used with a NaiveDateTimeField"
130
+ % expression .__class__ .__name__
131
+ )
132
+ if connection .vendor == "postgresql" :
133
+ if is_truncbase :
134
+ return timezone .make_naive (value , pytz .utc )
135
135
return value
136
-
137
- if connection .settings_dict ["ENGINE" ] == "django.db.backends.mysql" :
138
- return str (value )
139
-
140
- elif connection .settings_dict ["ENGINE" ] == "django.db.backends.sqlite3" :
141
- return str (value )
142
-
143
- elif connection .settings_dict ["ENGINE" ] == "django.db.backends.oracle" :
144
- from django .db .backends .oracle .utils import Oracle_datetime
145
-
146
- return Oracle_datetime .from_datetime (value )
147
-
136
+ if timezone .is_aware (value ):
137
+ if django .VERSION < (1 , 9 ):
138
+ return timezone .make_naive (value , pytz .utc )
139
+ return timezone .make_naive (value , connection .timezone )
148
140
return value
149
141
150
142
def pre_save (self , model_instance , add ):
151
143
if self .auto_now or (self .auto_now_add and add ):
152
- value = timezone .now (). replace ( tzinfo = None )
144
+ value = timezone .make_naive ( timezone . now () )
153
145
setattr (model_instance , self .attname , value )
154
146
return value
155
147
else :
156
148
return super (NaiveDateTimeField , self ).pre_save (model_instance , add )
157
149
158
- def value_to_string (self , obj ):
159
- val = self .value_from_object (obj )
160
- return "" if val is None else val .isoformat ()
161
150
162
- def formfield (self , ** kwargs ):
163
- defaults = {"form_class" : forms .DateTimeField }
164
- defaults .update (kwargs )
165
- return super (NaiveDateTimeField , self ).formfield (** defaults )
151
+ class NaiveTimezoneMixin (object ):
152
+ def get_tzname (self ):
153
+ if isinstance (self .output_field , NaiveDateTimeField ):
154
+ if self .tzinfo is not None :
155
+ warnings .warn (
156
+ "tzinfo argument provided when truncating a NaiveDateTimeField. "
157
+ "This argument will have no effect."
158
+ )
159
+ return 'UTC'
160
+ return super (NaiveTimezoneMixin , self ).get_tzname ()
166
161
167
162
168
- # try to register our field for the __time and __date lookups
169
- try :
170
- from django .db .models import TimeField
171
- from django .db .models .functions .datetime import TruncBase
163
+ class NaiveConvertValueMixin (object ):
164
+ def convert_value (self , value , * args , ** kwargs ):
165
+ if isinstance (self .output_field , NaiveDateTimeField ):
166
+ return value
167
+ return super (NaiveConvertValueMixin , self ).convert_value (value , * args , ** kwargs )
172
168
173
- class TruncTimeNaive (TruncBase ):
174
- kind = "time"
175
- lookup_name = "time"
176
- output_field = TimeField ()
177
169
178
- def as_sql (self , compiler , connection ):
179
- # Cast to date rather than truncate to date.
180
- lhs , lhs_params = compiler .compile (self .lhs )
170
+ class NaiveAsSQLMixin (object ):
171
+ def as_sql (self , compiler , connection ):
172
+ if isinstance (self .lhs .output_field , NaiveDateTimeField ):
173
+ with timezone .override (pytz .utc ):
174
+ return super (NaiveAsSQLMixin , self ).as_sql (compiler , connection )
175
+ return super (NaiveAsSQLMixin , self ).as_sql (compiler , connection )
181
176
182
- # this is a postgresql only compatible cast, replacing
183
- # a call to connection.ops.datetime_cast_time_sql that
184
- # wouldn't work with None tzinfo
185
- sql = "(%s)::time" % lhs
186
177
187
- return sql , lhs_params
178
+ _monkeypatching = False
188
179
189
- NaiveDateTimeField .register_lookup (TruncTimeNaive )
190
- except ImportError :
191
- pass
192
180
193
- try :
194
- from django .db .models import DateField
195
- from django .db .models .functions .datetime import TruncBase
181
+ if django .VERSION >= (1 , 11 ):
182
+ _this_module = sys .modules [__name__ ]
183
+ _db_functions = sys .modules ['django.db.models.functions' ]
184
+ _lookups = set (DateTimeField .get_lookups ().values ())
185
+ _patch_classes = [
186
+ (Extract , [NaiveAsSQLMixin , NaiveTimezoneMixin ]),
187
+ (TruncBase , [NaiveAsSQLMixin , NaiveTimezoneMixin , NaiveConvertValueMixin ]),
188
+ ]
189
+ for original , mixins in _patch_classes :
190
+ for cls in original .__subclasses__ ():
196
191
197
- class TruncDateNaive (TruncBase ):
198
- kind = "date"
199
- lookup_name = "date"
200
- output_field = DateField ()
192
+ bases = tuple (mixins ) + (cls ,)
193
+ naive_cls = type (cls .__name__ , bases , {})
201
194
202
- def as_sql (self , compiler , connection ):
203
- # Cast to date rather than truncate to date.
204
- lhs , lhs_params = compiler .compile (self .lhs )
195
+ if _monkeypatching :
196
+ setattr (_db_functions , cls .__name__ , naive_cls )
205
197
206
- # this is a postgresql only compatible cast, replacing
207
- # a call to connection.ops.datetime_cast_date_sql that
208
- # wouldn't work with None tzinfo
209
- sql = "(%s)::date" % lhs
198
+ if cls in _lookups :
199
+ NaiveDateTimeField .register_lookup (naive_cls )
210
200
211
- return sql , lhs_params
201
+ # Year lookups don't need special handling with naive fields
202
+ if cls is ExtractYear :
203
+ naive_cls .register_lookup (Exact )
204
+ naive_cls .register_lookup (GreaterThan )
205
+ naive_cls .register_lookup (GreaterThanOrEqual )
206
+ naive_cls .register_lookup (LessThan )
207
+ naive_cls .register_lookup (LessThanOrEqual )
212
208
213
- NaiveDateTimeField .register_lookup (TruncDateNaive )
214
- except ImportError :
215
- pass
209
+ # Add an attribute to this module so these functions can be imported
210
+ setattr (_this_module , cls .__name__ , naive_cls )
0 commit comments