Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simba SAC #59

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion sbx/common/jax_layers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Callable, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union

import flax.linen as nn
import jax
import jax.numpy as jnp
from flax.linen.module import Module, compact, merge_param
Expand Down Expand Up @@ -204,3 +205,20 @@ def __call__(self, x, use_running_average: Optional[bool] = None):
self.bias_init,
self.scale_init,
)


# Adapted from simba: https://github.com/SonyResearch/simba
class SimbaResidualBlock(nn.Module):
hidden_dim: int
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
scale_factor: int = 4
norm_layer: Type[nn.Module] = nn.LayerNorm

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
residual = x
x = self.norm_layer()(x)
x = nn.Dense(self.hidden_dim * self.scale_factor, kernel_init=nn.initializers.he_normal())(x)
x = self.activation_fn(x)
x = nn.Dense(self.hidden_dim, kernel_init=nn.initializers.he_normal())(x)
return residual + x
53 changes: 53 additions & 0 deletions sbx/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from stable_baselines3.common.preprocessing import is_image_space, maybe_transpose
from stable_baselines3.common.utils import is_vectorized_observation

from sbx.common.jax_layers import SimbaResidualBlock


class Flatten(nn.Module):
"""
Expand Down Expand Up @@ -143,6 +145,29 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
return x


class SimbaContinuousCritic(nn.Module):
net_arch: Sequence[int]
dropout_rate: Optional[float] = None
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
scale_factor: int = 4

@nn.compact
def __call__(self, x: jnp.ndarray, action: jnp.ndarray) -> jnp.ndarray:
x = Flatten()(x)
x = jnp.concatenate([x, action], -1)
# Note: simba was using kernel_init=orthogonal_init(1)
x = nn.Dense(self.net_arch[0])(x)
for n_units in self.net_arch:
x = SimbaResidualBlock(n_units, self.activation_fn, self.scale_factor)(x)
# TODO: double check where to put the dropout
if self.dropout_rate is not None and self.dropout_rate > 0:
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False)
x = nn.LayerNorm()(x)

x = nn.Dense(1)(x)
return x


class VectorCritic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False
Expand All @@ -169,3 +194,31 @@ def __call__(self, obs: jnp.ndarray, action: jnp.ndarray):
activation_fn=self.activation_fn,
)(obs, action)
return q_values


class SimbaVectorCritic(nn.Module):
net_arch: Sequence[int]
# Note: we have use_layer_norm for consistency but it is not used (always on)
use_layer_norm: bool = True
dropout_rate: Optional[float] = None
n_critics: int = 1 # only one critic per default
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu

@nn.compact
def __call__(self, obs: jnp.ndarray, action: jnp.ndarray):
# Idea taken from https://github.com/perrin-isir/xpag
# Similar to https://github.com/tinkoff-ai/CORL for PyTorch
vmap_critic = nn.vmap(
SimbaContinuousCritic,
variable_axes={"params": 0}, # parameters not shared between the critics
split_rngs={"params": True, "dropout": True}, # different initializations
in_axes=None,
out_axes=0,
axis_size=self.n_critics,
)
q_values = vmap_critic(
dropout_rate=self.dropout_rate,
net_arch=self.net_arch,
activation_fn=self.activation_fn,
)(obs, action)
return q_values
3 changes: 2 additions & 1 deletion sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax
from sbx.common.type_aliases import BatchNormTrainState, ReplayBufferSamplesNp
from sbx.crossq.policies import CrossQPolicy
from sbx.crossq.policies import CrossQPolicy, SimbaCrossQPolicy


class EntropyCoef(nn.Module):
Expand All @@ -42,6 +42,7 @@ def __call__(self) -> float:
class CrossQ(OffPolicyAlgorithmJax):
policy_aliases: ClassVar[Dict[str, Type[CrossQPolicy]]] = { # type: ignore[assignment]
"MlpPolicy": CrossQPolicy,
"SimbaPolicy": SimbaCrossQPolicy,
# Minimal dict support using flatten()
"MultiInputPolicy": CrossQPolicy,
}
Expand Down
198 changes: 192 additions & 6 deletions sbx/crossq/policies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union

import flax.linen as nn
import jax
Expand All @@ -10,7 +11,7 @@
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.distributions import TanhTransformedDistribution
from sbx.common.jax_layers import BatchRenorm
from sbx.common.jax_layers import BatchRenorm, SimbaResidualBlock
from sbx.common.policies import BaseJaxPolicy, Flatten
from sbx.common.type_aliases import BatchNormTrainState

Expand Down Expand Up @@ -48,12 +49,52 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) ->
x = nn.LayerNorm()(x)
x = self.activation_fn(x)
if self.use_batch_norm:
x = BatchRenorm(use_running_average=not train, momentum=self.batch_norm_momentum)(x)
x = BatchRenorm(
use_running_average=not train,
momentum=self.batch_norm_momentum,
warmup_steps=self.renorm_warmup_steps,
)(x)

x = nn.Dense(1)(x)
return x


class SimbaCritic(nn.Module):
net_arch: Sequence[int]
dropout_rate: Optional[float] = None
batch_norm_momentum: float = 0.99
renorm_warmup_steps: int = 100_000
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
scale_factor: int = 4

@nn.compact
def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) -> jnp.ndarray:
x = Flatten()(x)
x = jnp.concatenate([x, action], -1)
norm_layer = partial(
BatchRenorm,
use_running_average=not train,
momentum=self.batch_norm_momentum,
warmup_steps=self.renorm_warmup_steps,
)
x = norm_layer()(x)
x = nn.Dense(self.net_arch[0])(x)

for n_units in self.net_arch:
x = SimbaResidualBlock(
n_units,
self.activation_fn,
self.scale_factor,
norm_layer, # type: ignore[arg-type]
)(x)
# TODO: double check where to put the dropout
if self.dropout_rate is not None and self.dropout_rate > 0:
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False)
x = norm_layer()(x)
x = nn.Dense(1)(x)
return x


class VectorCritic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False
Expand Down Expand Up @@ -88,6 +129,87 @@ def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, train: bool = False):
return q_values


class SimbaVectorCritic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False # ignored
use_batch_norm: bool = True
batch_norm_momentum: float = 0.99
renorm_warmup_steps: int = 100_000
dropout_rate: Optional[float] = None
n_critics: int = 2
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
scale_factor: int = 4

@nn.compact
def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, train: bool = False):
# Idea taken from https://github.com/perrin-isir/xpag
# Similar to https://github.com/tinkoff-ai/CORL for PyTorch
vmap_critic = nn.vmap(
SimbaCritic,
variable_axes={"params": 0, "batch_stats": 0}, # parameters not shared between the critics
split_rngs={"params": True, "dropout": True, "batch_stats": True}, # different initializations
in_axes=None,
out_axes=0,
axis_size=self.n_critics,
)
q_values = vmap_critic(
# use_layer_norm=self.use_layer_norm,
# use_batch_norm=self.use_batch_norm,
batch_norm_momentum=self.batch_norm_momentum,
renorm_warmup_steps=self.renorm_warmup_steps,
dropout_rate=self.dropout_rate,
net_arch=self.net_arch,
activation_fn=self.activation_fn,
scale_factor=self.scale_factor,
)(obs, action, train)
return q_values


class SimbaActor(nn.Module):
net_arch: Sequence[int]
action_dim: int
log_std_min: float = -20
log_std_max: float = 2
use_batch_norm: bool = True
batch_norm_momentum: float = 0.99
renorm_warmup_steps: int = 100_000
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
scale_factor: int = 4

def get_std(self):
# Make it work with gSDE
return jnp.array(0.0)

@nn.compact
def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: # type: ignore[name-defined]
x = Flatten()(x)
norm_layer = partial(
BatchRenorm,
use_running_average=not train,
momentum=self.batch_norm_momentum,
warmup_steps=self.renorm_warmup_steps,
)
x = norm_layer()(x)
x = nn.Dense(self.net_arch[0])(x)

for n_units in self.net_arch:
x = SimbaResidualBlock(
n_units,
self.activation_fn,
self.scale_factor,
norm_layer, # type: ignore[arg-type]
)(x)
x = norm_layer()(x)

mean = nn.Dense(self.action_dim)(x)
log_std = nn.Dense(self.action_dim)(x)
log_std = jnp.clip(log_std, self.log_std_min, self.log_std_max)
dist = TanhTransformedDistribution(
tfd.MultivariateNormalDiag(loc=mean, scale_diag=jnp.exp(log_std)),
)
return dist


class Actor(nn.Module):
net_arch: Sequence[int]
action_dim: int
Expand Down Expand Up @@ -119,7 +241,11 @@ def __call__(self, x: jnp.ndarray, train: bool = False) -> tfd.Distribution: #
x = nn.Dense(n_units)(x)
x = self.activation_fn(x)
if self.use_batch_norm:
x = BatchRenorm(use_running_average=not train, momentum=self.batch_norm_momentum)(x)
x = BatchRenorm(
use_running_average=not train,
momentum=self.batch_norm_momentum,
warmup_steps=self.renorm_warmup_steps,
)(x)

mean = nn.Dense(self.action_dim)(x)
log_std = nn.Dense(self.action_dim)(x)
Expand Down Expand Up @@ -159,6 +285,8 @@ def __init__(
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
actor_class: Type[nn.Module] = Actor,
vector_critic_class: Type[nn.Module] = VectorCritic,
):
if optimizer_kwargs is None:
# Note: the default value for b1 is 0.9 in Adam.
Expand All @@ -183,6 +311,8 @@ def __init__(
self.batch_norm_momentum = batch_norm_momentum
self.batch_norm_actor = batch_norm_actor
self.renorm_warmup_steps = renorm_warmup_steps
self.actor_class = actor_class
self.vector_critic_class = vector_critic_class

if net_arch is not None:
if isinstance(net_arch, list):
Expand Down Expand Up @@ -216,7 +346,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float)
obs = jnp.array([self.observation_space.sample()])
action = jnp.array([self.action_space.sample()])

self.actor = Actor(
self.actor = self.actor_class(
action_dim=int(np.prod(self.action_space.shape)),
net_arch=self.net_arch_pi,
use_batch_norm=self.batch_norm_actor,
Expand Down Expand Up @@ -244,7 +374,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float)
),
)

self.qf = VectorCritic(
self.qf = self.vector_critic_class(
dropout_rate=self.dropout_rate,
use_layer_norm=self.layer_norm,
use_batch_norm=self.batch_norm,
Expand Down Expand Up @@ -319,3 +449,59 @@ def _predict(self, observation: np.ndarray, deterministic: bool = False) -> np.n
if not self.use_sde:
self.reset_noise()
return self.sample_action(self.actor_state, observation, self.noise_key)


class SimbaCrossQPolicy(CrossQPolicy):
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
dropout_rate: float = 0,
layer_norm: bool = False,
batch_norm: bool = True,
batch_norm_actor: bool = True,
batch_norm_momentum: float = 0.99,
renorm_warmup_steps: int = 100000,
use_sde: bool = False,
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu,
log_std_init: float = -3,
use_expln: bool = False,
clip_mean: float = 2,
features_extractor_class=None,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
actor_class: Type[nn.Module] = SimbaActor, # TODO: replace with Simba actor
vector_critic_class: Type[nn.Module] = SimbaVectorCritic,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
dropout_rate,
layer_norm,
batch_norm,
batch_norm_actor,
batch_norm_momentum,
renorm_warmup_steps,
use_sde,
activation_fn,
log_std_init,
use_expln,
clip_mean,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
n_critics,
share_features_extractor,
actor_class,
vector_critic_class,
)
Loading
Loading