Skip to content

Commit

Permalink
Minor fixes in docs
Browse files Browse the repository at this point in the history
  • Loading branch information
mats-claassen committed May 6, 2024
1 parent 743d848 commit c00d25d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/tf_tabular/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def add_inputs_list(
)

def build_input_layers(self) -> tuple[List[tf.keras.layers.Layer], tf.Tensor]:
"""Build input layer stack and return the input layers and the output layer for building the model.
"""Builds input layer stack and return the input layers and the output layer for building the model.
:return tuple[List[tf.keras.layers.Layer], tf.Tensor]: Tuple containing the input layers and the output layer
"""
input_layers = []
Expand Down
2 changes: 1 addition & 1 deletion src/tf_tabular/numeric_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(
self.builder = builder
self.projection_activation = projection_activation

def project(self, x: List[tf.keras.layers.Layer]):
def project(self, x: List[tf.Tensor]) -> List[tf.Tensor]:
"""If num_projection is not None, project the numerical features to a lower or higher dimension.
If a builder is provided, use that to build the projection layer. Otherwise, use the default projection layer.
"""
Expand Down
23 changes: 18 additions & 5 deletions src/tf_tabular/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from itertools import combinations
from typing import List

import numpy as np
import pandas as pd
import tensorflow as tf
Expand All @@ -17,7 +19,7 @@


def _input_layer(name: str, is_multi_hot: bool = False, is_string: bool = False):
"""Build input layer for a column"""
"""Builds input layer for a column"""
shape: tuple[None] | tuple[int]
if is_multi_hot:
shape = (None,)
Expand Down Expand Up @@ -47,7 +49,18 @@ def get_combiner(combiner: str, is_list: bool):
raise NotImplementedError(f"Unknown combiner: {combiner}")


def build_projection_layer(cont_layers, num_projection, l2_reg, activation="relu", cross_features=True):
def build_projection_layer(cont_layers: List[tf.Tensor], num_projection: int, l2_reg: float,
activation: str = "relu", cross_features: bool = True):
"""Builds a projection layer for continuous features. If cross_features is True, it will also include the
multiplication of all pairs of continuous features.
:param List[tf.Tensor] cont_layers: List of continuous layers
:param int num_projection: size of projection layer output neurons
:param float l2_reg: regularization parameter for L2
:param str activation: activation to use in projection layer, defaults to "relu"
:param bool cross_features: Whether to build cross features or not, defaults to True
:return Tensor: output of the projection layer
"""
if cross_features:
cont_layers = list(cont_layers)
pairs = list(combinations(cont_layers, 2))
Expand Down Expand Up @@ -123,7 +136,7 @@ def get_embedding_layer(


def build_continuous_input(name, mean: float | None = None, variance: float | None = None, sample=None):
"""Build th e input layer stack for continuous features
"""Builds the input layer stack for continuous features
:param str name: Layer name
:param float mean: mean of the feature values, defaults to None
Expand All @@ -142,7 +155,7 @@ def build_continuous_input(name, mean: float | None = None, variance: float | No


def build_categorical_input(name, embedding_dim, vocab, is_multi_hot, embedding_df=None):
"""Build input for categorical columns.
"""Builds input for categorical columns.
This function supports many cases because of the different trials we have done for different columns
:param str name: Layer name
Expand Down Expand Up @@ -174,7 +187,7 @@ def build_categorical_input(name, embedding_dim, vocab, is_multi_hot, embedding_


def get_vocab(series: pd.Series, max_size: int | None = None):
"""Get the vocabulary of a series"""
"""Gets the vocabulary (unique items) of a series"""
if isinstance(series.iloc[0], list) or isinstance(series.iloc[0], np.ndarray):
series = series.explode()
series = series.dropna()
Expand Down

0 comments on commit c00d25d

Please sign in to comment.