Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No Cross Attention between Seq Embedding and E_map #83

Open
haibao-yu opened this issue Sep 10, 2024 · 2 comments
Open

No Cross Attention between Seq Embedding and E_map #83

haibao-yu opened this issue Sep 10, 2024 · 2 comments
Labels

Comments

@haibao-yu
Copy link

Thanks for your great work! I have a question about your implementation:

  • In Figure 3, the embeded road map is interacted with Seq Embedding and then added to input feature. However, I find the embeded road map is directly added to input feature in your implementation. Can you check it is right implementation?

# 0. camera
N_cam = camera_param.shape[1]
camera_emb = self._embed_camera(camera_param)
# (B, N_cam, max_len + 1, dim=768)
encoder_hidden_states_with_cam = self.add_cam_states(
encoder_hidden_states, camera_emb
)
# we may drop the condition during training, but not drop controlnet
if (self.drop_cond_ratio > 0.0 and self.training):
if encoder_hidden_states_uncond is not None:
encoder_hidden_states_with_cam, uncond_mask = self._random_use_uncond_cam(
encoder_hidden_states_with_cam, encoder_hidden_states_uncond)
controlnet_cond = controlnet_cond.type(self.dtype)
controlnet_cond = self._random_use_uncond_map(controlnet_cond)
else:
uncond_mask = None
# 0.5. bbox embeddings
# bboxes data should follow the format of (B, N_cam or 1, max_len, ...)
# for each view
if bboxes_3d_data is not None:
bbox_embedder_kwargs = {}
for k, v in bboxes_3d_data.items():
bbox_embedder_kwargs[k] = v.clone()
if self.drop_cam_with_box and uncond_mask is not None:
_, n_box = bboxes_3d_data["bboxes"].shape[:2]
if n_box != N_cam:
assert n_box == 1, "either N_cam or 1."
for k in bboxes_3d_data.keys():
ori_v = rearrange(
bbox_embedder_kwargs[k], 'b n ... -> (b n) ...')
new_v = repeat(ori_v, 'b ... -> b n ...', n=N_cam)
bbox_embedder_kwargs[k] = new_v
# here we set mask for dropped boxes to all zero
masks = bbox_embedder_kwargs['masks']
masks[uncond_mask > 0] = 0
# original flow
b_box, n_box = bbox_embedder_kwargs["bboxes"].shape[:2]
for k in bboxes_3d_data.keys():
bbox_embedder_kwargs[k] = rearrange(
bbox_embedder_kwargs[k], 'b n ... -> (b n) ...')
bbox_emb = self.bbox_embedder(**bbox_embedder_kwargs)
if n_box != N_cam:
# n_box should be 1: all views share the same set of bboxes, we repeat
bbox_emb = repeat(bbox_emb, 'b ... -> b n ...', n=N_cam)
else:
# each view already has its set of bboxes
bbox_emb = rearrange(bbox_emb, '(b n) ... -> b n ...', n=N_cam)
encoder_hidden_states_with_cam = torch.cat([
encoder_hidden_states_with_cam, bbox_emb
], dim=2)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor(
[timesteps],
dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
# timesteps = timesteps.expand(sample.shape[0])
timesteps = timesteps.reshape(-1) # time_proj can only take 1-D input
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError(
"class_labels should be provided when num_class_embeds > 0"
)
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# BEV: we remap data to have (B n) as batch size
sample = rearrange(sample, 'b n ... -> (b n) ...')
encoder_hidden_states_with_cam = rearrange(
encoder_hidden_states_with_cam, 'b n ... -> (b n) ...')
if len(emb) < len(sample):
emb = repeat(emb, 'b ... -> (b repeat) ...', repeat=N_cam)
controlnet_cond = repeat(
controlnet_cond, 'b ... -> (b repeat) ...', repeat=N_cam)
# 2. pre-process
sample = self.conv_in(sample)
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample += controlnet_cond

@flymin
Copy link
Member

flymin commented Sep 13, 2024

We follow the implementation of ControlNet, where the control signal (map) is first added with x (noisy latent) and then goes through the copied encoder block. The whole file you referred to is the encoder from Figure 3 (not just the conv block).

@flymin flymin added answered and removed answered labels Sep 19, 2024
Copy link

This issue is stale because it has been open for 7 days with no activity. If you do not have any follow-ups, the issue will be closed soon.

@github-actions github-actions bot added the stale label Sep 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants