Skip to content

Commit

Permalink
included code to visualize logs in wandb
Browse files Browse the repository at this point in the history
  • Loading branch information
ngastzepeda committed Jan 22, 2024
1 parent 1a2da37 commit a490aa1
Show file tree
Hide file tree
Showing 2 changed files with 1,158 additions and 0 deletions.
1,065 changes: 1,065 additions & 0 deletions notebooks/cvrptw random policy/temp_dev.ipynb

Large diffs are not rendered by default.

93 changes: 93 additions & 0 deletions notebooks/cvrptw random policy/temp_dev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import torch

from rl4co.envs import CVRPEnv, CVRPTWEnv
from rl4co.models.nn.utils import rollout, random_policy
from rl4co.models.zoo.am import AttentionModel
from rl4co.utils.trainer import RL4COTrainer

# env_cvrp = CVRPEnv()
# env_short = CVRPTWEnv(num_loc=20)
env_cvrptw = CVRPTWEnv(
num_loc=30,
min_loc=0,
max_loc=150,
min_demand=1,
max_demand=10,
vehicle_capacity=1,
capacity=10,
min_time=0,
max_time=480,
scale=True,
)

env = env_cvrptw

# batch size
batch_size = 3


### --- random policy --- ###
# try random policy
reward, td, actions = rollout(
env=env,
td=env.reset(batch_size=[batch_size]),
policy=random_policy,
max_steps=1000,
)
assert reward.shape == (batch_size,)

env.get_reward(td, actions)
CVRPTWEnv.check_solution_validity(td, actions)

env.render(td, actions)


### --- AM --- ###
# Model: default is AM with REINFORCE and greedy rollout baseline
model = AttentionModel(
env,
baseline="rollout",
train_data_size=100_000,
val_data_size=10_000,
)

# Greedy rollouts over untrained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
td_init = env.reset(batch_size=[3]).to(device)
model = model.to(device)
out = model(td_init.clone(), phase="test", decode_type="greedy", return_actions=True)

# Plotting
# print(f"Tour lengths: {[f'{-r.item():.2f}' for r in out['reward']]}")
# for td, actions in zip(td_init, out["actions"].cpu()):
# env.render(td, actions)

### --- Logging --- ###
import wandb
from lightning.pytorch.loggers import WandbLogger

wandb.login()
logger = WandbLogger(project="routefinder", name="cvrptw-am")

### --- Training --- ###
# The RL4CO trainer is a wrapper around PyTorch Lightning's `Trainer` class which adds some functionality and more efficient defaults
trainer = RL4COTrainer(
max_epochs=100,
accelerator="auto",
devices=1,
logger=logger,
)

# fit model
trainer.fit(model)

### --- Testing --- ###

# Greedy rollouts over trained model (same states as previous plot)
# model = model.to(device)
# out = model(td_init.clone(), phase="test", decode_type="greedy", return_actions=True)

# Plotting
# print(f"Tour lengths: {[f'{-r.item():.2f}' for r in out['reward']]}")
# for td, actions in zip(td_init, out["actions"].cpu()):
# env.render(td, actions)

0 comments on commit a490aa1

Please sign in to comment.