Skip to content

Commit ce7d156

Browse files
authored
Merge branch 'skrub-data:main' into chore-fix-issue1729
2 parents f67ac8f + 85c1c06 commit ce7d156

File tree

7 files changed

+87
-46
lines changed

7 files changed

+87
-46
lines changed

CHANGES.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ New features
1414

1515
Changes
1616
-------
17+
- The :class:`StringEncoder` now exposes the ``vocabulary`` parameter from the parent
18+
:class:`TfidfVectorizer`.
19+
:pr:`1819` by :user:`Eloi Massoulié <emassoulie>`
20+
21+
22+
- :func:`compute_ngram_distance` has been renamed to :func:`_compute_ngram_distance` and is now a private function.
23+
:pr:`1838` by :user:`Siddharth Baleja <siddharthbaleja>`.
1724

1825
Bugfixes
1926
--------

examples/0050_deduplication.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -148,33 +148,6 @@
148148
# If we want to adapt the translation table, we can
149149
# modify it manually.
150150

151-
###############################################################################
152-
# Visualizing string pair-wise distance between names
153-
# ---------------------------------------------------
154-
#
155-
# Below, we use a heatmap to visualize the pairwise-distance between medication
156-
# names. A darker color means that two medication names are closer together
157-
# (i.e. more similar), a lighter color means a larger distance.
158-
#
159-
160-
from scipy.spatial.distance import squareform
161-
162-
from skrub import compute_ngram_distance
163-
164-
ngram_distances = compute_ngram_distance(unique_examples)
165-
square_distances = squareform(ngram_distances)
166-
167-
import seaborn as sns
168-
169-
fig, ax = plt.subplots(figsize=(14, 12))
170-
sns.heatmap(
171-
square_distances, yticklabels=unique_examples, xticklabels=unique_examples, ax=ax
172-
)
173-
plt.show()
174-
175-
###############################################################################
176-
# We have three clusters appearing - the original medication
177-
# names and their misspellings that form a cluster around them.
178151

179152
###############################################################################
180153
# Conclusion

skrub/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
y,
3030
)
3131
from ._datetime_encoder import DatetimeEncoder
32-
from ._deduplicate import compute_ngram_distance, deduplicate
32+
from ._deduplicate import deduplicate
3333
from ._drop_uninformative import DropUninformative
3434
from ._fuzzy_join import fuzzy_join
3535
from ._gap_encoder import GapEncoder
@@ -77,7 +77,7 @@
7777
"Cleaner",
7878
"DropUninformative",
7979
"deduplicate",
80-
"compute_ngram_distance",
80+
"deduplicate",
8181
"ToCategorical",
8282
"to_datetime",
8383
"AggJoiner",

skrub/_deduplicate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sklearn.metrics import silhouette_score
1313

1414

