Skip to content

Commit

Permalink
[Formatting] Black+Ruff 🎨
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Nov 5, 2023
1 parent 74c48ca commit 0ac16cd
Show file tree
Hide file tree
Showing 17 changed files with 100 additions and 74 deletions.
5 changes: 2 additions & 3 deletions docs/_theme/rl4co/extensions/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,17 @@
# limitations under the License.
from docutils import nodes
from docutils.statemachine import StringList
from sphinx.util.docutils import SphinxDirective

from pt_lightning_sphinx_theme.extensions.pytorch_tutorials import (
cardnode,
CustomCalloutItemDirective,
CustomCardItemDirective,
DisplayItemDirective,
LikeButtonWithTitle,
ReactGreeter,
SlackButton,
TwoColumns,
cardnode,
)
from sphinx.util.docutils import SphinxDirective


class tutoriallistnode(nodes.General, nodes.Element):
Expand Down
29 changes: 23 additions & 6 deletions docs/_theme/rl4co/extensions/pytorch_tutorials.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@
from docutils import nodes
from docutils.parsers.rst import Directive, directives
from docutils.statemachine import StringList
from sphinx.util.docutils import SphinxDirective

from pt_lightning_sphinx_theme.extensions.react import get_react_component_rst
from sphinx.util.docutils import SphinxDirective

try:
FileNotFoundError
Expand Down Expand Up @@ -272,11 +271,23 @@ def run(self):

image_class = ""
if "image_center" in self.options:
image = "<img src='" + self.options["image_center"] + "' style=height:" + image_height + " >"
image = (
"<img src='"
+ self.options["image_center"]
+ "' style=height:"
+ image_height
+ " >"
)
image_class = "image-center"

elif "image_right" in self.options:
image = "<img src='" + self.options["image_right"] + "' style=height:" + image_height + " >"
image = (
"<img src='"
+ self.options["image_right"]
+ "' style=height:"
+ image_height
+ " >"
)
image_class = "image-right"
else:
image = ""
Expand Down Expand Up @@ -371,7 +382,11 @@ def run(self):
raise
# return []
callout_rst = get_react_component_rst(
"LikeButtonWithTitle", width=width, margin=margin, title=title, padding=padding
"LikeButtonWithTitle",
width=width,
margin=margin,
title=title,
padding=padding,
)
callout_list = StringList(callout_rst.split("\n"))
callout = nodes.paragraph()
Expand Down Expand Up @@ -427,7 +442,9 @@ def run(self):
print(e)
raise
return []
callout_rst = SLACK_TEMPLATE.format(align=align, title=title, margin=margin, width=width)
callout_rst = SLACK_TEMPLATE.format(
align=align, title=title, margin=margin, width=width
)
callout_list = StringList(callout_rst.split("\n"))
callout = nodes.paragraph()
self.state.nested_parse(callout_list, self.content_offset, callout)
Expand Down
2 changes: 1 addition & 1 deletion rl4co/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.4.dev1"
__version__ = "0.3.0dev0"
1 change: 0 additions & 1 deletion rl4co/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from rl4co.models.zoo.active_search import ActiveSearch
from rl4co.models.zoo.am import AttentionModel, AttentionModelPolicy
from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy

from rl4co.models.zoo.common.search import SearchBase
from rl4co.models.zoo.eas import EAS, EASEmb, EASLay
from rl4co.models.zoo.ham import (
Expand Down
4 changes: 3 additions & 1 deletion rl4co/models/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def get_log_likelihood(log_p, actions, mask, return_sum: bool = True):
if mask is not None:
log_p[~mask] = 0

assert (log_p > -1000).data.all(), "Logprobs should not be -inf, check sampling procedure!"
assert (
log_p > -1000
).data.all(), "Logprobs should not be -inf, check sampling procedure!"

# Calculate log_likelihood
if return_sum:
Expand Down
18 changes: 12 additions & 6 deletions rl4co/models/rl/common/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, Union, Iterable
from typing import Any, Iterable, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -149,7 +149,6 @@ def setup(self, stage="fit"):
self.data_cfg["test_data_size"], phase="test"
)
self.dataloader_names = None

self.setup_loggers()
self.post_setup_hook()

Expand Down Expand Up @@ -214,12 +213,16 @@ def configure_optimizers(self, parameters=None):

