diff --git a/Cargo.lock b/Cargo.lock index 140ac24..736a238 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -706,6 +706,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee1b05cbd864bcaecbd3455d6d967862d446e4ebfc3c2e5e5b9841e53cba6673" +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -838,6 +847,16 @@ dependencies = [ "cc", ] +[[package]] +name = "idna" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "indexmap" version = "2.2.3" @@ -1895,6 +1914,21 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +[[package]] +name = "ring" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +dependencies = [ + "cc", + "cfg-if", + "getrandom", + "libc", + "spin", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rustc-hash" version = "1.1.0" @@ -1914,6 +1948,37 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustls" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41" +dependencies = [ + "log", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ede67b28608b4c60685c7d54122d4400d90f62b40caee7700e700380a390fa8" + +[[package]] +name = "rustls-webpki" +version = "0.102.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.14" @@ -2074,6 +2139,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + [[package]] name = "sqlparser" version = "0.39.0" @@ -2238,6 +2309,21 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" +[[package]] +name = "tinyvec" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "torch-sys" version = "0.15.0" @@ -2246,6 +2332,9 @@ dependencies = [ "anyhow", "cc", "libc", + "serde", + "serde_json", + "ureq", "zip", ] @@ -2255,12 +2344,27 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "unicode-bidi" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" + [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-normalization" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +dependencies = [ + "tinyvec", +] + [[package]] name = "unicode-reverse" version = "1.0.8" @@ -2282,6 +2386,42 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "ureq" +version = "2.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11f214ce18d8b2cbe84ed3aa6486ed3f5b285cf8d8fbdbce9f3f767a724adc35" +dependencies = [ + "base64", + "flate2", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "rustls-webpki", + "serde", + "serde_json", + "url", + "webpki-roots", +] + +[[package]] +name = "url" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + [[package]] name = "utf8parse" version = "0.2.1" @@ -2381,6 +2521,15 @@ version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4f186bd2dcf04330886ce82d6f33dd75a7bfcf69ecf5763b89fcde53b6ac9838" +[[package]] +name = "webpki-roots" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "which" version = "4.4.2" @@ -2592,6 +2741,12 @@ dependencies = [ "syn 2.0.48", ] +[[package]] +name = "zeroize" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" + [[package]] name = "zip" version = "0.6.6" diff --git a/Cargo.toml b/Cargo.toml index 316506b..fcaf7e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ rand = "0.8.5" rand_distr = "0.4.3" serde = "1.0.196" serde_json = "1.0.113" -tch = { git = "https://github.com/Jark5455/tch-rs" } +tch = { git = "https://github.com/Jark5455/tch-rs", features = ["download-libtorch"] } clap = { version = "4.5.0", features = ["derive"] } polars = { version = "0.38.1", features = ["cross_join", "cum_agg", "json", "lazy", "ndarray", "regex", "strings"] } diff --git a/src/main.rs b/src/main.rs index aa0fc11..fe2e621 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,31 +12,49 @@ mod tests; mod viewer; mod wrappers; -use crate::environment::{Environment, Terminate}; +use crate::environment::{Environment, MujocoEnvironment, Terminate}; use crate::halfcheetahenv::HalfCheetahEnv; use crate::replay_buffer::ReplayBuffer; +use crate::stockenv::StockEnv; use crate::td3::TD3; use crate::viewer::Viewer; lazy_static::lazy_static! { - static ref device: std::sync::Arc = std::sync::Arc::new(tch::Device::cuda_if_available()); + static ref device: std::sync::Arc = std::sync::Arc::new(tch::Device::Cpu); } #[derive(clap::Parser)] struct Args { #[arg(long)] - max_timesteps: Option, - #[arg(long)] - start_timesteps: Option, - #[arg(long)] - expl_noise: Option, - #[arg(long)] - eval_freq: Option, - #[arg(long)] - save_policy: Option, - #[arg(long)] - load_td3: Option, + env: String, + #[command(subcommand)] + command: Commands, +} + +#[derive(clap::Subcommand)] +enum Commands { + Train { + #[arg(long)] + actor_opt: String, + #[arg(long)] + critic_opt: String, + #[arg(long)] + max_timesteps: Option, + #[arg(long)] + start_timesteps: Option, + #[arg(long)] + expl_noise: Option, + #[arg(long)] + eval_freq: Option, + #[arg(long)] + save_policy: Option, + }, + + Run { + #[arg(long)] + savefile: String, + }, } fn eval_td3(policy: &TD3, env: &mut Box, eval_episodes: Option) -> f64 { @@ -58,14 +76,16 @@ fn eval_td3(policy: &TD3, env: &mut Box, eval_episodes: Option< } fn run_td3( + env: &str, + filename: &str, expl_noise: f64, max_timesteps: u32, start_timesteps: u32, eval_freq: u32, save_policy: bool, + actor_opt: &str, + critic_opt: &str, ) { - let filename = "td3_halfcheetah"; - if !std::path::Path::new("./results").exists() { std::fs::create_dir_all("./results").expect("Failed to create results directory"); } @@ -74,27 +94,39 @@ fn run_td3( std::fs::create_dir_all("./models").expect("Failed to create models directory"); } - /* + let envs: (Box, Box) = match env { + "halfcheetah" => { + let train_env = Box::new(HalfCheetahEnv::new( + None, None, None, None, None, None, None, + )); - let end = polars::export::chrono::Utc::now() - .date_naive() - .and_hms_micro_opt(0, 0, 0, 0) - .unwrap(); - let _start = end - polars::export::chrono::Duration::days(15); + let eval_env = Box::new(HalfCheetahEnv::new( + None, None, None, None, None, None, None, + )); - // let ref_env = HalfCheetahEnv::new(None, None, None, None, None, None, None); + (train_env, eval_env) + } - // let mut train_env: Box = Box::new(ref_env.clone()); - // let mut eval_env: Box = Box::new(ref_env.clone()); + "stockenv" => { + let end = polars::export::chrono::Utc::now() + .date_naive() + .and_hms_micro_opt(0, 0, 0, 0) + .unwrap(); + let start = end - polars::export::chrono::Duration::days(15); - */ + let train_env = Box::new(StockEnv::new(start, end)); + let eval_env = train_env.clone(); - let mut train_env: Box = Box::new(HalfCheetahEnv::new( - None, None, None, None, None, None, None, - )); - let mut eval_env: Box = Box::new(HalfCheetahEnv::new( - None, None, None, None, None, None, None, - )); + (train_env, eval_env) + } + + &_ => { + panic!("Invalid Environment Selection") + } + }; + + let mut train_env = envs.0; + let mut eval_env = envs.1; let state_dim = train_env.observation_spec().shape; let action_dim = train_env.action_spec().shape; @@ -104,8 +136,8 @@ fn run_td3( state_dim as i64, action_dim as i64, max_action, - None, - None, + actor_opt, + critic_opt, None, None, None, @@ -234,27 +266,60 @@ fn load_td3(filename: String) -> TD3 { } fn main() { - let args = ::parse(); + println!("Cuda Enabled: {}", device.is_cuda()); - let expl_noise = args.expl_noise.unwrap_or(0.1); - let max_timesteps = args.max_timesteps.unwrap_or(1000000); - let start_timesteps = args.start_timesteps.unwrap_or(25000); - let eval_freq = args.eval_freq.unwrap_or(5000); - let save_policy = args.save_policy.unwrap_or(false); - - if args.load_td3.is_some() { - let td3 = load_td3(args.load_td3.unwrap()); - let env = HalfCheetahEnv::new(None, None, None, None, None, None, None); - let mut viewer = Viewer::new(Box::new(env), td3, None, None); + let args = ::parse(); - viewer.render(); - } else { - run_td3( + match args.command { + Commands::Train { + actor_opt, + critic_opt, expl_noise, max_timesteps, start_timesteps, eval_freq, save_policy, - ); + } => { + let expl_noise = expl_noise.unwrap_or(0.1); + let max_timesteps = max_timesteps.unwrap_or(100000); + let start_timesteps = start_timesteps.unwrap_or(5000); + let eval_freq = eval_freq.unwrap_or(5000); + let save_policy = save_policy.unwrap_or(false); + + let filename = format!( + "td3_{}_{}_{}", + args.env, + actor_opt.to_lowercase(), + critic_opt.to_lowercase() + ); + + run_td3( + args.env.as_str(), + filename.as_str(), + expl_noise, + max_timesteps, + start_timesteps, + eval_freq, + save_policy, + actor_opt.as_str(), + critic_opt.as_str(), + ); + } + + Commands::Run { savefile } => { + let td3 = load_td3(savefile); + + let env: Box = match args.env.as_str() { + "halfcheetah" => Box::new(HalfCheetahEnv::new( + None, None, None, None, None, None, None, + )), + &_ => { + panic!("Selected Environment is not renderable") + } + }; + + let mut viewer = Viewer::new(env, td3, None, None); + viewer.render(); + } } } diff --git a/src/optimizer/cmaes.rs b/src/optimizer/cmaes.rs index cac058e..2659476 100644 --- a/src/optimizer/cmaes.rs +++ b/src/optimizer/cmaes.rs @@ -8,26 +8,33 @@ use tch::IndexOp; pub struct CMAES { pub vs: RefVs, + + pub cc: f64, + pub cs: f64, + pub c1: f64, + pub cmu: f64, + + pub sigma: f64, pub xmean: tch::Tensor, - pub z: tch::Tensor, - pub s: tch::Tensor, + pub variation: tch::Tensor, + pub newpop: Vec>>, pub N: i64, - pub sigma: f64, pub lambda: i64, pub mu: i64, pub weights: tch::Tensor, pub mueff: f64, - pub cc: f64, - pub cs: f64, - pub c1: f64, - pub cmu: f64, pub damps: f64, - pub chiN: f64, pub B: tch::Tensor, pub D: tch::Tensor, + pub Dinv: tch::Tensor, pub C: tch::Tensor, + pub Cold: tch::Tensor, + pub invsqrtC: tch::Tensor, pub pc: tch::Tensor, pub ps: tch::Tensor, + + pub counteval: i64, + pub eigeneval: i64, pub gen: i64, } @@ -40,16 +47,18 @@ impl CMAES { let N = xmean.size()[0]; let lambda = popsize.unwrap_or(4 + (3f64 * (N as f64).ln()).floor() as i64); - let mu = lambda as f64 / 2f64; + let mu = lambda / 2; - let weights = tch::Tensor::from_slice(&[(mu + 0.5f64).log2()]).to_device(**device) - - tch::Tensor::linspace(1f64, mu, mu.floor() as i64, (tch::Kind::Float, **device)); - let weights = weights.copy() / weights.sum(Some(tch::Kind::Float)); + let mut weights = vec![0f64; mu as usize]; + + for i in 0..weights.len() { + weights[i] = (mu as f64 + 0.5f64).log2() - (i as f64 + 1f64).log2(); + } + + let weights = tch::Tensor::from_slice(weights.as_slice()); let weights = weights.copy() / weights.sum(Some(tch::Kind::Float)); let weights = weights.totype(tch::Kind::Float); - let mu = mu.floor() as i64; - let mut mueff = [0f64; 1]; (weights.sum(Some(tch::Kind::Float)).pow_(2) @@ -62,50 +71,64 @@ impl CMAES { let cc = (4f64 + mueff / N as f64) / (N as f64 + 4f64 + 2f64 * mueff / N as f64); let cs = (mueff + 2f64) / (N as f64 + mueff + 5f64); let c1 = 2f64 / ((N as f64 + 1.3f64).powi(2) + mueff); - let cmu = f64::min( - 1f64 - c1, - 2f64 * (mueff - 2f64 + 1f64 / mueff) / ((N as f64 + 2f64).powi(2) + mueff), - ); + let cmu = 2f64 * (mueff - 2f64 + 1f64 / mueff) / ((N as f64 + 2f64).powi(2) + mueff); - let damps = f64::min(0f64, ((mu as f64 - 1f64) / (N as f64 + 1f64)).sqrt() - 1f64) + cs; + let damps = + 1f64 + 2f64 * f64::max(0f64, ((mueff - 1f64) / (N as f64 + 1f64)).sqrt() - 1f64) + cs; - let chiN = (N as f64).sqrt() - * (1f64 - 1f64 / (4f64 * N as f64) + 1f64 / (21f64 * (N as f64).sqrt())); + let variation = tch::Tensor::zeros([N], (tch::Kind::Float, **device)); let B = tch::Tensor::eye(N, (tch::Kind::Float, **device)); - let D = tch::Tensor::eye(N, (tch::Kind::Float, **device)); - let C = tch::Tensor::matmul(&B.matmul(&D), &B.matmul(&D).t_()); + let D = tch::Tensor::from_slice(vec![10f32.powi(-6); N as usize].as_slice()) + .diag(0) + .to_device(**device); + let C = tch::Tensor::matmul(&D, &D); + let invsqrtC = tch::Tensor::from_slice(vec![10f32.powi(6); N as usize].as_slice()) + .diag(0) + .to_device(**device); - let z = tch::Tensor::randn([N, lambda], (tch::Kind::Float, **device)); - let s = xmean.view([-1, 1]) + sigma * B.matmul(&D.matmul(&z)); + let Cold = tch::Tensor::eye(N, (tch::Kind::Float, **device)); + let Dinv = tch::Tensor::eye(N, (tch::Kind::Float, **device)); let pc = tch::Tensor::zeros([N], (tch::Kind::Float, **device)); let ps = tch::Tensor::zeros([N], (tch::Kind::Float, **device)); + let newpop = + vec![std::rc::Rc::new(std::cell::RefCell::new(tch::Tensor::new())); lambda as usize]; + + let counteval = 0; + let eigeneval = 0; let gen = 0; Self { vs, - N, - z, - s, - xmean, + + cc, + cs, + c1, + cmu, + sigma, + xmean, + variation, + newpop, + N, lambda, mu, weights, mueff, - cc, - cs, - c1, - cmu, damps, - chiN, - pc, - ps, B, D, + Dinv, C, + Cold, + invsqrtC, + pc, + ps, + + counteval, + eigeneval, gen, } } @@ -113,39 +136,32 @@ impl CMAES { impl CMAES { fn vs_to_flattensor(vs: RefVs) -> tch::Tensor { - let flatlist: Vec = vs - .borrow() - .trainable_variables() + let binding = vs.borrow().variables(); + let mut names_sorted = binding.keys().collect::>(); + names_sorted.sort(); + + let flatlist = names_sorted .iter() - .map(|var| var.flatten(0, (var.dim() - 1) as i64)) - .collect(); + .map(|name| vs.borrow().variables().get(*name).unwrap().flatten(0, -1)) + .collect::>(); tch::Tensor::concat(&flatlist, 0) } fn flattensor_to_vs(layout: RefVs, tensor: tch::Tensor) -> tch::nn::VarStore { let newvs = tch::nn::VarStore::new(**device); - for (name, tensor) in &layout.borrow().variables_.lock().unwrap().named_variables { - newvs - .root() - .var( - &*name, - tensor.size().as_slice(), - tch::nn::init::Init::Const(0f64), - ) - .copy_(tensor); - } - - let mut startindex = 0; - for mut var in newvs.trainable_variables() { - let len = var.flatten(0, -1).size()[0]; + let binding = layout.borrow().variables(); + let mut names_sorted = binding.keys().collect::>(); + names_sorted.sort(); - let val = tensor - .i(startindex..startindex + len) - .unflatten(0, var.size()); - var.copy_(&val); + let mut start_index = 0; + for name in names_sorted { + let layoutvar = binding.get(name).unwrap(); + let len = layoutvar.flatten(0, -1).size()[0]; + let val = tensor.i(start_index..start_index+len).unflatten(0, layoutvar.size()); + newvs.root().var(name, layoutvar.size().as_slice(), tch::nn::init::Init::Const(0f64)).copy_(&val); - startindex = startindex + len; + start_index += len; } newvs @@ -154,19 +170,19 @@ impl CMAES { impl MilkshakeOptimizer for CMAES { fn ask(&mut self) -> Vec { - let mut z = tch::Tensor::randn([self.N, self.lambda], (tch::Kind::Float, **device)); - let mut s = self.xmean.view([-1, 1]) + self.sigma * self.B.matmul(&self.D.matmul(&z)); - self.z = z.t_(); - self.s = s.t_(); - - let candidates = tch::Tensor::unbind(&self.s, 0); + for i in 0..self.lambda { + let noise = tch::Tensor::randn([self.N], (tch::Kind::Float, **device)); + self.newpop[i as usize] = std::rc::Rc::new(std::cell::RefCell::new( + &self.xmean + (self.sigma * &self.B * &self.D).mv(&noise), + )); + } let mut res = vec![]; - for candidate in candidates { + for candidate in &self.newpop { res.push(std::rc::Rc::new(std::cell::RefCell::new( - Self::flattensor_to_vs(self.vs.clone(), candidate), + Self::flattensor_to_vs(self.vs.clone(), candidate.borrow().copy()), ))); } @@ -174,6 +190,75 @@ impl MilkshakeOptimizer for CMAES { } fn tell(&mut self, solutions: Vec, losses: Vec) { + self.counteval += self.lambda; + + let fitvals = tch::Tensor::stack(losses.as_slice(), 0).sort(0, false); + + let arIndexLocal = fitvals.1.copy().to_device(tch::Device::Cpu); + + let arindex = unsafe { + std::slice::from_raw_parts( + arIndexLocal.data_ptr() as *const i64, + arIndexLocal.size()[0] as usize, + ) + }; + + let elite_indices = &arindex[0..self.mu as usize]; + + let elite_solutions: Vec = elite_indices + .iter() + .map(|i| self.newpop[*i as usize].borrow().copy()) + .collect(); + + let meanold = self.xmean.copy(); + + self.xmean = elite_solutions[0].copy() * self.weights.get(0); + + for i in 1..self.mu { + self.xmean += elite_solutions[i as usize].copy() * self.weights.get(i); + } + + self.weights.print(); + + let zscore = (&self.xmean - &meanold) / self.sigma; + + self.ps = (1f64 - self.cs) * &self.ps + (self.cs * (2f64 - self.cs) * self.mueff).sqrt() * &self.invsqrtC.mv(&zscore); + + let correlation = self.ps.norm().pow_tensor_scalar(2) / self.N / (1f64 - (1f64 - self.cs).powf(2f64 * self.counteval as f64 / self.lambda as f64)); + let hsig = unsafe { *(correlation.totype(tch::Kind::Float).to_device(tch::Device::Cpu).data_ptr() as *const f32) } < (2f64 + 4f64 / (self.N as f64 + 1f64)) as f32; + let hsig = match hsig { + true => { 1f64 } + false => { 0f64 } + }; + + self.pc = (1f64 - self.cc) * &self.pc + hsig * (self.cc * (2. - self.cc) * self.mueff).sqrt() * &zscore; + + self.Cold = self.C.copy(); + self.C = (elite_solutions[0].copy() - &meanold) * (elite_solutions[0].copy() - &meanold).t_() * self.weights.get(0); + + for i in 1..self.mu { + self.C += (elite_solutions[0].copy() - &meanold) * (elite_solutions[0].copy() - &meanold).t_() * self.weights.get(i); + } + + self.C /= self.sigma.powi(2); + self.C = (1f64 - self.c1 - self.cmu) * &self.Cold + self.cmu * &self.C + self.c1 * ((&self.pc * &self.pc.copy().t_()) + (1f64 - hsig) * self.cc * (2f64 - self.cc) * &self.Cold); + + if (self.counteval - self.eigeneval) as f64 > self.lambda as f64 / (self.c1 + self.cmu) / self.N as f64 / 10f64 { + self.eigeneval = self.counteval; + + let db = self.C.linalg_eigh("L"); + + db.0.print(); + db.1.print(); + + self.D = db.0.sqrt().diag_embed(0, -2, -1); + self.B = db.1.copy(); + self.Dinv = self.D.pow_(-1f64); + self.invsqrtC = &self.B * &self.Dinv * self.B.copy().t_(); + } + + /* + let fitvals = tch::Tensor::stack(losses.as_slice(), 0).sort(0, false); let arIndexLocal = fitvals.1.copy().to_device(tch::Device::Cpu); @@ -249,20 +334,34 @@ impl MilkshakeOptimizer for CMAES { self.sigma = self.sigma * ((self.cs / self.damps) * (correlation - 1.0)).exp(); + println!("STEP DIRECTION ======================================"); + self.pc.print(); + self.pc = (1f64 - self.cc) * &self.pc + hsig * (self.cc * (2f64 - self.cc) * self.mueff).sqrt() * self.B.matmul(&self.D).matmul(&zmean); + println!("STEP DIRECTION ======================================"); + self.pc.print(); + let pc_cov = self.pc.unsqueeze(1).matmul(&self.pc.unsqueeze(1).t_()); + + // let pc_cov = self.pc.unsqueeze(1).matmul(&self.pc.unsqueeze(1).t_()); + + // pc_cov.print(); + let pc_cov = pc_cov + (1f64 - hsig) * self.cc * (2f64 - self.cc) * &self.C; + // pc_cov.print(); + let bdz = self.B.matmul(&self.D).matmul(&z.copy().t_()); let cmu_cov = tch::Tensor::matmul(&bdz, &self.weights.diag_embed(0, -2, -1)); let cmu_cov = cmu_cov.matmul(&bdz.copy().t_()); - self.C = (1.0 - self.c1 - self.cmu) * &self.C + (self.c1 * pc_cov) + (self.cmu * cmu_cov); + // (&self.C + (self.c1 * &pc_cov)).print(); + self.C = (1.0 - self.c1 - self.cmu) * self.C.copy() + (self.c1 * pc_cov) + (self.cmu * cmu_cov); let eig = self.C.linalg_eigh("L"); self.D = eig.0; @@ -270,6 +369,10 @@ impl MilkshakeOptimizer for CMAES { self.D = self.D.sqrt().diag_embed(0, -2, -1); self.gen += 1; + + println!("{}", self.gen); + + */ } fn result(&mut self) -> RefVs { diff --git a/src/td3.rs b/src/td3.rs index 9677b37..f90b612 100644 --- a/src/td3.rs +++ b/src/td3.rs @@ -220,6 +220,8 @@ impl TD3 { state_dim: i64, action_dim: i64, max_action: f64, + actor_opt: &str, + critic_opt: &str, actor_shape: Option>, q1_shape: Option>, q2_shape: Option>, @@ -228,12 +230,7 @@ impl TD3 { policy_noise: Option, noise_clip: Option, policy_freq: Option, - actor_opt: Option<&str>, - critic_opt: Option<&str>, ) -> anyhow::Result { - let actor_opt_str = actor_opt.unwrap_or("ADAM"); - let critic_opt_str = critic_opt.unwrap_or("ADAM"); - let actor_shape = actor_shape.unwrap_or(vec![256]); let q1_shape = q1_shape.unwrap_or(vec![256]); let q2_shape = q2_shape.unwrap_or(vec![256]); @@ -250,7 +247,7 @@ impl TD3 { let critic = Critic::new(state_dim, action_dim, q1_shape.clone(), q2_shape.clone()); let critic_target = Critic::new(state_dim, action_dim, q1_shape.clone(), q2_shape.clone()); - let actor_opt: anyhow::Result> = match actor_opt_str { + let actor_opt: anyhow::Result> = match actor_opt { "ADAM" => Ok(Box::new(ADAM::new(3e-4, actor.vs.clone()))), "CMAES" => Ok(Box::new(CMAES::new(actor.vs.clone(), None, None))), &_ => { @@ -258,7 +255,7 @@ impl TD3 { } }; - let critic_opt: anyhow::Result> = match critic_opt_str { + let critic_opt: anyhow::Result> = match critic_opt { "ADAM" => Ok(Box::new(ADAM::new(3e-4, critic.vs.clone()))), "CMAES" => Ok(Box::new(CMAES::new(critic.vs.clone(), None, None))), &_ => {