diff --git a/django_hstore/fields.py b/django_hstore/fields.py index 21e064c..6654e06 100755 --- a/django_hstore/fields.py +++ b/django_hstore/fields.py @@ -78,6 +78,7 @@ def south_field_triple(self): HStoreField.register_lookup(HStoreLessThanOrEqual) HStoreField.register_lookup(HStoreContains) HStoreField.register_lookup(HStoreIContains) + HStoreField.register_lookup(HStoreIsNull) class DictionaryField(HStoreField): diff --git a/django_hstore/lookups.py b/django_hstore/lookups.py index bd280a1..8f5b954 100755 --- a/django_hstore/lookups.py +++ b/django_hstore/lookups.py @@ -7,10 +7,11 @@ LessThan, LessThanOrEqual, Contains, - IContains + IContains, + IsNull ) -from django_hstore.query import get_cast_for_param +from django_hstore.query import get_cast_for_param, get_value_annotations __all__ = [ @@ -20,7 +21,8 @@ 'HStoreLessThan', 'HStoreLessThanOrEqual', 'HStoreContains', - 'HStoreIContains' + 'HStoreIContains', + 'HStoreIsNull' ] @@ -28,7 +30,7 @@ class HStoreLookupMixin(object): def __init__(self, lhs, rhs, *args, **kwargs): # We need to record the types of the rhs parameters before they are converted to strings if isinstance(rhs, dict): - self.value_annot = dict((key, type(subvalue)) for key, subvalue in six.iteritems(rhs)) + self.value_annot = get_value_annotations(rhs) super(HStoreLookupMixin, self).__init__(lhs, rhs) @@ -118,3 +120,22 @@ def as_postgresql(self, qn, connection): class HStoreIContains(IContains, HStoreContains): pass + + +class HStoreIsNull(IsNull): + + def as_postgresql(self, qn, connection): + lhs, lhs_params = self.process_lhs(qn, connection) + + if isinstance(self.rhs, dict): + param = self.rhs + param_keys = list(param.keys()) + conditions = [] + + for key in param_keys: + op = 'IS NULL' if param[key] else 'IS NOT NULL' + conditions.append('(%s->\'%s\') %s' % (lhs, key, op)) + + return (" AND ".join(conditions), lhs_params) + + return super(HStoreIsNull, self).as_sql(qn, connection) diff --git a/django_hstore/query.py b/django_hstore/query.py index dc37561..a6383b4 100755 --- a/django_hstore/query.py +++ b/django_hstore/query.py @@ -88,12 +88,18 @@ def get_cast_for_param(value_annot, key): return '::float8' elif issubclass(value_annot[key], Decimal): return '::numeric' - elif issubclass(value_annot[key], bool): + elif value_annot[key] in (True, False): return '::boolean' else: return '' +def get_value_annotations(param): + # We need to store the actual value for booleans, not just the type, for isnull + get_type = lambda v: v if isinstance(v, bool) else type(v) + return dict((key, get_type(subvalue)) for key, subvalue in six.iteritems(param)) + + class HStoreWhereNode(WhereNode): def add(self, data, *args, **kwargs): @@ -108,8 +114,8 @@ def add(self, data, *args, **kwargs): if isinstance(original_value, dict): len_children = len(self.children) if self.children else 0 - value_annot = dict((key, type(subvalue)) - for key, subvalue in six.iteritems(original_value)) + value_annot = get_value_annotations(original_value) + # We should be able to get the normal child node here, but it is not returned in Django 1.5 super(HStoreWhereNode, self).add(data, *args, **kwargs) @@ -202,6 +208,15 @@ def make_atom(self, child, qn, connection): raise ValueError('invalid value') elif lookup_type == 'isnull': + if isinstance(param, dict): + param_keys = list(param.keys()) + conditions = [] + + for key in param_keys: + op = 'IS NULL' if value_annot[key] else 'IS NOT NULL' + conditions.append('(%s->\'%s\') %s' % (field, key, op)) + + return (" AND ".join(conditions), []) # do not perform any special format return super(HStoreWhereNode, self).make_atom(child, qn, connection) diff --git a/doc/doc.asciidoc b/doc/doc.asciidoc index a83f6bb..3016a67 100755 --- a/doc/doc.asciidoc +++ b/doc/doc.asciidoc @@ -449,6 +449,13 @@ Something.objects.filter(data__contains=['a', 'b']) # subset by single key Something.objects.filter(data__contains=['a']) + +# filter by is null on individual key/value pairs +Something.objects.filter(data__isnull={'a': True}) +Something.objects.filter(data__isnull={'a': True, 'b': False}) + +# filter by is null on the column works as normal +Something.objects.filter(data__isnull=True) ---- diff --git a/tests/django_hstore_tests/tests.py b/tests/django_hstore_tests/tests.py index c4238e1..0d614b1 100755 --- a/tests/django_hstore_tests/tests.py +++ b/tests/django_hstore_tests/tests.py @@ -153,6 +153,18 @@ def test_nullable_getitem(self): with self.assertRaises(KeyError): n.data['test'] + def test_null_values(self): + null_v = DataBag.objects.create(name="test", data={"v": None}) + nonnull_v = DataBag.objects.create(name="test", data={"v": "item"}) + + r = DataBag.objects.filter(data__isnull={"v": True}) + self.assertEqual(len(r), 1) + self.assertEqual(r[0], null_v) + + r = DataBag.objects.filter(data__isnull={"v": False}) + self.assertEqual(len(r), 1) + self.assertEqual(r[0], nonnull_v) + def test_named_querying(self): alpha, beta = self._create_bags() self.assertEqual(DataBag.objects.get(name='alpha'), alpha)