Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding iter0_post_solver_creation callback to extensions object #429

Merged
merged 3 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions mpisppy/extensions/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,18 @@ def post_solve_loop(self):
pass

def pre_iter0(self):
''' Method called at the end of PH_Prep().
When this method is called, all scenarios have been created, and
''' When this method is called, all scenarios have been created, and
the dual/prox terms have been attached to the objective, but the
solvers have not yet been created.
'''
pass

def iter0_post_solver_creation(self):
''' When this method is called, PH iteration 0 has been initiated and
all solver objects have been created.
'''
pass

def post_iter0(self):
''' Method called after the first PH iteration.
When this method is called, one call to solve_loop() has been
Expand Down
9 changes: 6 additions & 3 deletions mpisppy/phbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,8 +772,7 @@ def PH_Prep(

Note:
This function constructs an Extension object if one was specified
at the time the PH object was created. It also calls the
`pre_iter0` method of the Extension object.
at the time the PH object was created.
"""

self.attach_Ws_and_prox()
Expand Down Expand Up @@ -832,9 +831,10 @@ def Iter0(self):
stochastic program with the nonanticipativity constraints
removed.
"""

if (self.extensions is not None):
self.extobject.pre_iter0()

verbose = self.options["verbose"]
dprogress = self.options["display_progress"]
dtiming = self.options["display_timing"]
Expand All @@ -853,6 +853,9 @@ def _vb(msg):
global_toc("Creating solvers")
self._create_solvers()

if (self.extensions is not None):
self.extobject.iter0_post_solver_creation()

teeme = ("tee-rank0-solves" in self.options
and self.options['tee-rank0-solves']
and self.cylinder_rank == 0
Expand Down
Loading