Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mats-claassen committed May 8, 2024
1 parent 4ab8c7b commit feb43df
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 20 deletions.
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@ name = "tf-tabular"
authors = [
{name = "Mathias Claassen", email = "[email protected]"},
]
description = "TODO"
description = "TF Tabular simplifies the experimentation and preprocessing of tabular datsets for TensorFlow models."
readme = "README.md"
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11"
]
Expand Down Expand Up @@ -110,7 +108,7 @@ line-ending = "auto"


[tool.pytest.ini_options]
addopts = "--cov-report xml:coverage.xml --cov src --cov-fail-under 0 --cov-append -m 'not integration'"
addopts = "--cov-report html:coverage.html --cov src --cov-fail-under 0 --cov-append -m 'not integration'"
pythonpath = [
"src"
]
Expand Down
2 changes: 1 addition & 1 deletion src/tf_tabular/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def add_inputs(self, input_specs: List[InputSpec]):

def add_inputs_list(
self,
categoricals: List[str],
categoricals: List[str] = [],
numericals: List[str] = [],
normalization_params: Dict = {},
vocabs: Dict = {},
Expand Down
10 changes: 8 additions & 2 deletions src/tf_tabular/sequence_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,21 @@


class SequenceProcessor:
def __init__(self, attn_heads: int = 4, key_dim: int = 256, attention_builder=None):
def __init__(
self, attn_heads: int = 4, key_dim: int = 256, attention_builder=None, attention_name: str = "seq_attention"
):
"""The SequenceProcessor concatenates sequential input layers and applies attention on top.
:param int attn_heads: How many attention heads to use, defaults to 4
:param int key_dim: key_dim passed to attention layer, defaults to 256
:param function attention_builder: Optional function that takes a tf.keras.Layer and builds an attention or
whatever other layer on top., defaults to None
:param str attention_name: Name of the attention layer, defaults to "seq_attention"
"""
self.attn_heads = attn_heads
self.key_dim = key_dim
self.attention_builder = attention_builder
self.attention_name = attention_name

def _combine(self, x: List[tf.keras.layers.Layer]):
# Make sure numerical elements can be concatenated to embeddings
Expand All @@ -23,7 +27,9 @@ def _combine(self, x: List[tf.keras.layers.Layer]):
def _attention(self, x: tf.keras.layers.Layer):
if self.attention_builder is not None:
return self.attention_builder(x)
return tf.keras.layers.MultiHeadAttention(num_heads=self.attn_heads, key_dim=self.key_dim)(x, x)
return tf.keras.layers.MultiHeadAttention(
name=self.attention_name, num_heads=self.attn_heads, key_dim=self.key_dim
)(x, x)

def process_layers(self, x: List[tf.keras.layers.Layer]):
"""Processes a list of layers, concatenating them and applying an attention layer."""
Expand Down
40 changes: 27 additions & 13 deletions src/tf_tabular/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from itertools import combinations
import logging
from typing import List

import numpy as np
Expand All @@ -17,6 +18,8 @@
)
from tensorflow.keras.regularizers import L2

logger = logging.getLogger(__name__)


def _input_layer(name: str, is_multi_hot: bool = False, is_string: bool = False):
"""Builds input layer for a column"""
Expand Down Expand Up @@ -49,8 +52,13 @@ def get_combiner(combiner: str, is_list: bool):
raise NotImplementedError(f"Unknown combiner: {combiner}")


