diff --git a/README.md b/README.md index 7cfb151..671896b 100644 --- a/README.md +++ b/README.md @@ -44,11 +44,15 @@ labels = torch.tensor([0]) # 1 batch # focal loss focal_loss = Loss(loss_type="focal_loss") loss = focal_loss(logits, labels) +``` +```python # cross-entropy loss ce_loss = Loss(loss_type="cross_entropy") loss = ce_loss(logits, labels) +``` +```python # binary cross-entropy loss bce_loss = Loss(loss_type="binary_cross_entropy") loss = bce_loss(logits, labels) @@ -74,7 +78,9 @@ focal_loss = Loss( class_balanced=True ) loss = focal_loss(logits, labels) +``` +```python # class-balanced cross-entropy loss ce_loss = Loss( loss_type="cross_entropy", @@ -82,7 +88,9 @@ ce_loss = Loss( class_balanced=True ) loss = ce_loss(logits, labels) +``` +```python # class-balanced binary cross-entropy loss bce_loss = Loss( loss_type="binary_cross_entropy",