def log_metrics(self, metric_dict: dict, phase: str, dataloader_idx: int = None):
"""Log metrics to logger and progress bar"""
metrics = getattr(self, f"{phase}_metrics")
metrics = getattr(self, f"{phase}_metrics")
dataloader_name = ""
if dataloader_idx is not None and self.dataloader_names is not None:
dataloader_name = "/" + self.dataloader_names[dataloader_idx]
dataloader_name = "/" + self.dataloader_names[dataloader_idx]
metrics = {
f"{phase}/{k}{dataloader_name}": v.mean() if isinstance(v, torch.Tensor) else v for k, v in metric_dict.items() if k in metrics
f"{phase}/{k}{dataloader_name}": v.mean()
if isinstance(v, torch.Tensor)
else v
for k, v in metric_dict.items()
if k in metrics
}
log_on_step = self.log_on_step if phase == "train" else False
on_epoch = False if phase == "train" else True
Expand Down Expand Up @@ -292,7 +295,10 @@ def _dataloader(self, dataset, batch_size, shuffle=False):
self.dataloader_names = list(dataset.keys())
else:
self.dataloader_names = [f"{i}" for i in range(len(dataset))]
return [self._dataloader_single(ds, batch_size, shuffle) for ds in dataset.values()]
return [
self._dataloader_single(ds, batch_size, shuffle)
for ds in dataset.values()
]
else:
return self._dataloader_single(dataset, batch_size, shuffle)

Expand Down
1 change: 0 additions & 1 deletion rl4co/models/zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from rl4co.models.zoo.active_search import ActiveSearch
from rl4co.models.zoo.am import AttentionModel, AttentionModelPolicy
from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy

from rl4co.models.zoo.common.search import SearchBase
from rl4co.models.zoo.eas import EAS, EASEmb, EASLay
from rl4co.models.zoo.ham import (
Expand Down
4 changes: 2 additions & 2 deletions rl4co/models/zoo/ham/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class HeterogeneousAttentionModel(REINFORCE):
"""Heterogenous Attention Model for solving the Pickup and Delivery Problem based on
"""Heterogenous Attention Model for solving the Pickup and Delivery Problem based on
REINFORCE: https://arxiv.org/abs/2110.02634.
Args:
Expand All @@ -20,7 +20,7 @@ class HeterogeneousAttentionModel(REINFORCE):
"""

def __init__(
self,
self,
env: RL4COEnvBase,
policy: HeterogeneousAttentionModelPolicy = None,
baseline: Union[REINFORCEBaseline, str] = "rollout",
Expand Down
3 changes: 1 addition & 2 deletions rl4co/models/zoo/ham/policy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch.nn as nn
from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy
from rl4co.models.zoo.ham.encoder import GraphHeterogeneousAttentionEncoder

Expand Down Expand Up @@ -41,4 +40,4 @@ def __init__(
num_heads=num_heads,
normalization=normalization,
**kwargs,
)
)
2 changes: 1 addition & 1 deletion rl4co/models/zoo/mdam/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .model import MDAM
from .policy import MDAMPolicy
from .model import MDAM
33 changes: 16 additions & 17 deletions rl4co/models/zoo/mdam/decoder.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import math
from typing import Union

from dataclasses import dataclass
from tensordict import TensorDict
from typing import Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from tensordict import TensorDict

from rl4co.envs import RL4COEnvBase

from rl4co.models.nn.attention import LogitAttention
from rl4co.models.nn.env_embeddings import env_context_embedding, env_dynamic_embedding
from rl4co.models.nn.utils import decode_probs, get_log_likelihood
Expand Down Expand Up @@ -67,8 +66,7 @@ def __init__(
self.project_node_embeddings = nn.ModuleList(self.project_node_embeddings)

self.project_fixed_context = [
nn.Linear(embedding_dim, embedding_dim, bias=False)
for _ in range(num_paths)
nn.Linear(embedding_dim, embedding_dim, bias=False) for _ in range(num_paths)
]
self.project_fixed_context = nn.ModuleList(self.project_fixed_context)

Expand All @@ -79,8 +77,7 @@ def __init__(
self.project_step_context = nn.ModuleList(self.project_step_context)

self.project_out = [
nn.Linear(embedding_dim, embedding_dim, bias=False)
for _ in range(num_paths)
nn.Linear(embedding_dim, embedding_dim, bias=False) for _ in range(num_paths)
]
self.project_out = nn.ModuleList(self.project_out)

Expand Down Expand Up @@ -108,15 +105,15 @@ def __init__(
self.shrink_size = shrink_size

def forward(
self,
td: TensorDict,
encoded_inputs: torch.Tensor,
env: Union[str, RL4COEnvBase],
attn,
V,
h_old,
**decoder_kwargs
):
self,
td: TensorDict,
encoded_inputs: torch.Tensor,
env: Union[str, RL4COEnvBase],
attn,
V,
h_old,
**decoder_kwargs,
):
# SECTION: Decoder first step: calculate for the decoder divergence loss
# Cost list and log likelihood list along with path
output_list = []
Expand Down Expand Up @@ -261,7 +258,9 @@ def _get_log_p(self, fixed, td, path_index, normalize=True):
step_context = self.context[path_index](
fixed.node_embeddings, td
) # [batch, embed_dim]
glimpse_q = fixed.graph_context + step_context.unsqueeze(1).to(fixed.graph_context.device)
glimpse_q = fixed.graph_context + step_context.unsqueeze(1).to(
fixed.graph_context.device
)

# Compute keys and values for the nodes
(
Expand Down
27 changes: 13 additions & 14 deletions rl4co/models/zoo/mdam/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from typing import Union

from rl4co.envs.common.base import RL4COEnvBase
Expand All @@ -8,10 +7,10 @@


class MDAM(REINFORCE):
""" Multi-Decoder Attention Model (MDAM) is a model
to train multiple diverse policies, which effectively increases the chance of finding
"""Multi-Decoder Attention Model (MDAM) is a model
to train multiple diverse policies, which effectively increases the chance of finding
good solutions compared with existing methods that train only one policy.
Reference link: https://arxiv.org/abs/2012.10638;
Reference link: https://arxiv.org/abs/2012.10638;
Implementation reference: https://github.com/liangxinedu/MDAM.
Args:
Expand All @@ -24,15 +23,15 @@ class MDAM(REINFORCE):
"""

def __init__(
self,
env: RL4COEnvBase,
policy: MDAMPolicy = None,
baseline: Union[REINFORCEBaseline, str] = "rollout",
policy_kwargs={},
baseline_kwargs={},
**kwargs
):
self,
env: RL4COEnvBase,
policy: MDAMPolicy = None,
baseline: Union[REINFORCEBaseline, str] = "rollout",
policy_kwargs={},
baseline_kwargs={},
**kwargs,
):
if policy is None:
policy = MDAMPolicy(env.name, **policy_kwargs)
policy = MDAMPolicy(env.name, **policy_kwargs)