15-
def compute_ngram_distance(
15+
def _compute_ngram_distance(
1616
unique_words,
1717
ngram_range=(2, 4),
1818
analyzer="char_wb",
@@ -260,7 +260,7 @@ def deduplicate(
260260
9 white 9 white
261261
"""
262262
unique_words, counts = np.unique(X, return_counts=True)
263-
distance_mat = compute_ngram_distance(
263+
distance_mat = _compute_ngram_distance(
264264
unique_words, ngram_range=ngram_range, analyzer=analyzer
265265
)
266266

skrub/_string_encoder.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ class StringEncoder(TransformerMixin, SingleColumnTransformer):
6161
Used during randomized svd. Pass an int for reproducible results across
6262
multiple function calls.
6363
64+
vocabulary : Mapping or iterable, default=None
65+
In case of "tfidf" vectorizer, the vocabulary mapping passed to the vectorizer.
66+
Either a Mapping (e.g., a dict) where keys are terms and values are
67+
indices in the feature matrix, or an iterable over terms.
68+
6469
Attributes
6570
----------
6671
input_name_ : str
@@ -131,13 +136,15 @@ def __init__(
131136
analyzer="char_wb",
132137
stop_words=None,
133138
random_state=None,
139+
vocabulary=None,
134140
):
135141
self.n_components = n_components
136142
self.vectorizer = vectorizer
137143
self.ngram_range = ngram_range
138144
self.analyzer = analyzer
139145
self.stop_words = stop_words
140146
self.random_state = random_state
147+
self.vocabulary = vocabulary
141148

142149
def fit_transform(self, X, y=None):
143150
"""Fit the encoder and transform a column.
@@ -165,21 +172,29 @@ def fit_transform(self, X, y=None):
165172
ngram_range=self.ngram_range,
166173
analyzer=self.analyzer,
167174
stop_words=self.stop_words,
175+
vocabulary=self.vocabulary,
168176
)
169177
elif self.vectorizer == "hashing":
170-
self.vectorizer_ = Pipeline(
171-
[
172-
(
173-
"hashing",
174-
HashingVectorizer(
175-
ngram_range=self.ngram_range,
176-
analyzer=self.analyzer,
177-
stop_words=self.stop_words,
178+
if self.vocabulary is not None:
179+
raise ValueError(
180+
"Custom vocabulary passed to StringEncoder, unsupported by"
181+
"HashingVectorizer. Rerun without a 'vocabulary' parameter."
182+
)
183+
else:
184+
self.vectorizer_ = Pipeline(
185+
[
186+
(
187+
"hashing",
188+
HashingVectorizer(
189+
ngram_range=self.ngram_range,
190+
analyzer=self.analyzer,
191+
stop_words=self.stop_words,
192+
),
178193
),
179-
),
180-
("tfidf", TfidfTransformer()),
181-
]
182-
)
194+
("tfidf", TfidfTransformer()),
195+
]
196+
)
197+
183198
else:
184199
raise ValueError(
185200
f"Unknown vectorizer {self.vectorizer}. Options are 'tfidf' or"

skrub/tests/test_deduplicate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from sklearn.utils._testing import assert_array_equal, skip_if_no_parallel
88

99
from skrub._deduplicate import (
10+
_compute_ngram_distance,
1011
_create_spelling_correction,
1112
_guess_clusters,
12-
compute_ngram_distance,
1313
deduplicate,
1414
)
1515
from skrub.datasets import make_deduplication_data
@@ -60,7 +60,7 @@ def test_deduplicate(
6060

6161
def test_compute_ngram_distance():
6262
words = np.array(["aac", "aaa", "aaab", "aaa", "aaab", "aaa", "aaab", "aaa"])
63-
distance = compute_ngram_distance(words)
63+
distance = _compute_ngram_distance(words)
6464
distance = squareform(distance)
6565
assert distance.shape[0] == words.shape[0]
6666
assert np.allclose(np.diag(distance), 0)
@@ -70,7 +70,7 @@ def test_compute_ngram_distance():
7070

7171
def test__guess_clusters():
7272
words = np.array(["aac", "aaa", "aaab", "aaa", "aaab", "aaa", "aaab", "aaa"])
73-
distance = compute_ngram_distance(words)
73+
distance = _compute_ngram_distance(words)
7474
Z = linkage(distance, method="average")
7575
n_clusters = _guess_clusters(Z, distance)
7676
assert n_clusters == len(np.unique(words))

skrub/tests/test_string_encoder.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,49 @@ def test_zero_padding_in_feature_names_out(df_module, n_components, expected_col
310310
feature_names = encoder.get_feature_names_out()
311311

312312
assert feature_names[: len(expected_columns)] == expected_columns
313+
314+
315+
def test_vocabulary_parameter(df_module):
316+
voc = {
317+
"this": 5,
318+
"is": 1,
319+
"simple": 3,
320+
"example": 0,
321+
"this is": 6,
322+
"is simple": 2,
323+
"simple example": 4,
324+
}
325+
encoder = StringEncoder(n_components=2, vocabulary=voc)
326+
pipeline = Pipeline(
327+
[
328+
(
329+
"tfidf",
330+
TfidfVectorizer(ngram_range=(3, 4), analyzer="char_wb", vocabulary=voc),
331+
),
332+
("tsvd", TruncatedSVD()),
333+
]
334+
)
335+
X = df_module.make_column(
336+
"col",
337+
["this is a sentence", "this simple example is simple", "other words", ""],
338+
)
339+
340+
enc_out = encoder.fit_transform(X)
341+
pipe_out = pipeline.fit_transform(X)
342+
pipe_out /= scaling_factor(pipe_out)
343+
344+
assert encoder.vectorizer_.vocabulary_ == voc
345+
assert_almost_equal(enc_out, pipe_out)
346+
347+
348+
def test_vocabulary_on_hashing_vectorizer(df_module):
349+
voc = {
350+
"this": 5,
351+
}
352+
encoder = StringEncoder(vocabulary=voc, vectorizer="hashing")
353+
with pytest.raises(ValueError, match="Custom vocabulary passed to StringEncoder*"):
354+
X = df_module.make_column(
355+
"col",
356+
["this is a sentence", "this simple example is simple", "other words", ""],
357+
)
358+
encoder.fit_transform(X)

0 commit comments

Comments
 (0)