Skip to content

Commit

Permalink
Moved code reference class variables to RXTask class
Browse files Browse the repository at this point in the history
  • Loading branch information
nvidianz committed Sep 25, 2024
1 parent 4e13a91 commit 6fd2678
Showing 1 changed file with 30 additions and 22 deletions.
52 changes: 30 additions & 22 deletions nvflare/fuel/f3/streaming/byte_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,29 @@ def __init__(self, sid: int, origin: str, cell: CoreCell):
def __str__(self):
return f"Rx[SID:{self.sid} from {self.origin} for {self.channel}/{self.topic} Size: {self.size}]"

@classmethod
def find_or_create_task(cls, message: Message, cell: CoreCell) -> Optional["RxTask"]:

sid = message.get_header(StreamHeaderKey.STREAM_ID)
origin = message.get_header(MessageHeaderKey.ORIGIN)
error = message.get_header(StreamHeaderKey.ERROR_MSG, None)

with cls.map_lock:
task = cls.rx_task_map.get(sid, None)
if not task:
if error:
log.warning(f"Received error for non-existing stream: SID {sid} from {origin}")
return None

task = RxTask(sid, origin, cell)
cls.rx_task_map[sid] = task
else:
if error:
task.stop(StreamError(f"{task} Received error from {origin}: {error}"), notify=False)
return None

return task

def read(self, size: int) -> BytesAlike:

count = 0
Expand All @@ -107,11 +130,12 @@ def read(self, size: int) -> BytesAlike:

count += 1

def process_chunk(self, seq: int, message: Message) -> bool:
def process_chunk(self, message: Message) -> bool:
"""Returns True if a new stream is created"""

new_stream = False
with self.lock:
seq = message.get_header(StreamHeaderKey.SEQUENCE)
if seq == 0:
if self.stream_future:
log.warning(f"{self} Received duplicate chunk 0, ignored")
Expand All @@ -120,7 +144,7 @@ def process_chunk(self, seq: int, message: Message) -> bool:
self._handle_new_stream(message)
new_stream = True

self._handle_incoming_data(message)
self._handle_incoming_data(seq, message)
return new_stream

def _handle_new_stream(self, message: Message):
Expand All @@ -132,10 +156,9 @@ def _handle_new_stream(self, message: Message):
self.stream_future = StreamFuture(self.sid, self.headers)
self.stream_future.set_size(self.size)

def _handle_incoming_data(self, message: Message):
def _handle_incoming_data(self, seq: int, message: Message):

data_type = message.get_header(StreamHeaderKey.DATA_TYPE)
seq = message.get_header(StreamHeaderKey.SEQUENCE)

last_chunk = data_type == StreamDataType.FINAL
if last_chunk:
Expand Down Expand Up @@ -310,26 +333,11 @@ def register_callback(self, channel: str, topic: str, stream_cb: Callable, *args

def _data_handler(self, message: Message):

sid = message.get_header(StreamHeaderKey.STREAM_ID)
origin = message.get_header(MessageHeaderKey.ORIGIN)
seq = message.get_header(StreamHeaderKey.SEQUENCE)
error = message.get_header(StreamHeaderKey.ERROR_MSG, None)

with RxTask.map_lock:
task = RxTask.rx_task_map.get(sid, None)
if not task:
if error:
log.debug(f"Received error for non-existing stream: SID {sid} from {origin}")
return

task = RxTask(sid, origin, self.cell)
RxTask.rx_task_map[sid] = task

if error:
task.stop(StreamError(f"{task} Received error from {origin}: {error}"), notify=False)
task = RxTask.find_or_create_task(message, self.cell)
if not task:
return

new_stream = task.process_chunk(seq, message)
new_stream = task.process_chunk(message)
if new_stream:
# Invoke callback
callback = self.registry.find(task.channel, task.topic)
Expand Down

0 comments on commit 6fd2678

Please sign in to comment.