Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When I compute gradients and perform updates using the same values in Torch and JAX, I find that after multiple iterations, the inference results differ significantly. #23646

Open
CZXIANGOvO opened this issue Sep 15, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@CZXIANGOvO
Copy link

Description

Please specify cuda:0 at the very beginning.

import torch
import numpy as np
import os
from network.cv.SSD.backbone_mobilenetv1_pytorch import SSDWithMobileNetV1 as SSD_torch
import jax
import jax
import jax.numpy as jnp
from jax import ops as jops
from jax.nn import one_hot, sigmoid
from jax import lax
import jax.scipy.special as sc
import optax

if "CONTEXT_DEVICE_TARGET" in os.environ and os.environ['CONTEXT_DEVICE_TARGET'] == 'GPU':
    devices = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
    device = devices[-2]
    final_device = "cuda:" + device
else:
    final_device = 'cpu'


def class_loss_jax(logits, label):
    """Calculate category losses."""
    label = jnp.eye(logits.shape[-1])[label]
    weight = jnp.ones_like(logits)
    pos_weight = jnp.ones_like(logits)
    sigmoid_logits = sc.expit(logits)

    # Binary cross entropy calculation
    term1 = label * jnp.log(sigmoid_logits + 1e-15)
    term2 = (1 - label) * jnp.log(1 - sigmoid_logits + 1e-15)

    loss = - (weight * (term1 * pos_weight + term2))
    sigmoid_cross_entropy = jnp.mean(loss)
    sigmoid = sc.expit(logits)
    p_t = label * sigmoid + (1 - label) * (1 - sigmoid)
    modulating_factor = jnp.power(1 - p_t, 2.0)
    alpha_weight_factor = label * 0.75 + (1 - label) * (1 - 0.75)
    focal_loss = modulating_factor * alpha_weight_factor * sigmoid_cross_entropy
    return focal_loss
    

def SSDmultibox_jax_cal(params_, pred_loc, pred_label, gt_loc, gt_label, num_matched_boxes):
    mask = jnp.less(0, gt_label).astype(jnp.float32)
    num_matched_boxes = jnp.sum(num_matched_boxes.astype(jnp.float32))
    # Positioning loss
    mask_loc = jnp.tile(jnp.expand_dims(mask, -1), (1, 1, 4))
    diff = jnp.abs(pred_loc - gt_loc)
    smooth_l1 = jnp.where(diff < 1, 0.5 * diff ** 2, diff - 0.5)
    smooth_l1 = smooth_l1 * mask_loc
    loss_loc = jnp.sum(jnp.sum(smooth_l1, -1), -1)
    loss_cls = class_loss_jax(pred_label, gt_label)
    loss_cls = jnp.sum(loss_cls, (1, 2))
    return jnp.sum((loss_cls + loss_loc) / num_matched_boxes)


class loss_SSDmultibox_torch(torch.nn.Module):
    def __init__(self):
        super(loss_SSDmultibox_torch, self).__init__()

    def forward(self, pred_loc, pred_label, gt_loc, gt_label, num_matched_boxes):
        mask = (gt_label > 0).float()
        num_matched_boxes = num_matched_boxes.float().sum()

        # Positioning loss
        mask_loc = mask.unsqueeze(-1).repeat(1, 1, 4)
        smooth_l1 = torch.nn.SmoothL1Loss(reduction='none')(pred_loc, gt_loc) * mask_loc
        loss_loc = smooth_l1.sum(dim=-1).sum(dim=-1)

        # Category loss
        from network.cv.SSD.ssd_utils_torch import class_loss
        loss_cls = class_loss(pred_label, gt_label)
        loss_cls = loss_cls.sum(dim=(1, 2))

        return ((loss_cls + loss_loc) / num_matched_boxes).sum()


image_torch = np.load('./image_torch.npy')
image_torch = torch.from_numpy(image_torch).to(final_device)

pred_loc_torch = np.load('./pred_loc_torch.npy')
pred_loc_torch = torch.from_numpy(pred_loc_torch).to(final_device)
pred_label_torch = np.load('./pred_label_torch.npy')
pred_label_torch = torch.from_numpy(pred_label_torch).to(final_device)
box_torch = np.load('./box_torch.npy')
box_torch = torch.from_numpy(box_torch).to(final_device)
label_torch = np.load('./label_torch.npy')
label_torch = torch.from_numpy(label_torch).to(final_device)
num_match_torch = np.load('./num_match_torch.npy')
num_match_torch = torch.from_numpy(num_match_torch).to(final_device)

