def readout_features(self, graphTensor: tfgnn.GraphTensor, feature_name: str) -> tf.Tensor:
"""Extract features using dynamic root node type from context"""
# 1. Get root node type from context feature
root_node_type = graphTensor.context["root_node_type"] # Shape [batch_size, 1]
#root_node_type = tf.squeeze(root_node_type, axis=1) # Shape [batch_size]
batch_size = tf.shape(root_node_type)[0]
features = tf.TensorArray(tf.float32, size=batch_size)
for i in tf.range(batch_size):
current_type = root_node_type[i]
component = tfgnn.get_component(graph_tensor, i)
feature = tfgnn.gather_first_node(
component,
node_set_name=current_type,
feature_name=tfgnn.HIDDEN_STATE
)
features = features.write(i, feature)
return features.stack()
It is very inconvenient to get embeddings for first nodes mutiple node sets in one time.