Skip to content

Commit

Permalink
refactor: make init functions static methods
Browse files Browse the repository at this point in the history
This allows for better testability while still keeping functions grouped where they'd be expected (in the StateChangeComponents class)

Also improved use of MagicMock in testing
  • Loading branch information
chriskelly committed Oct 29, 2023
1 parent ac26cf8 commit 745773d
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 102 deletions.
126 changes: 68 additions & 58 deletions app/models/financial/state_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
Classes:
StateChangeComponents: Collection of components needed to calculate transition to next state
"""
from __future__ import annotations
from dataclasses import dataclass

import numpy as np
from app import util
from app.models.financial.taxes import Taxes, calc_taxes
from app.data.constants import INTERVALS_PER_YEAR
from app.models.config import Kids, Spending
from app.models.financial.state import State
from app.models.controllers import Controllers

# pylint: disable=redefined-builtin


class Income(util.FloatRepr):
"""Income in a given interval
Expand All @@ -24,7 +25,10 @@ class Income(util.FloatRepr):
pension (float): Pension income
"""

def __init__(self, state: State, controllers: Controllers):
def __init__(self, components: StateChangeComponents):
controllers = components.controllers
state = components.state

self.job_income = controllers.job_income.get_total_income(state.interval_idx)
(
self.social_security_user,
Expand Down Expand Up @@ -63,34 +67,6 @@ def __float__(self):
)


def _calc_spending(state: State, config: Spending, is_working: bool) -> float:
base_amount = -config.yearly_amount / INTERVALS_PER_YEAR * state.inflation
if not is_working:
base_amount *= 1 + config.retirement_change
return base_amount


def _calc_cost_of_kids(current_date: float, spending: float, config: Kids) -> float:
"""Calculate the cost of children
Args:
current_date (float): date of state
spending (float): base spending in current state
config (Kids)
Returns:
float: cost of children for this interval
"""
if config is None:
return 0
current_kids = [
year
for year in config.birth_years
if current_date - config.years_of_support < year <= current_date
]
return len(current_kids) * spending * config.fraction_of_spending


@dataclass
class _NetTransactions(util.FloatRepr):
income: Income
Expand All @@ -116,53 +92,87 @@ class StateChangeComponents:
"""

