Skip to content
Draft
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions inverse_solver.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"import os, sys\n",
"\n",
"sys.path.append(os.getcwd())\n",
"from diffmpm.material import SimpleMaterial,LinearElastic\n",
"from diffmpm.particle import Particles\n",
"from diffmpm.element import Quadrilateral4Node\n",
"from diffmpm.constraint import Constraint\n",
"from diffmpm.mesh import Mesh2D\n",
"from diffmpm.solver import MPMExplicit\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"\n",
"mesh_config = {}\n",
"density = 1\n",
"# poisson_ratio = 0\n",
"youngs_modulus = 1000\n",
"material = LinearElastic(\n",
" {\n",
" \"id\":0,\n",
" \"youngs_modulus\": youngs_modulus,\n",
" \"density\": density,\n",
" \"poisson_ratio\": 0.0,\n",
" }\n",
")\n",
"particle_loc = jnp.array([[0.0, 0.0], [0.5, 0.0], [0.0, 0.5], [0.5, 0.5]]).reshape(\n",
" 4, 1, 2\n",
")\n",
"particles = Particles(particle_loc, material, jnp.zeros(particle_loc.shape[0],dtype=jnp.int32))\n",
"particles.velocity=particles.velocity.at[:].set(0.0)\n",
"constraints = [(0, Constraint(1, 0.0))]\n",
"external_loading = jnp.array([0.0, -9.8]).reshape(1,2)\n",
"element = Quadrilateral4Node([1, 1], 1, [1,1], constraints)\n",
"mesh_config[\"particles\"] = [particles]\n",
"mesh_config[\"elements\"] = element\n",
"mesh_config[\"particle_surface_traction\"] = []\n",
"mesh = Mesh2D(mesh_config)\n",
"solver = MPMExplicit(mesh, 0.01,sim_steps=10)\n",
"\n",
"real_ans = solver.solve_jit(external_loading)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from jax import jit\n",
"\n",
"def compute_loss(E,solver,target_stress):\n",
" material_props=solver.mesh.particles[0].material.properties\n",
" material_props[\"youngs_modulus\"]=E\n",
" solver.mesh.particles[0].material=LinearElastic(material_props)\n",
" external_loading_local=jnp.array([0.0, -9.8]).reshape(1,2)\n",
" solver.mesh.particles[0].velocity = mesh.particles[0].velocity.at[:].set(0.0)\n",
" result = solver.solve_jit(external_loading_local)\n",
" stress = result[\"stress\"]\n",
" loss = jnp.linalg.norm(stress - target_stress)\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"E: 1000.3665161132812: 100%|██████████| 500/500 [00:11<00:00, 43.43it/s]\n"
]
}
],
"source": [
"import optax\n",
"from tqdm import tqdm\n",
"from jax import jit, value_and_grad\n",
"\n",
"def optax_adam(params,niter,mpm,target_vel):\n",
" start_alpha=0.1\n",
" optimizer=optax.adam(start_alpha)\n",
" opt_state=optimizer.init(params)\n",
" param_list=[]\n",
" loss_list=[]\n",
" t=tqdm(range(niter),desc=f\"E: {params}\")\n",
" for _ in t:\n",
" lo,grads=value_and_grad(compute_loss)(params,mpm,target_vel)\n",
" updates,opt_state=optimizer.update(grads,opt_state)\n",
" params=optax.apply_updates(params,updates)\n",
" t.set_description(f\"E: {params}\")\n",
" param_list.append(params)\n",
" loss_list.append(lo)\n",
" return param_list,loss_list\n",
"params=1050.0\n",
"parameter_list,loss_list=optax_adam(params,500,solver,real_ans[\"stress\"])"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"#Trying via Bayesian Optimization\n",
"target_stress=real_ans[\"stress\"]\n",
"@jit\n",
"def fun(E,solver=solver,target_stress=target_stress):\n",
" material_props=solver.mesh.particles[0].material.properties\n",
" material_props[\"youngs_modulus\"]=E\n",
" solver.mesh.particles[0].material=LinearElastic(material_props)\n",
" external_loading_local=jnp.array([0.0, -9.8]).reshape(1,2)\n",
" # solver.mesh.particles[0].velocity = mesh.particles[0].velocity.at[:].set(0.0)\n",
" result = solver.solve_jit(external_loading_local)\n",
" stress = result[\"stress\"]\n",
" loss = jnp.linalg.norm(stress - target_stress)\n",
" return -loss"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"from bayes_opt import BayesianOptimization\n",
"\n",
"pbounds = {\"E\": (800, 1500)}\n",
"optimizer = BayesianOptimization(f=fun, pbounds=pbounds, random_state=1, verbose=2)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"| iter | target | E |\n",
"-------------------------------------\n",
"| \u001b[0m1 \u001b[0m | \u001b[0m-1.017 \u001b[0m | \u001b[0m1.092e+03\u001b[0m |\n",
"| \u001b[0m2 \u001b[0m | \u001b[0m-3.452 \u001b[0m | \u001b[0m1.304e+03\u001b[0m |\n",
"| \u001b[0m3 \u001b[0m | \u001b[0m-1.895 \u001b[0m | \u001b[0m800.1 \u001b[0m |\n",
"| \u001b[0m4 \u001b[0m | \u001b[0m-1.031 \u001b[0m | \u001b[0m1.093e+03\u001b[0m |\n",
"| \u001b[95m5 \u001b[0m | \u001b[95m-0.4102 \u001b[0m | \u001b[95m961.1 \u001b[0m |\n",
"| \u001b[95m6 \u001b[0m | \u001b[95m-0.05521 \u001b[0m | \u001b[95m1.005e+03\u001b[0m |\n",
"| \u001b[0m7 \u001b[0m | \u001b[0m-5.561 \u001b[0m | \u001b[0m1.5e+03 \u001b[0m |\n",
"| \u001b[0m8 \u001b[0m | \u001b[0m-0.2243 \u001b[0m | \u001b[0m1.021e+03\u001b[0m |\n",
"=====================================\n"
]
}
],
"source": [
"optimizer.maximize(init_points=3,n_iter=5)"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'target': -0.0552130751311779, 'params': {'E': 1005.1373320945569}}\n"
]
}
],
"source": [
"print(optimizer.max)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "optaximpo",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}