Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于chapter8的Crf的问题 #5

Open
hjing100 opened this issue Jan 2, 2023 · 1 comment
Open

关于chapter8的Crf的问题 #5

hjing100 opened this issue Jan 2, 2023 · 1 comment

Comments

@hjing100
Copy link

hjing100 commented Jan 2, 2023

def loss_layer(self, project_logits, lengths, name=None):
    """
    计算CRF的loss
    :param project_logits: [1, num_steps, num_tags]
    :return: scalar loss
    """
    with tf.name_scope("crf_loss" if not name else name):
        small = -1000.0
        # pad logits for crf loss
        start_logits = tf.concat(
            [small * tf.ones(shape=[self.batch_size, 1, self.relation_num]), tf.zeros(shape=[self.batch_size, 1, 1])],
            axis=-1)
        pad_logits = tf.cast(small * tf.ones([self.batch_size, self.num_steps, 1]), tf.float32)
        logits = tf.concat([project_logits, pad_logits], axis=-1)
        logits = tf.concat([start_logits, logits], axis=1)
        targets = tf.concat(
            [tf.cast(self.relation_num * tf.ones([self.batch_size, 1]), tf.int32), self.input_relation], axis=-1)

        self.trans = tf.get_variable(
            name="transitions",
            shape=[self.relation_num + 1, self.relation_num + 1],  # 1
            # shape=[self.relation_num, self.relation_num],  # 1
            initializer=self.initializer)
        log_likelihood, self.trans = crf_log_likelihood(
            inputs=logits,
            tag_indices=targets,
            # tag_indices=self.input_relation,
            transition_params=self.trans,
            # sequence_lengths=lengths
            sequence_lengths=lengths + 1
        )  # + 1
        return tf.reduce_mean(-log_likelihood, name='loss')

你好,
请问这里为什么要+1?
start_logits 是什么?

@hjing100
Copy link
Author

hjing100 commented Jan 2, 2023

我看到参考的别处的是这样的:
log_likelihood, self.transition_params = crf_log_likelihood(inputs=self.logits,
tag_indices=self.labels,
sequence_lengths=self.sequence_lengths)
self.logits直接是DNN的输出,self.labels是[batch_size,seq_len]的真实标签。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant