Skip to content

Commit

Permalink
Initial E3B implementation. No inverse dynamics model, decayed corr m…
Browse files Browse the repository at this point in the history
…atrix
  • Loading branch information
Joseph Suarez committed Jan 17, 2025
1 parent 40827e0 commit 4418024
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 28 deletions.
24 changes: 15 additions & 9 deletions clean_pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def create(config, vecenv, policy, optimizer=None, wandb=None):

lstm = policy.lstm if hasattr(policy, 'lstm') else None
experience = Experience(config.batch_size, config.bptt_horizon,
config.minibatch_size, obs_shape, obs_dtype, atn_shape, atn_dtype,
config.cpu_offload, config.device, lstm, total_agents)
config.minibatch_size, policy.hidden_size, obs_shape, obs_dtype,
atn_shape, atn_dtype, config.cpu_offload, config.device, lstm, total_agents)

uncompiled_policy = policy

Expand Down Expand Up @@ -82,6 +82,7 @@ def evaluate(data):
policy = data.policy
infos = defaultdict(list)
lstm_h, lstm_c = experience.lstm_h, experience.lstm_c
e3b_inv = experience.e3b_inv

while not experience.full:
with profile.env:
Expand All @@ -99,14 +100,18 @@ def evaluate(data):
with profile.eval_forward, torch.no_grad():
# TODO: In place-update should be faster. Leaking 7% speed max
# Also should be using a cuda tensor to index
e3b = e3b_inv[env_id]
if lstm_h is not None:
h = lstm_h[:, env_id]
c = lstm_c[:, env_id]
actions, logprob, _, value, (h, c) = policy(o_device, (h, c))
actions, logprob, _, value, (h, c), next_e3b, intrinsic_reward = policy(o_device, (h, c), e3b=e3b)
lstm_h[:, env_id] = h
lstm_c[:, env_id] = c
else:
actions, logprob, _, value = policy(o_device)
actions, logprob, _, value, next_e3b, intrinsic_reward = policy(o_device, e3b=e3b)

e3b_inv[env_id] = next_e3b
r += intrinsic_reward.cpu()

if config.device == 'cuda':
torch.cuda.synchronize()
Expand Down Expand Up @@ -179,11 +184,11 @@ def train(data):