model_torch = SSD_torch()
model_torch.train()
model_torch.to(final_device)


learning_rate = 0.02
optimizer_torch = torch.optim.SGD
optimizer_torch = optimizer_torch(model_torch.parameters(), lr=learning_rate)
params_torch = {key: value.detach().cpu().numpy() for key, value in model_torch.state_dict().items()}
params_jax = {name: jnp.array(value, dtype=jnp.float32) for name, value in params_torch.items()}

optimizer_jax = optax.sgd
optimizer_jax = optimizer_jax(learning_rate)
loss_fun_torch = loss_SSDmultibox_torch()
opt_state = optimizer_jax.init(params_jax)

for i in range(0,10):

    pred_loc_torch, pred_label_torch = model_torch(image_torch)

    loss_torch = loss_fun_torch(pred_loc_torch, pred_label_torch, box_torch, label_torch, num_match_torch)

    loss_torch.backward()
    optimizer_torch.step()

    optimizer_torch.zero_grad()
    old_torch_state_dict = model_torch.state_dict()
    torch.save(old_torch_state_dict, './model_weights.pth')



    params_jax_numpy = {name: np.array(value) for name, value in params_jax.items()}
    params_torch_updated = {name: torch.from_numpy(value) for name, value in params_jax_numpy.items()}
    model_torch.load_state_dict(params_torch_updated)

    pred_loc_torch, pred_label_torch = model_torch(image_torch)

    pred_loc_jax = pred_loc_torch.detach().cpu().numpy()
    pred_label_jax = pred_label_torch.detach().cpu().numpy()

    loss_fun_jax = SSDmultibox_jax_cal

    pred_loc_jax = pred_loc_torch.detach().cpu().numpy()
    pred_label_jax = pred_label_torch.detach().cpu().numpy()
    box_jax = box_torch.detach().cpu().numpy()
    label_jax = label_torch.detach().cpu().numpy()
    num_match_jax = num_match_torch.detach().cpu().numpy()
    loss_jax, jax_grads = jax.value_and_grad(loss_fun_jax)(params_jax, pred_loc_jax, pred_label_jax, box_jax,
                                                            label_jax, num_match_jax)



    updates, opt_state = optimizer_jax.update(jax_grads, opt_state, params_jax)
    params_jax = optax.apply_updates(params_jax, updates)

    # jax_grads_distance = chebyshev_distance(old_jax_grads, jax_grads)
    # old_jax_grads = jax_grads
    torch_grads = {key: value.detach().cpu().numpy() for key, value in model_torch.state_dict().items()}

    loaded_state_dict = torch.load('./model_weights.pth')
    model_torch.load_state_dict(loaded_state_dict)

    print('loss_jax/loss_torch:',np.array(loss_jax)/ loss_torch.cpu().detach().numpy())  # 输出: True

屏幕截图 2024-09-15 191310

System info (python version, jaxlib version, accelerator, etc.)

download the code:https://drive.google.com/file/d/1H8uPgPdslVpizmSsif6oK4ey2e-oum9x/view?usp=sharing

!unzip issue3.zip
python issue1.py
@CZXIANGOvO CZXIANGOvO added the bug Something isn't working label Sep 15, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Sep 15, 2024

If it's truly the same model you're running, this is surprising. Given the magnitude of the difference, though, I suspect the models or optimizers differ in important ways: for example, maybe the precise definition of "learning rate" differs between the two implementations.

I don't know either optax or pytorch well enough to guess where that difference might lie, but if it's important to you to debug these differences in the implementations, that's probably where I'd start.

@CZXIANGOvO
Copy link
Author

If it's truly the same model you're running, this is surprising. Given the magnitude of the difference, though, I suspect the models or optimizers differ in important ways: for example, maybe the precise definition of "learning rate" differs between the two implementations.

I don't know either optax or pytorch well enough to guess where that difference might lie, but if it's important to you to debug these differences in the implementations, that's probably where I'd start.

We're using the same model.

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 18, 2024

We're using the same model.

Sure, but what I'm suggesting is that you may not be using the same optimizer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants