Skip to content

Commit

Permalink
Moves subtask_index based weighting to movement_phases based
Browse files Browse the repository at this point in the history
  • Loading branch information
vinamarora8 committed Nov 27, 2024
1 parent 909e23e commit 1c1502e
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 20 deletions.
10 changes: 4 additions & 6 deletions examples/poyo/configs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@ num_workers: 4
seed: 42

subtask_weights:
- 1.0 # RANDOM
- 0.1 # HOLD
- 5.0 # REACH
- 1.0 # RETURN
- 0.1 # INVALID
- 0.0 # OUTLIER
random_period: 1.0
hold_period: 0.1
reach_period: 5.0
return_period: 1.0

optim:
base_lr: 3.125e-5 # scaled linearly by batch size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
- m_20150625_center_out_reaching
- m_20150626_center_out_reaching
config:
eval_subtask_index: 2
eval_movement_phase: reach_period


# Random target reaching
Expand Down Expand Up @@ -110,4 +110,4 @@
- m_20140221_random_target_reaching
- m_20140224_random_target_reaching
config:
eval_subtask_index: 0
eval_movement_phase: random_period
2 changes: 0 additions & 2 deletions examples/poyo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def validation_step(self, batch, batch_idx):
batch.pop("target_weights")
absolute_starts = batch.pop("absolute_start")
session_ids = batch.pop("session_id")
output_subtask_index = batch.pop("output_subtask_index")
output_mask = batch.pop("output_mask")

# forward pass
Expand All @@ -222,7 +221,6 @@ def validation_step(self, batch, batch_idx):
batch["target_values"] = target_values
batch["absolute_start"] = absolute_starts
batch["session_id"] = session_ids
batch["output_subtask_index"] = output_subtask_index
batch["output_mask"] = output_mask

return output_values
Expand Down
52 changes: 42 additions & 10 deletions torch_brain/models/poyo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Callable, Dict, Iterable, List, Optional, Tuple
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import torch.nn as nn
from torchtyping import TensorType
from temporaldata import Data
from temporaldata import Data, Interval

from torch_brain.data import chain, pad8, track_mask8
from torch_brain.nn import (
Expand Down Expand Up @@ -259,10 +259,7 @@ def __init__(
self.session_tokenizer = session_tokenizer

self.modality_spec = modality_spec

self.subtask_weights = subtask_weights
if self.subtask_weights is not None:
self.subtask_weights = np.array(self.subtask_weights, dtype=np.float32)

self.latent_step = latent_step
self.num_latents_per_step = num_latents_per_step
Expand Down Expand Up @@ -314,19 +311,30 @@ def __call__(self, data: Data) -> Dict:
session_index = self.session_tokenizer(data.session)
session_index = np.repeat(session_index, len(output_timestamps))

# Assign each output timestamp a movement_phase
movement_phase_assignment = dict()
for k in data.movement_phases.keys():
movement_phase_assignment[k] = interval_contains(
data.movement_phases.__dict__[k],
output_timestamps,
)

# Weights for the output predictions (used in the loss function)
output_subtask_index = data.get_nested_attribute(self.modality_spec.context_key)
if self.subtask_weights is None:
output_weights = np.ones(len(output_values), dtype=np.float32)
else:
output_weights = self.subtask_weights[output_subtask_index]
output_weights = np.zeros_like(output_timestamps, dtype=np.float32)
for k in data.movement_phases.keys():
output_weights = output_weights + (
movement_phase_assignment[k].astype(float) * self.subtask_weights[k]
)

# Mask for the output predictions
output_mask = np.ones(len(output_values), dtype=bool)
if self.eval:
# During eval, only evaluate on the subtask specified in the config
target_subtask_index = data.config["eval_subtask_index"]
output_mask = output_mask & (output_subtask_index == target_subtask_index)
eval_movement_phase = data.config["eval_movement_phase"]
output_mask = output_mask & (movement_phase_assignment[eval_movement_phase])

batch = {
# input sequence
Expand All @@ -349,6 +357,30 @@ def __call__(self, data: Data) -> Dict:
if self.eval:
batch["session_id"] = data.session
batch["absolute_start"] = data.absolute_start
batch["output_subtask_index"] = pad8(output_subtask_index)

return batch


def interval_contains(
interval: Interval, x: Union[float, np.ndarray]
) -> Union[bool, np.ndarray]:
r"""If x is a single number, returns True if x is in the interval.
If x is a 1d numpy array, return a 1d bool numpy array.
"""

if isinstance(x, float):
if len(interval) == 0:
return False

return np.logical_and(x >= interval.start, x < interval.end).any()
elif isinstance(x, np.ndarray):
if len(interval) == 0:
return np.zeros_like(x, dtype=bool)

x_expanded = x[:, None]
y = np.logical_and(
x_expanded >= interval.start[None, :], x_expanded < interval.end[None, :]
)
y = np.logical_or.reduce(y, axis=1)
assert y.shape == x.shape
return y

0 comments on commit 1c1502e

Please sign in to comment.