diff --git a/src/fees/compute_fees.py b/src/fees/compute_fees.py index aa28105..be35e18 100644 --- a/src/fees/compute_fees.py +++ b/src/fees/compute_fees.py @@ -48,11 +48,11 @@ def __init__( self.buy_token_clearing_price = buy_token_clearing_price self.fee_policies = fee_policies self.partner_fee_recipient = partner_fee_recipient # if there is no partner, then its value is set to the null address - self.network_fee: int = -1 - self.total_protocol_fee: int = -1 - self.partner_fee: int = -1 - self.compute_all_fees() + total_protocol_fee, partner_fee, network_fee = self.compute_all_fees() + self.total_protocol_fee = total_protocol_fee + self.partner_fee = partner_fee + self.network_fee = network_fee return def volume(self) -> int: @@ -84,30 +84,31 @@ def surplus(self) -> int: return current_limit_sell_amount - self.sell_amount raise ValueError(f"Order kind {self.kind} is invalid.") - def compute_all_fees(self) -> None: + def compute_all_fees(self) -> tuple[int, int, int]: raw_trade = deepcopy(self) - self.total_protocol_fee = 0 - self.partner_fee = 0 + total_protocol_fee = 0 + partner_fee = 0 + network_fee = 0 if self.fee_policies: for i, fee_policy in enumerate(reversed(self.fee_policies)): raw_trade = fee_policy.reverse_protocol_fee(raw_trade) ## we assume that partner fee is the last to be applied if i == 0 and self.partner_fee_recipient is not NULL_ADDRESS: - self.partner_fee = raw_trade.surplus() - self.surplus() - self.total_protocol_fee = raw_trade.surplus() - self.surplus() + partner_fee = raw_trade.surplus() - self.surplus() + total_protocol_fee = raw_trade.surplus() - self.surplus() surplus_fee = self.compute_surplus_fee() # in the surplus token - network_fee_temp = surplus_fee - self.total_protocol_fee + network_fee_in_surplus_token = surplus_fee - self.total_protocol_fee if self.kind == "sell": - self.network_fee = int( - network_fee_temp + network_fee = int( + network_fee_in_surplus_token * Fraction( self.buy_token_clearing_price, self.sell_token_clearing_price ) ) else: - self.network_fee = network_fee_temp - return + network_fee = network_fee_in_surplus_token + return total_protocol_fee, partner_fee, network_fee def surplus_token(self) -> HexBytes: """Returns the surplus token""" @@ -480,12 +481,17 @@ def parse_fee_policies( # computing fees def compute_fee_imbalances( settlement_data: SettlementData, -) -> tuple[dict[str, tuple[str, int]], dict[str, tuple[str, int]]]: +) -> tuple[ + dict[str, tuple[str, int]], + dict[str, tuple[str, int]], + dict[str, tuple[str, int, str]], +]: protocol_fees: dict[str, tuple[str, int]] = {} network_fees: dict[str, tuple[str, int]] = {} + partner_fees: dict[str, tuple[str, int, str]] = {} for trade in settlement_data.trades: # protocol fees - protocol_fee_amount = trade.total_protocol_fee + protocol_fee_amount = trade.total_protocol_fee - trade.partner_fee protocol_fee_token = trade.surplus_token() protocol_fees[trade.order_uid.to_0x_hex()] = ( protocol_fee_token.to_0x_hex(), @@ -495,8 +501,13 @@ def compute_fee_imbalances( trade.sell_token.to_0x_hex(), trade.network_fee, ) + partner_fees[trade.order_uid.to_0x_hex()] = ( + protocol_fee_token.to_0x_hex(), + trade.partner_fee, + trade.partner_fee_recipient, + ) - return protocol_fees, network_fees + return protocol_fees, network_fees, partner_fees # combined function