diff --git a/nvflare/private/fed/simulator/simulator_server.py b/nvflare/private/fed/simulator/simulator_server.py index ff0c3de389..69d1f16df1 100644 --- a/nvflare/private/fed/simulator/simulator_server.py +++ b/nvflare/private/fed/simulator/simulator_server.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os from typing import Dict, List, Optional from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import FLContextKey, ReservedKey, ReservedTopic, ServerCommandKey +from nvflare.apis.fl_constant import FLContextKey, ReservedKey, ReservedTopic, ServerCommandKey, SiteType from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import ReturnCode, Shareable, make_reply +from nvflare.apis.workspace import Workspace from nvflare.fuel.f3.message import Message from nvflare.private.fed.server.run_manager import RunManager from nvflare.private.fed.server.server_state import HotState @@ -144,6 +145,20 @@ def _create_server_engine(self, args, snapshot_persistor): def deploy(self, args, grpc_args=None, secure_train=False): super(FederatedServer, self).deploy(args, grpc_args, secure_train) + os.makedirs(os.path.join(args.workspace, "local"), exist_ok=True) + os.makedirs(os.path.join(args.workspace, "startup"), exist_ok=True) + workspace = Workspace(args.workspace, "server", args.config_folder) + run_manager = RunManager( + server_name=SiteType.SERVER, + engine=self.engine, + job_id="", + workspace=workspace, + components={}, + handlers=[], + ) + self.engine.set_run_manager(run_manager) + self.engine.initialize_comm(self.cell) + self._register_cellnet_cbs() def stop_training(self): diff --git a/tests/unit_test/fuel/f3/streaming/streaming_test.py b/tests/unit_test/fuel/f3/streaming/streaming_test.py index aa8273d39a..890ce4f47a 100644 --- a/tests/unit_test/fuel/f3/streaming/streaming_test.py +++ b/tests/unit_test/fuel/f3/streaming/streaming_test.py @@ -32,15 +32,15 @@ def __init__(self): class TestStreamCell: - @pytest.fixture(scope="module") + @pytest.fixture(scope="session") def port(self): return get_open_ports(1)[0] - @pytest.fixture(scope="module") + @pytest.fixture(scope="session") def state(self): return State() - @pytest.fixture(scope="module") + @pytest.fixture(scope="session") def server_cell(self, port, state): listening_url = f"tcp://localhost:{port}" cell = CoreCell(RX_CELL, listening_url, secure=False, credentials={}) @@ -51,7 +51,7 @@ def server_cell(self, port, state): yield stream_cell cell.stop() - @pytest.fixture(scope="module") + @pytest.fixture(scope="session") def client_cell(self, port, state): connect_url = f"tcp://localhost:{port}" cell = CoreCell(TX_CELL, connect_url, secure=False, credentials={}) diff --git a/tests/unit_test/private/fed/app/deployer/simulator_deployer_test.py b/tests/unit_test/private/fed/app/deployer/simulator_deployer_test.py index 1fb7e6d204..55240a4695 100644 --- a/tests/unit_test/private/fed/app/deployer/simulator_deployer_test.py +++ b/tests/unit_test/private/fed/app/deployer/simulator_deployer_test.py @@ -13,6 +13,7 @@ # limitations under the License. import argparse +import os import shutil import tempfile import unittest @@ -26,6 +27,8 @@ from nvflare.private.fed.app.deployer.simulator_deployer import SimulatorDeployer from nvflare.private.fed.app.simulator.simulator import define_simulator_parser from nvflare.private.fed.client.fed_client import FederatedClient +from nvflare.private.fed.server.run_manager import RunManager +from nvflare.private.fed.simulator.simulator_server import SimulatorServer # from nvflare.private.fed.simulator.simulator_server import SimulatorServer from nvflare.security.security import EmptyAuthorizer @@ -49,17 +52,6 @@ def _create_parser(self): return parser - # Disable this test temporarily since it conflicts with other tests. - # def test_create_server(self): - # with patch("nvflare.private.fed.app.utils.FedAdminServer") as mock_admin: - # workspace = tempfile.mkdtemp() - # parser = self._create_parser() - # args = parser.parse_args(["job_folder", "-w" + workspace, "-n 2", "-t 1"]) - # _, server = self.deployer.create_fl_server(args) - # assert isinstance(server, SimulatorServer) - # server.cell.stop() - # shutil.rmtree(workspace) - @patch("nvflare.private.fed.client.fed_client.FederatedClient.register") # @patch("nvflare.private.fed.app.deployer.simulator_deployer.FederatedClient.start_heartbeat") # @patch("nvflare.private.fed.app.deployer.simulator_deployer.FedAdminAgent") @@ -71,3 +63,22 @@ def test_create_client(self, mock_register): assert isinstance(client, FederatedClient) client.cell.stop() shutil.rmtree(workspace) + + @patch("nvflare.private.fed.server.admin.FedAdminServer.start") + @patch("nvflare.private.fed.simulator.simulator_server.SimulatorServer._register_cellnet_cbs") + @patch("nvflare.private.fed.server.fed_server.Cell") + def test_create_server(self, mock_admin, mock_simulator_server, mock_cell): + workspace = tempfile.mkdtemp() + os.mkdir(os.path.join(workspace, "local")) + os.mkdir(os.path.join(workspace, "startup")) + parser = self._create_parser() + args = parser.parse_args(["job_folder", "-w" + workspace, "-n 2", "-t 1"]) + args.config_folder = "config" + _, server = self.deployer.create_fl_server(args) + + assert isinstance(server, SimulatorServer) + assert isinstance(server.engine.run_manager, RunManager) + + server.cell.stop() + server.close() + shutil.rmtree(workspace)