diff --git a/examples/poyo/configs/base.yaml b/examples/poyo/configs/base.yaml index 93dbb09..86d61ac 100644 --- a/examples/poyo/configs/base.yaml +++ b/examples/poyo/configs/base.yaml @@ -16,12 +16,10 @@ num_workers: 4 seed: 42 subtask_weights: - - 1.0 # RANDOM - - 0.1 # HOLD - - 5.0 # REACH - - 1.0 # RETURN - - 0.1 # INVALID - - 0.0 # OUTLIER + random_period: 1.0 + hold_period: 0.1 + reach_period: 5.0 + return_period: 1.0 optim: base_lr: 3.125e-5 # scaled linearly by batch size diff --git a/examples/poyo/configs/dataset/perich_miller_population_2018.yaml b/examples/poyo/configs/dataset/perich_miller_population_2018.yaml index 9b0b439..97b4444 100644 --- a/examples/poyo/configs/dataset/perich_miller_population_2018.yaml +++ b/examples/poyo/configs/dataset/perich_miller_population_2018.yaml @@ -81,7 +81,7 @@ - m_20150625_center_out_reaching - m_20150626_center_out_reaching config: - eval_subtask_index: 2 + eval_movement_phase: reach_period # Random target reaching @@ -110,4 +110,4 @@ - m_20140221_random_target_reaching - m_20140224_random_target_reaching config: - eval_subtask_index: 0 \ No newline at end of file + eval_movement_phase: random_period \ No newline at end of file diff --git a/examples/poyo/train.py b/examples/poyo/train.py index 981f3d4..e503bbb 100644 --- a/examples/poyo/train.py +++ b/examples/poyo/train.py @@ -212,7 +212,6 @@ def validation_step(self, batch, batch_idx): batch.pop("target_weights") absolute_starts = batch.pop("absolute_start") session_ids = batch.pop("session_id") - output_subtask_index = batch.pop("output_subtask_index") output_mask = batch.pop("output_mask") # forward pass @@ -222,7 +221,6 @@ def validation_step(self, batch, batch_idx): batch["target_values"] = target_values batch["absolute_start"] = absolute_starts batch["session_id"] = session_ids - batch["output_subtask_index"] = output_subtask_index batch["output_mask"] = output_mask return output_values diff --git a/torch_brain/models/poyo.py b/torch_brain/models/poyo.py index ee554d0..3cd67e5 100644 --- a/torch_brain/models/poyo.py +++ b/torch_brain/models/poyo.py @@ -1,9 +1,9 @@ -from typing import Callable, Dict, Iterable, List, Optional, Tuple +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch.nn as nn from torchtyping import TensorType -from temporaldata import Data +from temporaldata import Data, Interval from torch_brain.data import chain, pad8, track_mask8 from torch_brain.nn import ( @@ -259,10 +259,7 @@ def __init__( self.session_tokenizer = session_tokenizer self.modality_spec = modality_spec - self.subtask_weights = subtask_weights - if self.subtask_weights is not None: - self.subtask_weights = np.array(self.subtask_weights, dtype=np.float32) self.latent_step = latent_step self.num_latents_per_step = num_latents_per_step @@ -314,19 +311,30 @@ def __call__(self, data: Data) -> Dict: session_index = self.session_tokenizer(data.session) session_index = np.repeat(session_index, len(output_timestamps)) + # Assign each output timestamp a movement_phase + movement_phase_assignment = dict() + for k in data.movement_phases.keys(): + movement_phase_assignment[k] = interval_contains( + data.movement_phases.__dict__[k], + output_timestamps, + ) + # Weights for the output predictions (used in the loss function) - output_subtask_index = data.get_nested_attribute(self.modality_spec.context_key) if self.subtask_weights is None: output_weights = np.ones(len(output_values), dtype=np.float32) else: - output_weights = self.subtask_weights[output_subtask_index] + output_weights = np.zeros_like(output_timestamps, dtype=np.float32) + for k in data.movement_phases.keys(): + output_weights = output_weights + ( + movement_phase_assignment[k].astype(float) * self.subtask_weights[k] + ) # Mask for the output predictions output_mask = np.ones(len(output_values), dtype=bool) if self.eval: # During eval, only evaluate on the subtask specified in the config - target_subtask_index = data.config["eval_subtask_index"] - output_mask = output_mask & (output_subtask_index == target_subtask_index) + eval_movement_phase = data.config["eval_movement_phase"] + output_mask = output_mask & (movement_phase_assignment[eval_movement_phase]) batch = { # input sequence @@ -349,6 +357,30 @@ def __call__(self, data: Data) -> Dict: if self.eval: batch["session_id"] = data.session batch["absolute_start"] = data.absolute_start - batch["output_subtask_index"] = pad8(output_subtask_index) return batch + + +def interval_contains( + interval: Interval, x: Union[float, np.ndarray] +) -> Union[bool, np.ndarray]: + r"""If x is a single number, returns True if x is in the interval. + If x is a 1d numpy array, return a 1d bool numpy array. + """ + + if isinstance(x, float): + if len(interval) == 0: + return False + + return np.logical_and(x >= interval.start, x < interval.end).any() + elif isinstance(x, np.ndarray): + if len(interval) == 0: + return np.zeros_like(x, dtype=bool) + + x_expanded = x[:, None] + y = np.logical_and( + x_expanded >= interval.start[None, :], x_expanded < interval.end[None, :] + ) + y = np.logical_or.reduce(y, axis=1) + assert y.shape == x.shape + return y