Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code Explanation of AutoRegressivePolicy class #9

Closed
a510721 opened this issue Nov 14, 2024 · 10 comments
Closed

Code Explanation of AutoRegressivePolicy class #9

a510721 opened this issue Nov 14, 2024 · 10 comments
Labels
documentation Improvements or additions to documentation

Comments

@a510721
Copy link

a510721 commented Nov 14, 2024

Thank you for sharing your good research results.
In the AutoRegressivePolicy class, there are inputs called tks (tk_ids, tk_vals) and chk_ids. What do they mean and how should I use them?

@mlzxy
Copy link
Owner

mlzxy commented Nov 14, 2024

Explanation of tk_ids, tk_vals, chk_ids

Thanks for reaching out. Suppose we have an action sequence, [a1, a2, a3, a4, a5].

  • If the chk_ids are [0, 0, 1, 2, 2], the policy will generate a1, a2 in a chunk, then a3, then a4, a5.
  • If the tk_ids are [0, 1, 1, 2, 3, 3], it means a1 is token type 0, a2, a3 belong to token type 1, a4 is token type 2, a5 is token type 3. The token "types" correspond to what actions you define, for example in ALOHA (code link), we have state, low-level joint values, 2d visual guide point in action sequences. The ids can be obtained from policy.token_name_2_ids (like here).
  • Supposes a1 = [0.1, 0.1] (a 2d action), a2=[2] (a discrete action), and a3 = [0.3, 0.3, 0.3], a4 = [0.4, 0.4, 0.4, 0.4], a5 = [0.5], then the tk_vals is a tensor (shape = 5x4`) combines a1 to a5 through padding:
    [a1]  [a2]  [a3]  [a4]  [a5]
    0.1    2    0.3   0.4   0.5 
    0.1    0    0.3   0.4   0   
    0      0    0.3   0.4   0   
    0      0    0     0.4   0   
    

@mlzxy mlzxy pinned this issue Nov 14, 2024
@a510721
Copy link
Author

a510721 commented Nov 15, 2024

Thank you for your response. I have been looking at it, but the code is too complex and difficult. Could you please explain in detail the codes of 'generate' and 'compute_loss' within 'arp.py'? It would be great if you could provide an example based on 'aloha'.

@mlzxy
Copy link
Owner

mlzxy commented Nov 15, 2024

@a510721 Sure! I have some more urgent stuff to attend during the day and I will add explanations tonight.

WORKING ON IT.

@mlzxy mlzxy changed the title Can you provide a description of the AutoRegressivePolicy class input? Code Explanation of AutoRegressivePolicy class? Nov 15, 2024
@mlzxy mlzxy changed the title Code Explanation of AutoRegressivePolicy class? Code Explanation of AutoRegressivePolicy class Nov 15, 2024
@mlzxy
Copy link
Owner

mlzxy commented Nov 16, 2024

Explanation of ModelConfig

Line: L319

Basic parameters

  • n_embd: token embedding size
  • embd_pdrop: dropout ratio of input embeddings
  • layer_norm_every_block: whether apply layer norm at each transformer layer
  • max_seq_len: the maximum sequence length (used to initialize position embedding vectors, could just set this to a large enough value)
  • max_chunk_size: maximum chunk size (like max_seq_len, just set this to a value that is large enough)
  • layers: a list of LayerType, the created model will have len(layers) transformer layers
  • tokens: a list of TokenType, the definitions of all action tokens

LayerType

Line: L291

  • name (optional): the name of this layer
  • n_head: number of heads in attention
  • attn_kwargs: parameter dict pass to Attention class for self-causal-attention, the common choice is dict(attn_pdrop=0.1, resid_pdrop=0.1)
  • cond_attn_kwargs: same as attn_kwargs but for cross-attention
  • mlp_dropout: dropout ratio for the MLP in transformer
  • mlp_ratio: relative size for the hidden size of MLP, e.g., for a embedding size of 128, mlp_ratio of 4, the MLP has a hidden size of 512
  • AdaLN: whether apply modulation in cross attention, we redesign the modulation technique in Diffusion Transformer for common transformer. The origin technique is used to fuse the information between a fixed length noise and condition signal, we modify it to support fusion between arbitrary length sequences.
  • norm_before_AdaLN: whether apply an extra layer normalization before applying the modulation. Sometimes, turn this on if training has NaN (usually rare).
  • condition_on: the name of the features that will be used for cross attention, for this layer. During forward, we store all input features into a dict, different layers can attend to different features.

TokenType

Line: L258

  • id: Don't need to care about this, it is set automatically.

  • name: name of this token

  • is_control: whether this token is a "control token". Here a control token means it is given and do not require generation.

  • dim: size of this action, e.g., ALOHA has joint position as actions with dim 14.

  • embedding: a string that denotes the type of embedding for this action, available embeddings are:

    • zero: embeddings that always return 0. L473
    • discrete: work like nn.Embedding, except that it can embed discrete tokens not just from internal weights but input features (configured through "embed_from" parameters in embedding_kwargs) L479
    • position_1d / position_2d: triangular position embedding for actions with dim=1, dim=2 L507, L525
    • linear: embed by a linear layer projection, mostly used for continuous vectors like actions L546.
    • feat_grid_2d: embed a 2d pixel coordination action by sampling from a NCHW format spatial feature map (configured by "sampling_from" parameters in embedding_kwargs) L558
  • embedding_kwargs: the parameter dict for the embedding class.

  • is_continuous: whether this action is continuous.

  • bounds: bounds of this action [min_of_dim1, max_of_dim1, min_of_dim2, max...], only requires when input a continuous actions by using a discrete embedding. In this case, the model will try to quantize the continuous actions.

  • dict_sizes: the vocabulary size for discrete actions (only needed for discrete embedding), [dict_size_of_dim1, dict_size_of_dim2, ...], for example, if an action takes the possible values from 0 to 99, then dict_sizes shall be [100]

  • predictor: a string that denotes the type of decoder for this action (decoding embeddings into actions), available choices are:

    • gmm: a linear layer that regress continuous values by predicting the parameters of a GMM distribution, inspired from the GMM implementation in robomimic L743
    • class: predict categorical (discrete) action values L830
    • upsample_from_2d_attn: predict 2d pixel location, by computing the feature correlation by dot product action embedding with a given NCHW spatial feature map (configured by attn_with in predictor_kwargs), and then upscale the correlation. In doing so, we can predict highly precise 2d pixel location without using a huge linear layer + softmax. L926
  • predictor_kwargs: the parameter dict for the predictor class.

📓 I will add more explanation to each embedding and predictor class later on, before that, you can use the existing usage in ALOHA and etc. as examples.

🤝 I admit this is not a trivial implementation, despite the simple idea behind, because there is no silver bullet for robot imitation learning at this moment. Details are needed for a good performance. I will keep documenting in the following.

@mlzxy mlzxy added the documentation Improvements or additions to documentation label Nov 16, 2024
@mlzxy
Copy link
Owner

mlzxy commented Nov 16, 2024

Explanation of compute_loss

The input to this function is your action sequences and additional features (stored in a dict called contexts). This function computes the loss (maximize conditional likelihood of all tokens through teacher forcing), and returns the loss dictionary to you. A simple analogy is the forward function in minGPT.

Note that I do not use ARP to generate long action sequences that cover the task from start to finish, but just follow the standard imitation learning setting. For example, if I need 100 actions at each step / inference (like in ALOHA), then I focus on how to generate these 100 actions using autoregression.

These 100 actions will be executed sequentially. When the execution finish, we generate the next 100 actions from the updated observation.

  • tks: all action tokens in shape (bs, L, d+1), where
    • L is the sequence length
    • d is the maximum dimension of the given actions. For example, in your action sequence, you have three 2-d pixel actions, and 10 7-d low-level robot actions. Then d is 7, and extra data can be padded with zero. A useful utility to concat uneven sized sequences is cat_uneven_blc_tensors
    • the last dimension is token id, which denote the type of action at each position. The mapping from token name to token id can be get from token_name_2_ids, as in L1024.
  • chk_ids: chunk id for the action tokens, in shape (1, L), (bs, L), or None. If None, the model will assume a chunk size of 1 (next-single-token prediction). Tokens with the same chunk id will be predicted at the same step. More explanation / example in this comment.
  • valid_tk_mask: boolean tensor in shape (bs, L) or None, denote whether the corresponding token needs to be considered in loss computation. L1184
  • skip_tokens: a list of int (token id). The loss of these skip tokens will not be computed. L1181
  • match_layer: a string that used to filter layers (since each layer can be given a name) L1168
  • block_attn_directions: a list of (int, int) or (str, str). For example, if we have block_attn_directions=["a", "b"], this means we don't want action "a" to attend to action "b". Attention mask will be generated if this parameter is set. L1162
  • log_prob (bool): whether return the condition log likelihood of the given tks. The likelihood has shape (bs, L). L1199
  • contexts (dict of tensors): used to provide features for cross attention (see the condition_on parameter in layer type). I will explain the use case with an example in the following (next reply).

Additional things:

  • In L1160, tk_codes = self.token_coder.encode(tks, tk_ids). This will encode the given tokens into discrete values if necessary, using the bounds and dict_sizes from TokenType.
  • In L1166, embs_hat = self.chunk_embedder(chk_ids, tk_ids). This will create the embeddings for empty tokens, based on the chunk ids you provide.
  • In L1174, we call the chunking causal transformers for interleaved forward, with [embs_star, embs_hat], the embs_star is the original action sequence embeddings, the embs_star are the embeddings for all empty tokens.

@mlzxy
Copy link
Owner

mlzxy commented Nov 16, 2024

Example of ALOHA with compute_loss:

Regarding our policy for ALOHA

  1. At L190, we define our action sequences as ['state', 'action', 'action', ..., 'guide-pt-left', .. , 'guide-pt-right'], where guide-pt-* are 2D pixel waypoints, action is 14-dim joint positions.

  2. You can add a breakpoint at L199, to check the value of chk_ids. The default chunk configuration represents that it will first predict all low-level actions, then predict the 2D pixel waypoints (high-level) one at a time. (put the high-level first gives better results, but slower)

  3. From Line 200 to Line 228, we create the 2D pixel waypoints by sampling the a sparse trajectory from dense data with F.interpolate, then we create a smooth heatmap from each 2D pixel coordinate (label smoothing) which will be used later for training signals ('smooth-heatmap-right', 'smooth-heatmap-left').

  4. In L239 of our policy for ALOHA, we have

    loss_dict = self.policy.compute_loss(tks, chk_ids, contexts={   										
      			'visual-tokens': encoder_out,  # used for cross attention (see condition_on in LayerType)
    												
    			# see below
      			'visual-featmap': visual_featmap, 
                            'smooth-heatmap-right': heatmap_right.flatten(0, 1), 
                            'smooth-heatmap-left': heatmap_left.flatten(0, 1) }, 
                        
                            valid_tk_mask=~tk_is_pad_mask) # exclude padded actions from loss computation 

    Note that besides visual-tokens used as the visual condition signal for cross attention, we have visual-featmap, and smooth-heatmap-right/left. They corresponds to the predictor_kwargs in L117:

    guide_token_right = arp.TokenType.make(name='guide-pt-right', is_continuous=True, dim=2, 
                                    embedding='position_2d', predictor="upsample_from_2d_attn", 
                                    predictor_kwargs={'attn_with': 'visual-featmap',                                                  
                                                      'label_name': 'smooth-heatmap-right', 
                                                      'upscale_ratio': stride // self.guide_pts_downsample 
                                                     })													
    														# I exclude the corr_dim parameter as it is not used and shall be removed actually

    The attn_with parameter is visual-featmap, which means that the output token embeddings will multiply with this visual-featmap, then the result correlation map will be upsample by upscale_ratio.

    The upsampled result is treated as a 2d logit heatmap, which describes the 2d pixel location, which is trained by cross entropy loss. The label is used by default the input 2d coordinate values, but here we override the label by providing label_name in the predictor_kwargs and in the corresponding contexts smooth-heatmap-right, which has smooth labels.

    This is partially illustrated in the appendix figures about action embedder and decoders.

    BTW, If you set visual-tokens to be None, then cross-attention with visual tokens will be skipped and only self-attention within action sequences will be applied.

@mlzxy
Copy link
Owner

mlzxy commented Nov 16, 2024

Explanation of generate

The input to this function includes (1) action sequence prompt, (2) additional features (stored in a dict called contexts), and (3) the token id (for token type) + chunk id (tokens with the same chunk ids will be predicted in a single chunk) of future sequences. It will return the generated action sequences. A simple analogy of this function is the generate in minGPT.

  • prompt_tks (tensor with shape B, L, d+1): Just like the tks for compute_loss function, but much shorter and only includes the prompt sequences. For example:

    • In ALOHA, the prompt_tk_vals is the state.
    • If you don't have prompt tokens, just pass in an empty tensor, like here.
  • future_tk_chk_ids (a list of IncompleteToken): The format of IncompleteToken is a dict {chk_id: int, tk_id: int, tk_val: int} defined at L334. For example, suppose the future_tk_chk_ids is:

    name2id = policy.token_name_2_ids
    future_tk_chk_ids = [
      {chk_id: 2, tk_id: name2id["action-1"]}, 
      {chk_id: 2, tk_id: name2id["action-1"]}, 
      {chk_id: 3, tk_id: name2id["hint"], tk_val: 3}, 
      {chk_id: 4, tk_id: name2id["action-2"]}, 
    ]

    Then, the model will generate two "action-1" in the first chunk, then append the sequences with one "hint" token of value equals to 3 without running inference, then generate a "action-2" token in the third chunk.

    • Note that if you have a prompt_tks of size (B, L, d+1), then the chk_id must start from L.

    • One use case of tk_val (of that "hint" example here) is to inject features into the autoregressive generation, for example, in the our policy for RLBench L343, we have a token called prompt_features:

      TokenType.make(name='prompt-features', dim=1, embedding='discrete', 
          is_control=True, embedding_kwargs={'embed_from': "prompt-features"})

      It has a embed_from parameter. In generate (compute_loss as well), we provide this token and corresponding feature in L709:

      contexts = { 'prompt-features': prompt_features }

      The prompt_features has a shape of (batch_size, dict_size, channel_size), where dict_size is the largest tk_val plus 1

  • sample (boolean): whether run sampling or just selecting the most likely value, when predicting each token. In some cases like ALOHA it does not matter. But in Push-T where there is more multi-modality involved, sample=True shall give better performance.

  • contexts (dict of tensors): like in compute_loss, also see examples given in this page.

  • block_attn_directions (list of (int, int) or (str, str)): used to configure attention masking, like [(a, b)] prevents token "a" to attend to token "b", similar to the same parameter in compute_loss.

  • match_layer (str): a string used to filter layers, as each transformer layer can be given a name

    • mostly not used unless want to squeeze efficiency / accuracy, like only running self-attention without cross-attentions for certain sequences by filtering and only running some specific layers.
  • sample_function (a function or a dict of functions, explained below):

    By default, the predicted actions are sampled by predictor, like the "class" predictor (L864) samples from the predicted categorical distribution with torch.multinomial. However, in some cases, we may want to override the default sampling behavior, at least for certain chunk.

    Like in RLBench, we follow the RVT practice to predict 2d pixel coordinate (for gripper position) distribution of each view (as a 2d logit heatmap), but when sampling or determining the actual 2d pixel coordinate, we need to consider the 3D consistency between multiple views. And this parameter help you do that, I will explain it with the example from RLBench.

    In L684 of the policy for RLBench, we set sample_function as self.sample_callback, the sample_callback is defined at L323. What it does is simple, just stores the 2d heatmap in an external list self.spatial_logits_buffer. The actual 2d pixel coordinate at each view is obtained by self.renderer.get_most_likely_point_3d in L596.

    In this example, we predict all 2d heatmaps separately, all in a single chunk by calling generate repeatedly, as shown by chk_id=0 in L682 . So we just set sample_function=self.sample_callback, and the sample_callback do not have to return anything meaningful (as it is not required and there is no following sequence).

    But predicting them all together and with other tokens is also fine (the current status is due to some legacy reason). The sample_function can be configure at per chunk level, e.g.,

    sample_function = {frozenset({0, 1}): func1, frozenset({2}): func2} # use func1 for chk-id 1,2; func2 for chk-id 3

    The sample_function shall return the sampled tokens in shape B, L, d, where d is the dim size of the token, e.g., 2 for 2d pixel coordinate.

@mlzxy
Copy link
Owner

mlzxy commented Nov 16, 2024

@a510721 I have added some descriptions. Hope this could help you a bit. I suggest you step through the code with a vscode debugger or other debugging tools you prefer, with these descriptions as some references.

Let me know if I can be more of help.

@a510721 a510721 closed this as completed Nov 20, 2024
@a510721
Copy link
Author

a510721 commented Nov 20, 2024

Thank you for your detailed explanation.

@mlzxy
Copy link
Owner

mlzxy commented Nov 20, 2024

Thank you for your detailed explanation.

My pleasure. I will add more explanation on the Predictor & Embedder for different actions later this week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

2 participants