def build_projection_layer(cont_layers: List[tf.Tensor], num_projection: int, l2_reg: float,
activation: str = "relu", cross_features: bool = True):
def build_projection_layer(
cont_layers: List[tf.Tensor],
num_projection: int,
l2_reg: float,
activation: str = "relu",
cross_features: bool = True,
):
"""Builds a projection layer for continuous features. If cross_features is True, it will also include the
multiplication of all pairs of continuous features.
Expand Down Expand Up @@ -88,11 +96,15 @@ def get_embedding_matrix(lookup, embedding_df: pd.DataFrame):
om = lookup.output_mode
lookup.output_mode = "int"
out_of_vocab = embedding_df[~embedding_df.id.isin(lookup.get_vocabulary())]
embedding_df = embedding_df[embedding_df.id.isin(lookup.get_vocabulary())]
if out_of_vocab.shape[0] == 0:
# If there is no OOV embedding, we compute the mean of the existing embeddings
oov_embedding = np.mean(embedding_df.embedding.values, axis=0).reshape(1, -1)
else:
oov_embedding = np.mean(np.stack(out_of_vocab.embedding.values), axis=0).reshape(1, -1)
embedding_df = embedding_df[embedding_df.id.isin(lookup.get_vocabulary())]
embedding_df["vocab_id"] = batch_run_lookup_on_df(embedding_df, lookup)
embedding_df = embedding_df.sort_values("vocab_id")
matrix = np.stack(embedding_df.embedding.values).astype(np.float32)
oov_embedding = np.mean(np.stack(out_of_vocab.embedding.values), axis=0).reshape(1, -1)
matrix = np.concatenate([oov_embedding, matrix], axis=0)
lookup.output_mode = om
return matrix
Expand All @@ -104,7 +116,6 @@ def get_embedding_layer(
name: str,
lookup: StringLookup | IntegerLookup | None = None,
embedding_df: pd.DataFrame | None = None,
verbose=False,
):
"""Builds the embedding layer for a categorical column. If embedding_df is provided, it will use the precomputed
embeddings. Otherwise, it will create a trainable embedding layer.
Expand All @@ -114,17 +125,18 @@ def get_embedding_layer(
:param str name: Name of the layer.
:param StringLookup | IntegerLookup | None lookup: Optional lookup layer needed when passing precomputed embeddings.
:param pd.DataFrame | None embedding_df: Precomputed embeddings in a dataframe containing 'id' and 'embeddings' columns, defaults to None
:param bool verbose: When set to True prints attributes of the embedding matrix, defaults to False. Only applies when embedding_df is not None.
:return Embedding: Embedding layer
"""
if embedding_df is None:
return Embedding(num_tokens, embedding_dim, name=name)
if embedding_df.shape[0] == 0:
raise ValueError("Empty embedding dataframe is invalid. Either pass embeddings or set embedding_df to None.")
embedding_matrix = get_embedding_matrix(lookup, embedding_df)
if verbose:
print("Num tokens:", num_tokens, ", embedding matrix:", embedding_matrix.shape)
print("Size of the matrix: ", embedding_matrix.size)
print("Memory size of one array element in bytes: ", embedding_matrix.itemsize)
print("Total size in kb: ", embedding_matrix.itemsize * embedding_matrix.size / 1024)

logger.debug("Num tokens:", num_tokens, ", embedding matrix:", embedding_matrix.shape)
logger.debug("Size of the matrix: ", embedding_matrix.size)
logger.debug("Memory size of one array element in bytes: ", embedding_matrix.itemsize)
logger.debug("Total size in kb: ", embedding_matrix.itemsize * embedding_matrix.size / 1024)

return Embedding(
num_tokens,
Expand All @@ -135,13 +147,15 @@ def get_embedding_layer(
)


def build_continuous_input(name, mean: float | None = None, variance: float | None = None, sample=None):
def build_continuous_input(
name, mean: float | None = None, variance: float | None = None, sample: tf.data.Dataset | np.ndarray | None = None
):
"""Builds the input layer stack for continuous features
:param str name: Layer name
:param float mean: mean of the feature values, defaults to None
:param float variance: variance of the feature values, defaults to None
:param _type_ sample: A sample of features to adapt the layer. You must specify either mean + variance or sample, not both
:param tf.data.Dataset | np.array | None sample: A sample of features to adapt the layer. You must specify either mean + variance or sample, not both
:return tuple: preprocessed input and inputs
"""
inp = _input_layer(name)
Expand Down
32 changes: 32 additions & 0 deletions tests/test_get_vocab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pandas as pd
from tf_tabular.utils import get_vocab


def test_get_vocab_string():
df = pd.Series(["a", "b", "c", "a", "b", "c"])
vocab = get_vocab(df)
assert set(vocab) == set(["a", "b", "c"])


def test_get_vocab_max_size():
df = pd.Series(["a", "b", "c", "a", "b"])
vocab = get_vocab(df, max_size=2)
assert set(vocab) == set(["a", "b"])


def test_get_vocab_int():
df = pd.Series([1, 2, 3, 1, 2, 3])
vocab = get_vocab(df)
assert set(vocab) == set([1, 2, 3])


def test_exclude_none():
df = pd.Series(["a", "b", "_none_"])
vocab = get_vocab(df)
assert set(vocab) == set(["a", "b"])


def test_vocab_lists():
df = pd.Series([["a", "b"], ["c", "d"], ["a", "b"], ["c", "b"]])
vocab = get_vocab(df)
assert set(vocab) == set(["a", "b", "c", "d"])
59 changes: 59 additions & 0 deletions tests/test_input_categoricals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest
import pandas as pd
import numpy as np
from tensorflow.keras import Model
from tf_tabular.builder import InputBuilder


def test_input_builder_defaults():
builder = InputBuilder()
assert builder.input_specs == []
assert builder.sequence_processor is None
assert builder.combiner == "mean"
assert builder.numeric_processor.num_projection is None


def test_add_categoricals_missing_params():
builder = InputBuilder()
pytest.raises(KeyError, builder.add_inputs_list, categoricals=["a", "b"])
pytest.raises(KeyError, builder.add_inputs_list, categoricals=["a", "b"], vocabs={"a": [], "b": []})


def test_add_categoricals_with_embedding():
builder = InputBuilder()
builder.add_inputs_list(
categoricals=["a", "b"], embedding_dims={"a": 10, "b": 20}, vocabs={"a": [1, 2, 3], "b": [4, 5, 6]}
)
assert len(builder.input_specs) == 2
assert builder.input_specs[0].name == "a"
assert builder.input_specs[1].name == "b"
assert builder.input_specs[0].embedding_dim == 10
assert builder.input_specs[1].embedding_dim == 20
assert not builder.input_specs[0].is_sequence
assert not builder.input_specs[1].is_sequence
assert not builder.input_specs[0].is_multi_hot
assert not builder.input_specs[1].is_multi_hot
assert builder.input_specs[0].vocab == [1, 2, 3]
assert builder.input_specs[1].vocab == [4, 5, 6]


def test_add_categoricals_with_embedding_df():
builder = InputBuilder()
emb_a = pd.DataFrame({"id": [1, 2, 3], "embedding": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]})
emb_a["embedding"] = emb_a["embedding"].apply(np.array)
builder.add_inputs_list(
categoricals=["a"], embedding_dims={"a": 10}, vocabs={"a": [1, 2, 3]}, embedding_df={"a": emb_a}
)
inputs, output = builder.build_input_layers()
model = Model(inputs=inputs, outputs=output)
emb_layer = model.get_layer("a_emb")
assert not emb_layer.trainable

layer_embs = emb_layer.get_weights()[0]
expected = np.stack(emb_a.embedding.values).astype(np.float32)

assert layer_embs.shape == (4, 3)

assert np.array_equal(layer_embs[1:], expected)
# assert that the OOV embedding is the mean of the others
assert np.allclose(layer_embs[0], emb_a.embedding.mean())
37 changes: 37 additions & 0 deletions tests/test_input_numericals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest
import numpy as np
from tensorflow.keras import Model
from tf_tabular.builder import InputBuilder


def test_add_numericals_with_normalization():
builder = InputBuilder()
params = {"a": {"sample": np.array([10, 4, 12])}, "b": {"mean": 3.1, "var": 1.0}}
builder.add_inputs_list(numericals=["a", "b"], normalization_params=params)
assert len(builder.input_specs) == 2
assert builder.input_specs[0].name == "a"
assert builder.input_specs[1].name == "b"
assert not builder.input_specs[0].is_sequence
assert not builder.input_specs[1].is_sequence
assert np.array_equal(builder.input_specs[0].sample, params["a"]["sample"])
assert builder.input_specs[1].mean == params["b"]["mean"]
assert builder.input_specs[1].variance == params["b"]["var"]

inputs, output = builder.build_input_layers()
model = Model(inputs=inputs, outputs=output)
assert model.get_layer("a_norm") is not None
assert model.get_layer("b_norm") is not None


def test_add_numericals_no_norm():
builder = InputBuilder()
builder.add_inputs_list(numericals=["a"])
assert len(builder.input_specs) == 1
assert builder.input_specs[0].name == "a"
assert builder.input_specs[0].sample is None
assert builder.input_specs[0].mean is None
assert builder.input_specs[0].variance is None

inputs, output = builder.build_input_layers()
model = Model(inputs=inputs, outputs=output)
pytest.raises(ValueError, model.get_layer, "a_norm")
49 changes: 49 additions & 0 deletions tests/test_sequentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from tensorflow.keras import Model
from tf_tabular.builder import InputBuilder
from tf_tabular.sequence_processor import SequenceProcessor


def test_add_sequential_columns():
builder = InputBuilder(sequence_processor=SequenceProcessor(attention_name="test_attn"))
builder.add_inputs_list(
categoricals=["a", "b"],
embedding_dims={"a": 10, "b": 20},
vocabs={"a": [1, 2, 3], "b": [4, 5, 6]},
sequentials=["a"],
)
inputs, output = builder.build_input_layers()
model = Model(inputs=inputs, outputs=output)
assert model.get_layer("test_attn") is not None


def test_add_multihot_combiner_default():
builder = InputBuilder()
builder.add_inputs_list(categoricals=["a"], embedding_dims={"a": 10}, vocabs={"a": [1, 2, 3]}, multi_hots=["a"])
inputs, output = builder.build_input_layers()
model = Model(inputs=inputs, outputs=output)
assert model.get_layer("a_emb").output_shape == (None, None, 10)
assert model.get_layer("a_emb").trainable
assert output.shape == (None, 10)
assert model.get_layer("global_average_pooling1d_1") is not None


def test_add_multihot_combiner_max():
builder = InputBuilder(combiner="max")
builder.add_inputs_list(categoricals=["a"], embedding_dims={"a": 10}, vocabs={"a": [1, 2, 3]}, multi_hots=["a"])
inputs, output = builder.build_input_layers()
model = Model(inputs=inputs, outputs=output)
assert model.get_layer("a_emb").output_shape == (None, None, 10)
assert output.shape == (None, 10)
assert model.get_layer("global_max_pooling1d") is not None
assert model.get_layer("global_max_pooling1d").output_shape == (None, 10)


def test_add_multihot_combiner_sum():
builder = InputBuilder(combiner="sum")
builder.add_inputs_list(categoricals=["a"], embedding_dims={"a": 10}, vocabs={"a": [1, 2, 3]}, multi_hots=["a"])
inputs, output = builder.build_input_layers()
model = Model(inputs=inputs, outputs=output)
assert model.get_layer("a_emb").output_shape == (None, None, 10)
assert output.shape == (None, 10)
assert model.get_layer("lambda") is not None
assert model.get_layer("lambda").output_shape == (None, 10)

0 comments on commit feb43df

Please sign in to comment.