Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions scripts/mongodbintegrationmvp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Generated by Honegumi (https://arxiv.org/abs/2502.06815)
# pip install ax-platform==0.4.3 numpy pymongo
import numpy as np
from ax.service.ax_client import AxClient, ObjectiveProperties
from pymongo import MongoClient # Added import for MongoDB


obj1_name = "branin"


def branin(x1, x2):
y = float(
(x2 - 5.1 / (4 * np.pi**2) * x1**2 + 5.0 / np.pi * x1 - 6.0) ** 2
+ 10 * (1 - 1.0 / (8 * np.pi)) * np.cos(x1)
+ 10
)

return y


# Connect to MongoDB
tmongo_client = MongoClient("mongodb://localhost:27017/")
db = tmongo_client["ax_db"]
experiments_col = db["experiments"]

# Experiment configuration
parameters = [
{"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
{"name": "x2", "type": "range", "bounds": [0.0, 10.0]},
]
objectives = {obj1_name: ObjectiveProperties(minimize=True)}

# Load existing experiment state or initialize new
record = experiments_col.find_one({"experiment_name": obj1_name})
if record:
ax_client = AxClient()
ax_client.create_experiment(name=obj1_name, parameters=parameters, objectives=objectives)
saved_trials = record.get("trials", [])
# Replay saved trials
for t in saved_trials:
ax_client.complete_trial(trial_index=t["trial_index"], raw_data=t["raw_data"])
start_i = len(saved_trials)
else:
ax_client = AxClient()
ax_client.create_experiment(name=obj1_name, parameters=parameters, objectives=objectives)
start_i = 0
experiments_col.insert_one({"experiment_name": obj1_name, "trials": []})

for i in range(start_i, 19):

parameterization, trial_index = ax_client.get_next_trial()

# extract parameters
x1 = parameterization["x1"]
x2 = parameterization["x2"]

results = branin(x1, x2)
ax_client.complete_trial(trial_index=trial_index, raw_data=results)
# Save trial results to MongoDB
experiments_col.update_one(
{"experiment_name": obj1_name},
{"$push": {"trials": {"trial_index": trial_index, "raw_data": results}}},
)

best_parameters, metrics = ax_client.get_best_parameters()
Loading