Skip to content

Commit

Permalink
use weights as defined in dataset.yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
mazabou committed Nov 27, 2024
1 parent 1c1502e commit a3d2329
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 56 deletions.
6 changes: 0 additions & 6 deletions examples/poyo/configs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ eval_batch_size: null # if null, will use batch_size
num_workers: 4
seed: 42

subtask_weights:
random_period: 1.0
hold_period: 0.1
reach_period: 5.0
return_period: 1.0

optim:
base_lr: 3.125e-5 # scaled linearly by batch size
weight_decay: 1e-4
Expand Down
12 changes: 12 additions & 0 deletions examples/poyo/configs/dataset/perich_miller_population_2018.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@
- m_20150625_center_out_reaching
- m_20150626_center_out_reaching
config:
weights:
movement_periods.random_period: 1.0
movement_periods.hold_period: 0.1
movement_periods.reach_period: 5.0
movement_periods.return_period: 1.0
cursor_outlier_segments: 0.0
eval_movement_phase: reach_period


Expand Down Expand Up @@ -110,4 +116,10 @@
- m_20140221_random_target_reaching
- m_20140224_random_target_reaching
config:
weights:
movement_periods.random_period: 1.0
movement_periods.hold_period: 0.1
movement_periods.reach_period: 5.0
movement_periods.return_period: 1.0
cursor_outlier_segments: 0.0
eval_movement_phase: random_period
79 changes: 29 additions & 50 deletions torch_brain/models/poyo.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,11 @@ def __init__(
modality_spec: ModalitySpec,
sequence_length: float = 1.0,
eval: bool = False,
subtask_weights: Optional[Iterable[float]] = None,
):
self.unit_tokenizer = unit_tokenizer
self.session_tokenizer = session_tokenizer

self.modality_spec = modality_spec
self.subtask_weights = subtask_weights

self.latent_step = latent_step
self.num_latents_per_step = num_latents_per_step
Expand Down Expand Up @@ -308,33 +306,18 @@ def __call__(self, data: Data) -> Dict:
if output_values.dtype == np.float64:
output_values = output_values.astype(np.float32)

session_index = self.session_tokenizer(data.session)
session_index = np.repeat(session_index, len(output_timestamps))
output_session_index = self.session_tokenizer(data.session)
output_session_index = np.repeat(output_session_index, len(output_timestamps))

# Assign each output timestamp a movement_phase
movement_phase_assignment = dict()
for k in data.movement_phases.keys():
movement_phase_assignment[k] = interval_contains(
data.movement_phases.__dict__[k],
output_timestamps,
)

# Weights for the output predictions (used in the loss function)
if self.subtask_weights is None:
output_weights = np.ones(len(output_values), dtype=np.float32)
else:
output_weights = np.zeros_like(output_timestamps, dtype=np.float32)
for k in data.movement_phases.keys():
output_weights = output_weights + (
movement_phase_assignment[k].astype(float) * self.subtask_weights[k]
)

# Mask for the output predictions
output_mask = np.ones(len(output_values), dtype=bool)
if self.eval:
# During eval, only evaluate on the subtask specified in the config
eval_movement_phase = data.config["eval_movement_phase"]
output_mask = output_mask & (movement_phase_assignment[eval_movement_phase])
output_weights = np.ones_like(output_timestamps, dtype=np.float32)
if "weights" in data.config:
weights = data.config["weights"]
for weight_key, weight_value in weights.items():
# extract the interval from the weight key
weight = data.get_nested_attribute(weight_key)
if not isinstance(weight, Interval):
raise ValueError(f"Weight {weight_key} is not an Interval")
output_weights[isin_interval(output_timestamps, weight)] *= weight_value

batch = {
# input sequence
Expand All @@ -346,9 +329,9 @@ def __call__(self, data: Data) -> Dict:
"latent_index": latent_index,
"latent_timestamps": latent_timestamps,
# output sequence
"output_session_index": pad8(session_index),
"output_session_index": pad8(output_session_index),
"output_timestamps": pad8(output_timestamps),
"output_mask": pad8(output_mask),
"output_mask": track_mask8(output_session_index),
# ground truth targets
"target_values": pad8(output_values),
"target_weights": pad8(output_weights),
Expand All @@ -361,26 +344,22 @@ def __call__(self, data: Data) -> Dict:
return batch


def interval_contains(
interval: Interval, x: Union[float, np.ndarray]
) -> Union[bool, np.ndarray]:
r"""If x is a single number, returns True if x is in the interval.
If x is a 1d numpy array, return a 1d bool numpy array.
"""
def isin_interval(timestamps: np.ndarray, interval: Interval) -> np.ndarray:
r"""Check if timestamps are in any of the intervals in the `Interval` object.
if isinstance(x, float):
if len(interval) == 0:
return False
Args:
timestamps: Timestamps to check.
interval: Interval to check against.
return np.logical_and(x >= interval.start, x < interval.end).any()
elif isinstance(x, np.ndarray):
if len(interval) == 0:
return np.zeros_like(x, dtype=bool)
Returns:
Boolean mask of the same shape as `timestamps`.
"""
if len(interval) == 0:
return np.zeros_like(timestamps, dtype=bool)

x_expanded = x[:, None]
y = np.logical_and(
x_expanded >= interval.start[None, :], x_expanded < interval.end[None, :]
)
y = np.logical_or.reduce(y, axis=1)
assert y.shape == x.shape
return y
timestamps_expanded = timestamps[:, None]
mask = np.any(
(timestamps_expanded >= interval.start) & (timestamps_expanded < interval.end),
axis=1,
)
return mask

0 comments on commit a3d2329

Please sign in to comment.