with profile.train_forward:
if experience.lstm_h is not None:
_, newlogprob, entropy, newvalue, lstm_state = data.policy(
_, newlogprob, entropy, newvalue, lstm_state, _, _ = data.policy(
obs, state=lstm_state, action=atn)
lstm_state = (lstm_state[0].detach(), lstm_state[1].detach())
else:
_, newlogprob, entropy, newvalue = data.policy(
_, newlogprob, entropy, newvalue, _, _ = data.policy(
obs.reshape(-1, *data.vecenv.single_observation_space.shape),
action=atn,
)
Expand Down Expand Up @@ -388,8 +393,9 @@ def make_losses():

class Experience:
'''Flat tensor storage and array views for faster indexing'''
def __init__(self, batch_size, bptt_horizon, minibatch_size, obs_shape, obs_dtype, atn_shape, atn_dtype,
cpu_offload=False, device='cuda', lstm=None, lstm_total_agents=0):
def __init__(self, batch_size, bptt_horizon, minibatch_size, hidden_size,
obs_shape, obs_dtype, atn_shape, atn_dtype, cpu_offload=False,
device='cuda', lstm=None, lstm_total_agents=0):
if minibatch_size is None:
minibatch_size = batch_size

Expand All @@ -405,8 +411,8 @@ def __init__(self, batch_size, bptt_horizon, minibatch_size, obs_shape, obs_dtyp
self.dones=torch.zeros(batch_size, pin_memory=pin)
self.truncateds=torch.zeros(batch_size, pin_memory=pin)
self.values=torch.zeros(batch_size, pin_memory=pin)
self.e3b_inv = 1*torch.eye(hidden_size).repeat(lstm_total_agents, 1, 1).to(device)

#self.obs_np = np.asarray(self.obs)
self.actions_np = np.asarray(self.actions)
self.logprobs_np = np.asarray(self.logprobs)
self.rewards_np = np.asarray(self.rewards)
Expand Down
10 changes: 5 additions & 5 deletions config/ocean/pong.ini
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@ checkpoint_interval = 25
num_envs = 1
num_workers = 1
env_batch_size = 1
batch_size = 131072
update_epochs = 3
batch_size = 32768
update_epochs = 1
minibatch_size = 8192
bptt_horizon = 16
ent_coef = 0.004602
ent_coef = 0.003
gae_lambda = 0.979
gamma = 0.9879
learning_rate = 0.001494
anneal_lr = False
device = cuda
max_grad_norm = 3.592
vf_coef = 0.4122
max_grad_norm = 0.5
vf_coef = 0.5

[sweep.metric]
goal = maximize
Expand Down
20 changes: 11 additions & 9 deletions pufferlib/cleanrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,19 @@ def __init__(self, policy):
super().__init__()
self.policy = policy
self.is_continuous = hasattr(policy, 'is_continuous') and policy.is_continuous
self.hidden_size = policy.hidden_size

def get_value(self, x, state=None):
_, value = self.policy(x)
return value

def get_action_and_value(self, x, action=None):
logits, value = self.policy(x)
logits, value, e3b, intrinsic_reward = self.policy(x, e3b=e3b)
action, logprob, entropy = sample_logits(logits, action, self.is_continuous)
return action, logprob, entropy, value
return action, logprob, entropy, value, e3b, intrinsic_reward

def forward(self, x, action=None):
return self.get_action_and_value(x, action)
def forward(self, x, action=None, e3b=None):
return self.get_action_and_value(x, action, e3b)


class RecurrentPolicy(torch.nn.Module):
Expand All @@ -82,6 +83,7 @@ def __init__(self, policy):
super().__init__()
self.policy = policy
self.is_continuous = hasattr(policy.policy, 'is_continuous') and policy.policy.is_continuous
self.hidden_size = policy.hidden_size

@property
def lstm(self):
Expand All @@ -95,10 +97,10 @@ def lstm(self):
def get_value(self, x, state=None):
_, value, _ = self.policy(x, state)

def get_action_and_value(self, x, state=None, action=None):
logits, value, state = self.policy(x, state)
def get_action_and_value(self, x, state=None, action=None, e3b=None):
logits, value, state, e3b, intrinsic_reward = self.policy(x, state, e3b=e3b)
action, logprob, entropy = sample_logits(logits, action, self.is_continuous)
return action, logprob, entropy, value, state
return action, logprob, entropy, value, state, e3b, intrinsic_reward

def forward(self, x, state=None, action=None):
return self.get_action_and_value(x, state, action)
def forward(self, x, state=None, action=None, e3b=None):
return self.get_action_and_value(x, state, action, e3b)
18 changes: 13 additions & 5 deletions pufferlib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def encode_observations(self, observations):
observations = observations.view(batch_size, -1)
return torch.relu(self.encoder(observations.float())), None

def decode_actions(self, hidden, lookup, concat=True):
def decode_actions(self, hidden, lookup, concat=True, e3b=None):
'''Decodes a batch of hidden states into (multi)discrete actions.
Assumes no time dimension (handled by LSTM wrappers).'''
value = self.value_head(hidden)
Expand All @@ -86,8 +86,16 @@ def decode_actions(self, hidden, lookup, concat=True):
batch = hidden.shape[0]
return probs, value

intrinsic_reward = None
if e3b is not None:
phi = hidden.detach()
intrinsic_reward = (phi.unsqueeze(1) @ e3b @ phi.unsqueeze(2))
e3b = 0.95*e3b - (phi.unsqueeze(2) @ phi.unsqueeze(1))/(1 + intrinsic_reward)
intrinsic_reward = intrinsic_reward.squeeze()
intrinsic_reward = 0.1*torch.clamp(intrinsic_reward, -1, 1)

actions = self.decoder(hidden)
return actions, value
return actions, value, e3b, intrinsic_reward

class LSTMWrapper(nn.Module):
def __init__(self, env, policy, input_size=128, hidden_size=128, num_layers=1):
Expand All @@ -109,7 +117,7 @@ def __init__(self, env, policy, input_size=128, hidden_size=128, num_layers=1):
elif "weight" in name:
nn.init.orthogonal_(param, 1.0)

def forward(self, x, state):
def forward(self, x, state, e3b=None):
x_shape, space_shape = x.shape, self.obs_shape
x_n, space_n = len(x_shape), len(space_shape)
if x_shape[-space_n:] != space_shape:
Expand All @@ -135,8 +143,8 @@ def forward(self, x, state):
hidden = hidden.transpose(0, 1)

hidden = hidden.reshape(B*TT, self.hidden_size)
hidden, critic = self.policy.decode_actions(hidden, lookup)
return hidden, critic, state
hidden, critic, e3b, intrinsic_reward = self.policy.decode_actions(hidden, lookup, e3b=e3b)
return hidden, critic, state, e3b, intrinsic_reward

class Convolutional(nn.Module):
def __init__(self, env, *args, framestack, flat_size,
Expand Down

0 comments on commit 4418024

Please sign in to comment.