Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework:skip) Test Context and Simulation plugin #4437

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
FLEET_API_REST_DEFAULT_ADDRESS = "0.0.0.0:9095"
EXEC_API_DEFAULT_ADDRESS = "0.0.0.0:9093"
SIMULATIONIO_API_DEFAULT_ADDRESS = "0.0.0.0:9096"

# Constants for ping
PING_DEFAULT_INTERVAL = 30
Expand Down
247 changes: 140 additions & 107 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
ISOLATION_MODE_SUBPROCESS,
MISSING_EXTRA_REST,
SERVERAPPIO_API_DEFAULT_ADDRESS,
SIMULATIONIO_API_DEFAULT_ADDRESS,
TRANSPORT_TYPE_GRPC_ADAPTER,
TRANSPORT_TYPE_GRPC_RERE,
TRANSPORT_TYPE_REST,
Expand All @@ -63,6 +64,7 @@
from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
from flwr.superexec.app import load_executor
from flwr.superexec.exec_grpc import run_exec_api_grpc
from flwr.superexec.simulation import SimulationEngine

from .client_manager import ClientManager
from .history import History
Expand All @@ -79,6 +81,7 @@
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
from .superlink.linkstate import LinkStateFactory
from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc

DATABASE = ":flwr-in-memory-state:"
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
Expand Down Expand Up @@ -215,6 +218,7 @@ def run_superlink() -> None:
# Parse IP addresses
serverappio_address, _, _ = _format_address(args.serverappio_api_address)
exec_address, _, _ = _format_address(args.exec_api_address)
simulationio_address, _, _ = _format_address(args.simulationio_api_address)

# Obtain certificates
certificates = _try_obtain_certificates(args)
Expand All @@ -225,128 +229,148 @@ def run_superlink() -> None:
# Initialize FfsFactory
ffs_factory = FfsFactory(args.storage_dir)