super().__init__(env, policy, baseline, baseline_kwargs, **kwargs)
super().__init__(env, policy, baseline, baseline_kwargs, **kwargs)
17 changes: 8 additions & 9 deletions rl4co/models/zoo/mdam/policy.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
import torch.nn as nn
from typing import Union

from tensordict import TensorDict
from rl4co.envs import RL4COEnvBase, get_env

from rl4co.envs import RL4COEnvBase, get_env
from rl4co.models.nn.env_embeddings import env_init_embedding
from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy
from rl4co.models.zoo.mdam.decoder import Decoder
from rl4co.models.zoo.mdam.encoder import GraphAttentionEncoder
from rl4co.models.zoo.common.autoregressive import AutoregressivePolicy
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)


class MDAMPolicy(AutoregressivePolicy):
""" Multi-Decoder Attention Model (MDAM) policy.
"""Multi-Decoder Attention Model (MDAM) policy.
Args:
"""

def __init__(
self,
self,
env_name: str,
embedding_dim: int = 128,
num_encoder_layers: int = 3,
Expand All @@ -35,13 +34,13 @@ def __init__(
embed_dim=embedding_dim,
num_layers=num_encoder_layers,
normalization=normalization,
**kwargs
**kwargs,
),
decoder=Decoder(
env_name=env_name,
embedding_dim=embedding_dim,
num_heads=num_heads,
**kwargs
**kwargs,
),
embedding_dim=embedding_dim,
num_encoder_layers=num_encoder_layers,
Expand Down Expand Up @@ -84,4 +83,4 @@ def forward(
"entropy": kl_divergence,
"actions": actions if return_actions else None,
}
return out
return out
10 changes: 6 additions & 4 deletions rl4co/models/zoo/pomo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def __init__(
for phase in ["train", "val", "test"]:
self.set_decode_type_multistart(phase)

def shared_step(self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None):
def shared_step(
self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None
):
td = self.env.reset(batch)
n_aug, n_start = self.num_augment, self.num_starts
n_start = get_num_starts(td) if n_start is None else n_start
Expand Down Expand Up @@ -102,10 +104,10 @@ def shared_step(self, batch: Any, batch_idx: int, phase: str, dataloader_idx: in
out.update({"max_aug_reward": max_aug_reward})

if out.get("actions", None) is not None:
actions_ = out["best_multistart_actions"] if n_start > 1 else out["actions"]
out.update(
{"best_aug_actions": gather_by_index(actions_, max_idxs)}
actions_ = (
out["best_multistart_actions"] if n_start > 1 else out["actions"]
)
out.update({"best_aug_actions": gather_by_index(actions_, max_idxs)})

metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx)
return {"loss": out.get("loss", None), **metrics}
Loading

0 comments on commit 0ac16cd

Please sign in to comment.