diff --git a/examples/poyo/configs/base.yaml b/examples/poyo/configs/base.yaml index 86d61ac..3775002 100644 --- a/examples/poyo/configs/base.yaml +++ b/examples/poyo/configs/base.yaml @@ -15,12 +15,6 @@ eval_batch_size: null # if null, will use batch_size num_workers: 4 seed: 42 -subtask_weights: - random_period: 1.0 - hold_period: 0.1 - reach_period: 5.0 - return_period: 1.0 - optim: base_lr: 3.125e-5 # scaled linearly by batch size weight_decay: 1e-4 diff --git a/examples/poyo/configs/dataset/perich_miller_population_2018.yaml b/examples/poyo/configs/dataset/perich_miller_population_2018.yaml index 97b4444..5c51362 100644 --- a/examples/poyo/configs/dataset/perich_miller_population_2018.yaml +++ b/examples/poyo/configs/dataset/perich_miller_population_2018.yaml @@ -81,6 +81,12 @@ - m_20150625_center_out_reaching - m_20150626_center_out_reaching config: + weights: + movement_periods.random_period: 1.0 + movement_periods.hold_period: 0.1 + movement_periods.reach_period: 5.0 + movement_periods.return_period: 1.0 + cursor_outlier_segments: 0.0 eval_movement_phase: reach_period @@ -110,4 +116,10 @@ - m_20140221_random_target_reaching - m_20140224_random_target_reaching config: + weights: + movement_periods.random_period: 1.0 + movement_periods.hold_period: 0.1 + movement_periods.reach_period: 5.0 + movement_periods.return_period: 1.0 + cursor_outlier_segments: 0.0 eval_movement_phase: random_period \ No newline at end of file diff --git a/torch_brain/models/poyo.py b/torch_brain/models/poyo.py index 3cd67e5..362a5b1 100644 --- a/torch_brain/models/poyo.py +++ b/torch_brain/models/poyo.py @@ -253,13 +253,11 @@ def __init__( modality_spec: ModalitySpec, sequence_length: float = 1.0, eval: bool = False, - subtask_weights: Optional[Iterable[float]] = None, ): self.unit_tokenizer = unit_tokenizer self.session_tokenizer = session_tokenizer self.modality_spec = modality_spec - self.subtask_weights = subtask_weights self.latent_step = latent_step self.num_latents_per_step = num_latents_per_step @@ -308,33 +306,18 @@ def __call__(self, data: Data) -> Dict: if output_values.dtype == np.float64: output_values = output_values.astype(np.float32) - session_index = self.session_tokenizer(data.session) - session_index = np.repeat(session_index, len(output_timestamps)) + output_session_index = self.session_tokenizer(data.session) + output_session_index = np.repeat(output_session_index, len(output_timestamps)) - # Assign each output timestamp a movement_phase - movement_phase_assignment = dict() - for k in data.movement_phases.keys(): - movement_phase_assignment[k] = interval_contains( - data.movement_phases.__dict__[k], - output_timestamps, - ) - - # Weights for the output predictions (used in the loss function) - if self.subtask_weights is None: - output_weights = np.ones(len(output_values), dtype=np.float32) - else: - output_weights = np.zeros_like(output_timestamps, dtype=np.float32) - for k in data.movement_phases.keys(): - output_weights = output_weights + ( - movement_phase_assignment[k].astype(float) * self.subtask_weights[k] - ) - - # Mask for the output predictions - output_mask = np.ones(len(output_values), dtype=bool) - if self.eval: - # During eval, only evaluate on the subtask specified in the config - eval_movement_phase = data.config["eval_movement_phase"] - output_mask = output_mask & (movement_phase_assignment[eval_movement_phase]) + output_weights = np.ones_like(output_timestamps, dtype=np.float32) + if "weights" in data.config: + weights = data.config["weights"] + for weight_key, weight_value in weights.items(): + # extract the interval from the weight key + weight = data.get_nested_attribute(weight_key) + if not isinstance(weight, Interval): + raise ValueError(f"Weight {weight_key} is not an Interval") + output_weights[isin_interval(output_timestamps, weight)] *= weight_value batch = { # input sequence @@ -346,9 +329,9 @@ def __call__(self, data: Data) -> Dict: "latent_index": latent_index, "latent_timestamps": latent_timestamps, # output sequence - "output_session_index": pad8(session_index), + "output_session_index": pad8(output_session_index), "output_timestamps": pad8(output_timestamps), - "output_mask": pad8(output_mask), + "output_mask": track_mask8(output_session_index), # ground truth targets "target_values": pad8(output_values), "target_weights": pad8(output_weights), @@ -361,26 +344,22 @@ def __call__(self, data: Data) -> Dict: return batch -def interval_contains( - interval: Interval, x: Union[float, np.ndarray] -) -> Union[bool, np.ndarray]: - r"""If x is a single number, returns True if x is in the interval. - If x is a 1d numpy array, return a 1d bool numpy array. - """ +def isin_interval(timestamps: np.ndarray, interval: Interval) -> np.ndarray: + r"""Check if timestamps are in any of the intervals in the `Interval` object. - if isinstance(x, float): - if len(interval) == 0: - return False + Args: + timestamps: Timestamps to check. + interval: Interval to check against. - return np.logical_and(x >= interval.start, x < interval.end).any() - elif isinstance(x, np.ndarray): - if len(interval) == 0: - return np.zeros_like(x, dtype=bool) + Returns: + Boolean mask of the same shape as `timestamps`. + """ + if len(interval) == 0: + return np.zeros_like(timestamps, dtype=bool) - x_expanded = x[:, None] - y = np.logical_and( - x_expanded >= interval.start[None, :], x_expanded < interval.end[None, :] - ) - y = np.logical_or.reduce(y, axis=1) - assert y.shape == x.shape - return y + timestamps_expanded = timestamps[:, None] + mask = np.any( + (timestamps_expanded >= interval.start) & (timestamps_expanded < interval.end), + axis=1, + ) + return mask