# Start ServerAppIo API
serverappio_server: grpc.Server = run_serverappio_api_grpc(
address=serverappio_address,
# Start Exec API
executor = load_executor(args)
exec_server: grpc.Server = run_exec_api_grpc(
address=exec_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
executor=executor,
certificates=certificates,
config=parse_config_args(
[args.executor_config] if args.executor_config else args.executor_config
),
)
grpc_servers = [serverappio_server]
grpc_servers = [exec_server]

# Start Fleet API
bckg_threads = []
if not args.fleet_api_address:
if args.fleet_api_type in [
TRANSPORT_TYPE_GRPC_RERE,
TRANSPORT_TYPE_GRPC_ADAPTER,
]:
args.fleet_api_address = FLEET_API_GRPC_RERE_DEFAULT_ADDRESS
elif args.fleet_api_type == TRANSPORT_TYPE_REST:
args.fleet_api_address = FLEET_API_REST_DEFAULT_ADDRESS

fleet_address, host, port = _format_address(args.fleet_api_address)

num_workers = args.fleet_api_num_workers
if num_workers != 1:
log(
WARN,
"The Fleet API currently supports only 1 worker. "
"You have specified %d workers. "
"Support for multiple workers will be added in future releases. "
"Proceeding with a single worker.",
args.fleet_api_num_workers,
)
num_workers = 1
# Determine Exec plugin
# If simulation is used, don't start ServerAppIo and Fleet APIs
sim_exec = isinstance(executor, SimulationEngine)

if args.fleet_api_type == TRANSPORT_TYPE_REST:
if (
importlib.util.find_spec("requests")
and importlib.util.find_spec("starlette")
and importlib.util.find_spec("uvicorn")
) is None:
sys.exit(MISSING_EXTRA_REST)

_, ssl_certfile, ssl_keyfile = (
certificates if certificates is not None else (None, None, None)
)

fleet_thread = threading.Thread(
target=_run_fleet_api_rest,
args=(
host,
port,
ssl_keyfile,
ssl_certfile,
state_factory,
ffs_factory,
num_workers,
),
)
fleet_thread.start()
bckg_threads.append(fleet_thread)
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
maybe_keys = _try_setup_node_authentication(args, certificates)
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
if maybe_keys is not None:
(
node_public_keys,
server_private_key,
server_public_key,
) = maybe_keys
state = state_factory.state()
state.store_node_public_keys(node_public_keys)
state.store_server_private_public_key(
private_key_to_bytes(server_private_key),
public_key_to_bytes(server_public_key),
)
log(
INFO,
"Node authentication enabled with %d known public keys",
len(node_public_keys),
)
interceptors = [AuthenticateServerInterceptor(state)]
bckg_threads = []

fleet_server = _run_fleet_api_grpc_rere(
address=fleet_address,
if sim_exec:
simulationio_server: grpc.Server = run_simulationio_api_grpc(
address=simulationio_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
certificates=certificates,
interceptors=interceptors,
)
grpc_servers.append(fleet_server)
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER:
fleet_server = _run_fleet_api_grpc_adapter(
address=fleet_address,
grpc_servers.append(simulationio_server)

else:
# Start ServerAppIo API
serverappio_server: grpc.Server = run_serverappio_api_grpc(
address=serverappio_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
certificates=certificates,
)
grpc_servers.append(fleet_server)
else:
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")

# Start Exec API
exec_server: grpc.Server = run_exec_api_grpc(
address=exec_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
executor=load_executor(args),
certificates=certificates,
config=parse_config_args(
[args.executor_config] if args.executor_config else args.executor_config
),
)
grpc_servers.append(exec_server)
grpc_servers.append(serverappio_server)

# Start Fleet API
if not args.fleet_api_address:
if args.fleet_api_type in [
TRANSPORT_TYPE_GRPC_RERE,
TRANSPORT_TYPE_GRPC_ADAPTER,
]:
args.fleet_api_address = FLEET_API_GRPC_RERE_DEFAULT_ADDRESS
elif args.fleet_api_type == TRANSPORT_TYPE_REST:
args.fleet_api_address = FLEET_API_REST_DEFAULT_ADDRESS

fleet_address, host, port = _format_address(args.fleet_api_address)

num_workers = args.fleet_api_num_workers
if num_workers != 1:
log(
WARN,
"The Fleet API currently supports only 1 worker. "
"You have specified %d workers. "
"Support for multiple workers will be added in future releases. "
"Proceeding with a single worker.",
args.fleet_api_num_workers,
)
num_workers = 1

if args.fleet_api_type == TRANSPORT_TYPE_REST:
if (
importlib.util.find_spec("requests")
and importlib.util.find_spec("starlette")
and importlib.util.find_spec("uvicorn")
) is None:
sys.exit(MISSING_EXTRA_REST)

_, ssl_certfile, ssl_keyfile = (
certificates if certificates is not None else (None, None, None)
)

if args.isolation == ISOLATION_MODE_SUBPROCESS:
# Scheduler thread
scheduler_th = threading.Thread(
target=_flwr_serverapp_scheduler,
args=(state_factory, args.serverappio_api_address, args.ssl_ca_certfile),
)
scheduler_th.start()
bckg_threads.append(scheduler_th)
fleet_thread = threading.Thread(
target=_run_fleet_api_rest,
args=(
host,
port,
ssl_keyfile,
ssl_certfile,
state_factory,
ffs_factory,
num_workers,
),
)
fleet_thread.start()
bckg_threads.append(fleet_thread)
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
maybe_keys = _try_setup_node_authentication(args, certificates)
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
if maybe_keys is not None:
(
node_public_keys,
server_private_key,
server_public_key,
) = maybe_keys
state = state_factory.state()
state.store_node_public_keys(node_public_keys)
state.store_server_private_public_key(
private_key_to_bytes(server_private_key),
public_key_to_bytes(server_public_key),
)
log(
INFO,
"Node authentication enabled with %d known public keys",
len(node_public_keys),
)
interceptors = [AuthenticateServerInterceptor(state)]

fleet_server = _run_fleet_api_grpc_rere(
address=fleet_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
certificates=certificates,
interceptors=interceptors,
)
grpc_servers.append(fleet_server)
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER:
fleet_server = _run_fleet_api_grpc_adapter(
address=fleet_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
certificates=certificates,
)
grpc_servers.append(fleet_server)
else:
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")

if args.isolation == ISOLATION_MODE_SUBPROCESS:
# Scheduler thread
scheduler_th = threading.Thread(
target=_flwr_serverapp_scheduler,
args=(
state_factory,
args.serverappio_api_address,
args.ssl_ca_certfile,
),
)
scheduler_th.start()
bckg_threads.append(scheduler_th)

# Graceful shutdown
register_exit_handlers(
Expand All @@ -361,7 +385,7 @@ def run_superlink() -> None:
for thread in bckg_threads:
if not thread.is_alive():
sys.exit(1)
serverappio_server.wait_for_termination(timeout=1)
exec_server.wait_for_termination(timeout=1)


def _flwr_serverapp_scheduler(
Expand Down Expand Up @@ -657,6 +681,7 @@ def _parse_args_run_superlink() -> argparse.ArgumentParser:
_add_args_serverappio_api(parser=parser)
_add_args_fleet_api(parser=parser)
_add_args_exec_api(parser=parser)
_add_args_simulationio_api(parser=parser)

return parser

Expand Down Expand Up @@ -790,3 +815,11 @@ def _add_args_exec_api(parser: argparse.ArgumentParser) -> None:
"For example:\n\n`--executor-config 'verbose=true "
'root-certificates="certificates/superlink-ca.crt"\'`',
)


def _add_args_simulationio_api(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--simulationio-api-address",
help="SimulationIo API (gRPC) server address (IPv4, IPv6, or a domain name).",
default=SIMULATIONIO_API_DEFAULT_ADDRESS,
)
13 changes: 1 addition & 12 deletions src/py/flwr/server/serverapp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,7 @@

from flwr.cli.config_utils import get_fab_metadata
from flwr.cli.install import install_from_fab
from flwr.common.config import (
get_flwr_dir,
get_fused_config_from_dir,
get_project_config,
get_project_dir,
)
from flwr.common.config import get_flwr_dir, get_project_config, get_project_dir
from flwr.common.constant import Status, SubStatus
from flwr.common.logger import (
log,
Expand Down Expand Up @@ -209,12 +204,6 @@ def run_serverapp( # pylint: disable=R0914, disable=W0212

# Obtain server app reference and the run config
server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
server_app_run_config = get_fused_config_from_dir(
Path(app_path), run.override_config
)

# Update run_config in context
context.run_config = server_app_run_config

log(
DEBUG,
Expand Down
15 changes: 15 additions & 0 deletions src/py/flwr/server/superlink/simulation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower SimulationIo service."""
Loading
Loading