-
Notifications
You must be signed in to change notification settings - Fork 86
Open
Description
Description
Using timegan and timevae for timeseries / sequential data leads to an error:
File ~/miniconda3/envs/syncity/lib/python3.10/site-packages/synthcity/plugins/core/models/ts_model.py:251, in TimeSeriesModel.forward(self, static_data, temporal_data, observation_times)
240 @validate_arguments(config=dict(arbitrary_types_allowed=True))
241 def forward(
242 self,
(...)
247 # x shape (batch, time_step, input_size)
248 # r_out shape (batch, time_step, output_size)
250 if torch.isnan(static_data).sum() != 0:
--> 251 raise ValueError("NaNs detected in the static data")
252 if torch.isnan(temporal_data).sum() != 0:
253 raise ValueError("NaNs detected in the temporal data")
ValueError: NaNs detected in the static data
How to Reproduce
model_type = "timevae"
trn_n_iter = 20
trn_batch_size = 12
trn_n_hidden_units = 100
ts_syn_model = Plugins().get(model_type,
n_iter=trn_n_iter, # Number of training iterations
batch_size=trn_batch_size, # Batch size
generator_n_units_hidden = trn_n_hidden_units, # Number of units in the hidden layers of the generator; default: 500
discriminator_n_units_hidden = trn_n_hidden_units, # Number of units in the hidden layers of the discriminator; default: 500,
ts_syn_model.fit(data)
)
System Information
- OS: ubuntu-noble-24.04-arm64-server
- OS Version: 24.04
- Language Version: Python 3.10
Additional Context
Add any other context about the problem here.
Metadata
Metadata
Assignees
Labels
No labels