def __init__(self, state: State, controllers: Controllers):
self._state = state
self._controllers = controllers
self._allocation = controllers.allocation.gen_allocation(state)
self._economic_data = controllers.economic_data.get_economic_state_data(
self.state = state
self.controllers = controllers
self.allocation = controllers.allocation.gen_allocation(state)
self.economic_data = controllers.economic_data.get_economic_state_data(
state.interval_idx
)
self.net_transactions = self._gen_net_transactions()

def _gen_net_transactions(self) -> _NetTransactions:
income = Income(self._state, self._controllers)
portfolio_return = self._state.net_worth * np.dot(
self._economic_data.asset_rates, self._allocation
self.net_transactions = StateChangeComponents._gen_net_transactions(
components=self
)

@staticmethod
def _gen_net_transactions(components: StateChangeComponents) -> _NetTransactions:
income = Income(components)
portfolio_return = components.state.net_worth * np.dot(
components.economic_data.asset_rates, components.allocation
)
costs = StateChangeComponents._gen_costs(
components=components, income=income, portfolio_return=portfolio_return
)
costs = self._gen_costs(income, portfolio_return)

return _NetTransactions(
income=income,
portfolio_return=portfolio_return,
costs=costs,
annuity=self._controllers.annuity.make_annuity_transaction(
state=self._state,
is_working=self._controllers.job_income.is_working(
self._state.interval_idx
annuity=components.controllers.annuity.make_annuity_transaction(
state=components.state,
is_working=components.controllers.job_income.is_working(
components.state.interval_idx
),
initial_net_transaction=income.job_income + costs,
),
)

def _gen_costs(self, income: Income, portfolio_return: float) -> _Costs:
spending = _calc_spending(
state=self._state,
config=self._state.user.spending,
is_working=(
self._controllers.job_income.is_working(self._state.interval_idx)
),
)
@staticmethod
def _gen_costs(
components: StateChangeComponents, income: Income, portfolio_return: float
) -> _Costs:
spending = StateChangeComponents._calc_spending(components)
return _Costs(
spending=spending,
kids=_calc_cost_of_kids(
current_date=self._state.date,
kids=StateChangeComponents._calc_cost_of_kids(
components=components,
spending=spending,
config=self._state.user.kids,
),
taxes=calc_taxes(
total_income=income,
job_income_controller=self._controllers.job_income,
state=self._state,
job_income_controller=components.controllers.job_income,
state=components.state,
portfolio_return=portfolio_return,
),
)

@staticmethod
def _calc_spending(components: StateChangeComponents) -> float:
config = components.state.user.spending
inflation = components.state.inflation
is_working = components.controllers.job_income.is_working(
components.state.interval_idx
)

base_amount = -config.yearly_amount / INTERVALS_PER_YEAR * inflation
if not is_working:
base_amount *= 1 + config.retirement_change
return base_amount

@staticmethod
def _calc_cost_of_kids(components: StateChangeComponents, spending: float) -> float:
"""Calculate the cost of children
Returns:
float: cost of children for this interval
"""
current_date = components.state.date
config = components.state.user.kids

if config is None:
return 0
current_kids = [
year
for year in config.birth_years
if current_date - config.years_of_support < year <= current_date
]
return len(current_kids) * spending * config.fraction_of_spending
110 changes: 66 additions & 44 deletions tests/models/financial/test_state_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,53 +5,72 @@
import pytest
from app.data.constants import INTERVALS_PER_YEAR
from app.models.config import Kids, Spending
from app.models.controllers import Controllers
from app.models.financial.state import State
from app.models.financial.state_change import Income, _calc_cost_of_kids, _calc_spending
from app.models.financial.state_change import Income, StateChangeComponents


def test_income(mocker, first_state):
@pytest.fixture
def controllers_mock(mocker):
"""Fixture for an empty Controllers"""
return mocker.MagicMock(spec=Controllers)


@pytest.fixture
def components_mock(mocker):
"""Fixture for an empty StateChangeComponents"""
return mocker.MagicMock(spec=StateChangeComponents)


def test_income(
controllers_mock: Controllers, first_state, components_mock: StateChangeComponents
):
"""Test that income is summed up correctly"""
fake_values = [1, 2, 3, 4]
controllers_mock = mocker.MagicMock()
controllers_mock.job_income.get_total_income = lambda *arg: fake_values[0]
controllers_mock.social_security.calc_payment = lambda *arg: (
controllers_mock.job_income.get_total_income = lambda *_: fake_values[0]
controllers_mock.social_security.calc_payment = lambda *_: (
fake_values[1],
fake_values[2],
)
controllers_mock.pension.calc_payment = lambda *arg: fake_values[3]
income = Income(state=first_state, controllers=controllers_mock)
controllers_mock.pension.calc_payment = lambda *_: fake_values[3]
components_mock.controllers = controllers_mock
components_mock.state = first_state
income = Income(components_mock)
assert float(income) == pytest.approx(sum(fake_values))


class TestCalcSpending:
yearly_amount = 50
retirement_change = -0.1
inflation = 2
config = None
state = None

@pytest.fixture(autouse=True)
def init(self, first_state: State):
"""Initialize the config and state"""
self.config = Spending(
@pytest.fixture()
def components_mock(
self,
first_state: State,
components_mock: StateChangeComponents,
controllers_mock: Controllers,
):
"""Initialize the mock components"""
components_mock.state = first_state
components_mock.state.user.spending = Spending(
yearly_amount=self.yearly_amount, retirement_change=self.retirement_change
)
self.state = first_state
self.state.inflation = self.inflation
components_mock.state.inflation = self.inflation
components_mock.controllers = controllers_mock
return components_mock

def test_while_working(self, first_state: State):
def test_while_working(self, components_mock: StateChangeComponents):
"""Spending should be unadjusted while working"""
is_working = True
assert _calc_spending(
state=first_state, config=self.config, is_working=is_working
) == pytest.approx(-self.yearly_amount / INTERVALS_PER_YEAR * self.inflation)
components_mock.controllers.job_income.is_working = lambda *_: True
assert StateChangeComponents._calc_spending(components_mock) == pytest.approx(
-self.yearly_amount / INTERVALS_PER_YEAR * self.inflation
)

def test_after_working(self, first_state: State):
def test_after_working(self, components_mock: StateChangeComponents):
"""Spending should be adjusted by the retirement change after working"""
is_working = False
assert _calc_spending(
state=first_state, config=self.config, is_working=is_working
) == pytest.approx(
components_mock.controllers.job_income.is_working = lambda *_: False
assert StateChangeComponents._calc_spending(components_mock) == pytest.approx(
-self.yearly_amount
/ INTERVALS_PER_YEAR
* self.inflation
Expand All @@ -64,44 +83,47 @@ class TestCalcCostOfKids:
cost_of_each_kid = -20
current_date = 2020
years_of_support = 18
birth_years = None
config = None
components_mock: StateChangeComponents

@pytest.fixture(autouse=True)
def init_components_mock(
self, first_state: State, components_mock: StateChangeComponents
):
"""Initialize the mock components"""
self.components_mock = components_mock
self.components_mock.state = first_state
self.components_mock.state.date = self.current_date

def calc_cost(self):
def calc_cost_from_birth_years(self, birth_years: list[float]):
"""Helper function to calculate the cost of kids"""
config = Kids(
self.components_mock.state.user.kids = Kids(
fraction_of_spending=self.cost_of_each_kid / self.spending,
birth_years=self.birth_years,
birth_years=birth_years,
years_of_support=self.years_of_support,
)
return _calc_cost_of_kids(
current_date=self.current_date,
return StateChangeComponents._calc_cost_of_kids(
components=self.components_mock,
spending=self.spending,
config=config,
)

def test_one_kid(self):
"""Test that the cost of one kid is calculated correctly"""
self.birth_years = [2018]
cost_of_kid = self.calc_cost()
cost_of_kid = self.calc_cost_from_birth_years([2018])
assert cost_of_kid == pytest.approx(self.cost_of_each_kid)

def test_multiple_kids(self):
"""Test that the cost of multiple kids is calculated correctly"""
self.birth_years = [2018, 2019]
cost_of_kids = self.calc_cost()
assert cost_of_kids == pytest.approx(
len(self.birth_years) * self.cost_of_each_kid
)
birth_years = [2018, 2019]
cost_of_kids = self.calc_cost_from_birth_years(birth_years)
assert cost_of_kids == pytest.approx(len(birth_years) * self.cost_of_each_kid)

def test_kid_not_born_yet(self):
"""Test that the cost of a kid not born yet is zero"""
self.birth_years = [self.current_date + 1]
cost_of_kids = self.calc_cost()
cost_of_kids = self.calc_cost_from_birth_years([self.current_date + 1])
assert cost_of_kids == pytest.approx(0)

def test_kid_too_old(self):
"""Test that the cost of a kid that is older than `years_of_support` is zero"""
self.birth_years = [self.current_date - (self.years_of_support + 1)]
cost_of_kids = self.calc_cost()
birth_years = [self.current_date - (self.years_of_support + 1)]
cost_of_kids = self.calc_cost_from_birth_years(birth_years)
assert cost_of_kids == pytest.approx(0)

0 comments on commit 745773d

Please sign in to comment.