Skip to content

Commit

Permalink
Swap out get_field_by_name
Browse files Browse the repository at this point in the history
  • Loading branch information
jarekwg committed Jul 6, 2016
1 parent 0c7b8df commit 4b05de9
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions django_hstore/query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import unicode_literals, absolute_import
from __future__ import absolute_import, unicode_literals

from django import VERSION
import django
from django.db import transaction
from django.db.models.query import QuerySet
from django.db.models.query_utils import QueryWrapper
Expand All @@ -9,16 +9,24 @@
from django.db.models.sql.query import Query
from django.db.models.sql.subqueries import UpdateQuery
from django.db.models.sql.where import WhereNode
from django.utils import six

from django_hstore.apps import GEODJANGO_INSTALLED
from django_hstore.utils import get_cast_for_param, get_value_annotations

try:
# django <= 1.8
from django.db.models.sql.where import EmptyShortCircuit
except ImportError:
# django >= 1.9
EmptyShortCircuit = Exception
from django.utils import six

from django_hstore.apps import GEODJANGO_INSTALLED
from django_hstore.utils import get_cast_for_param, get_value_annotations

def get_field(self, name):
if django.VERSION >= (1, 8):
return self.model._meta.get_field(name)
else:
return self.model._meta.get_field_by_name(name)[0]


def select_query(method):
Expand All @@ -34,7 +42,6 @@ def update_query(method):
def updater(self, *args, **params):
self._for_write = True
query = method(self, self.query.clone(UpdateQuery), *args, **params)
forced_managed = False
with transaction.atomic(using=self.db):
rows = query.get_compiler(self.db).execute_sql(None)
self._result_cache = None
Expand Down Expand Up @@ -175,7 +182,8 @@ def hpeek(self, query, attr, key):
query.add_extra({'_': '%s -> %%s' % attr}, [key], None, None, None, None)
result = query.get_compiler(self.db).execute_sql(SINGLE)
if result and result[0]:
field = self.model._meta.get_field_by_name(attr)[0]

field = get_field(self, attr)
return field._value_to_python(result[0])

@select_query
Expand All @@ -186,7 +194,7 @@ def hslice(self, query, attr, keys):
query.add_extra({'_': 'slice("%s", %%s)' % attr}, [keys], None, None, None, None)
result = query.get_compiler(self.db).execute_sql(SINGLE)
if result and result[0]:
field = self.model._meta.get_field_by_name(attr)[0]
field = get_field(self, attr)
return dict((key, field._value_to_python(value)) for key, value in result[0].items())
return {}

Expand All @@ -196,7 +204,7 @@ def hremove(self, query, attr, keys):
Removes the specified keys in the specified hstore.
"""
value = QueryWrapper('delete("%s", %%s)' % attr, [keys])
field, model, direct, m2m = self.model._meta.get_field_by_name(attr)
field = get_field(self, attr)
query.add_update_fields([(field, None, value)])
return query

Expand All @@ -205,7 +213,7 @@ def hupdate(self, query, attr, updates):
"""
Updates the specified hstore.
"""
field, model, direct, m2m = self.model._meta.get_field_by_name(attr)
field = get_field(self, attr)
if hasattr(field, 'serializer'):
updates = field.get_prep_value(updates)
value = QueryWrapper('"%s" || %%s' % attr, [updates])
Expand All @@ -216,7 +224,7 @@ def hupdate(self, query, attr, updates):
if GEODJANGO_INSTALLED:
from django.contrib.gis.db.models.query import GeoQuerySet

if VERSION[:2] <= (1, 7):
if django.VERSION[:2] <= (1, 7):
from django.contrib.gis.db.models.sql.query import GeoQuery
from django.contrib.gis.db.models.sql.where import GeoWhereNode, GeoConstraint

Expand Down

0 comments on commit 4b05de9

Please sign in to comment.