Skip to content

Commit

Permalink
fix!: [py] use full word for namespace and add test (#4485)
Browse files Browse the repository at this point in the history
  • Loading branch information
lalo authored Feb 2, 2023
1 parent a2dc620 commit dcbcd07
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 17 deletions.
23 changes: 11 additions & 12 deletions python/pylibvw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -624,10 +624,8 @@ void ex_push_feature(example_ptr ec, unsigned char ns, feature_index fid, float
}

// List[Union[Tuple[Union[str,int], float], str,int]]
void ex_push_feature_list(example_ptr ec, vw_ptr vw, unsigned char ns, py::list& a)
void ex_push_feature_list(example_ptr ec, vw_ptr vw, unsigned char ns_first_letter, uint64_t ns_hash, py::list& a)
{ // warning: assumes namespace exists!
char ns_str[2] = {(char)ns, 0};
uint64_t ns_hash = VW::hash_space(*vw, ns_str);
size_t count = 0;
for (ssize_t i = 0; i < len(a); i++)
{
Expand Down Expand Up @@ -678,7 +676,7 @@ void ex_push_feature_list(example_ptr ec, vw_ptr vw, unsigned char ns, py::list&
}
if (got)
{
ec->feature_space[ns].push_back(f.x, f.weight_index);
ec->feature_space[ns_first_letter].push_back(f.x, f.weight_index);
count++;
}
}
Expand All @@ -688,11 +686,9 @@ void ex_push_feature_list(example_ptr ec, vw_ptr vw, unsigned char ns, py::list&
}

// Dict[Union[str,int],Union[int,float]]
void ex_push_feature_dict(example_ptr ec, vw_ptr vw, unsigned char ns, PyObject* o)
void ex_push_feature_dict(example_ptr ec, vw_ptr vw, unsigned char ns_first_letter, uint64_t ns_hash, PyObject* o)
{
// warning: assumes namespace exists!
char ns_str[2] = {(char)ns, 0};
uint64_t ns_hash = VW::hash_space(*vw, ns_str);
size_t count = 0;
const char* key_chars;

Expand Down Expand Up @@ -729,7 +725,7 @@ void ex_push_feature_dict(example_ptr ec, vw_ptr vw, unsigned char ns, PyObject*
continue;
}

ec->feature_space[ns].push_back(feat_value, feat_index);
ec->feature_space[ns_first_letter].push_back(feat_value, feat_index);
count++;
}

Expand Down Expand Up @@ -759,15 +755,18 @@ void ex_push_dictionary(example_ptr ec, vw_ptr vw, PyObject* o)
{
py::extract<std::string> ns_e(ns_raw);
if (ns_e().length() < 1) continue;
unsigned char ns = ns_e()[0];

ex_ensure_namespace_exists(ec, ns);
std::string ns_full = ns_e();
unsigned char ns_first_letter = ns_full[0];
uint64_t ns_hash = VW::hash_space(*vw, ns_full);

if (PyDict_Check(feats)) { ex_push_feature_dict(ec, vw, ns, feats); }
ex_ensure_namespace_exists(ec, ns_first_letter);

if (PyDict_Check(feats)) { ex_push_feature_dict(ec, vw, ns_first_letter, ns_hash, feats); }
else
{
py::list list = py::extract<py::list>(feats);
ex_push_feature_list(ec, vw, ns, list);
ex_push_feature_list(ec, vw, ns_first_letter, ns_hash, list);
}
}
}
Expand Down
14 changes: 11 additions & 3 deletions python/tests/test_pyvw.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,14 +540,22 @@ def test_example_features():
def test_example_features_dict():
vw = Workspace(quiet=True)
ex = vw.example(
{"a": {"two": 1, "features": 1.0}, "b": {"more": 1, "features": 1, 5: 1.5}}
{
"a": {"two": 1, "features": 1.0},
"namespace": {"more": 1, "feature": 1, 5: 1.5},
}
)
fs = list(ex.iter_features())
fs_keys = [f[0] for f in fs]

expected = [53373, 165129, 24716, 242309, 5]

assert set(fs_keys) == set(expected)

assert (ex.get_feature_id("a", "two"), 1) in fs
assert (ex.get_feature_id("a", "features"), 1) in fs
assert (ex.get_feature_id("b", "more"), 1) in fs
assert (ex.get_feature_id("b", "features"), 1) in fs
assert (ex.get_feature_id("namespace", "more"), 1) in fs
assert (ex.get_feature_id("namespace", "feature"), 1) in fs
assert (5, 1.5) in fs


Expand Down
14 changes: 12 additions & 2 deletions python/vowpalwabbit/pyvw.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,7 @@ def __init__(self, ex: "Example", id: Union[int, str]):
- If int, uses that as an index into this Examples list of feature groups to get the namespace id character
- If str, uses the first character as the namespace id character
"""
self.full = None
if isinstance(id, int): # you've specified a namespace by index
if id < 0 or id >= ex.num_namespaces():
raise Exception("namespace " + str(id) + " out of bounds")
Expand All @@ -983,6 +984,7 @@ def __init__(self, ex: "Example", id: Union[int, str]):
elif isinstance(id, str): # you've specified a namespace by string
if len(id) == 0:
id = " "
self.full = id
self.id = None # we don't know and we don't want to do the linear search required to find it
self.ns = id[0]
self.ord_ns = ord(self.ns)
Expand Down Expand Up @@ -1695,6 +1697,7 @@ def num_features_in(self, ns: Union[NamespaceId, str, int]) -> int:
"""
return pylibvw.example.num_features_in(self, self.get_ns(ns).ord_ns)

# pytype: disable=attribute-error
def get_feature_id(
self,
ns: Union[NamespaceId, str, int],
Expand Down Expand Up @@ -1722,7 +1725,13 @@ def get_feature_id(
return feature
if isinstance(feature, str):
if ns_hash is None:
ns_hash = self.vw.hash_space(self.get_ns(ns).ns)
if type(ns) != NamespaceId:
ns = self.get_ns(ns)
ns_hash = (
self.vw.hash_space(ns.full)
if ns.full
else self.vw.hash_space(ns.ns)
)
return self.vw.hash_feature(feature, ns_hash)
raise Exception("cannot extract feature of type: " + str(type(feature)))

Expand Down Expand Up @@ -1839,8 +1848,9 @@ def push_features(
"""
ns = self.get_ns(ns)
self.ensure_namespace_exists(ns)
ns_hash = self.vw.hash_space(ns.full) if ns.full else self.vw.hash_space(ns.ns)
self.push_feature_list(
self.vw, ns.ord_ns, featureList
self.vw, ns.ord_ns, ns_hash, featureList
) # much faster just to do it in C++
# ns_hash = self.vw.hash_space( ns.ns )
# for feature in featureList:
Expand Down

0 comments on commit dcbcd07

Please sign in to comment.