Skip to content

Commit

Permalink
fix: scope state key handling (#3070)
Browse files Browse the repository at this point in the history
* Fix scope "state" key handling
* Don't overwrite initial state
* Simplify
  • Loading branch information
provinzkraut authored Feb 4, 2024
1 parent f4c74f6 commit 5b7d2dc
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 8 deletions.
2 changes: 1 addition & 1 deletion litestar/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ async def __call__(
return

scope["app"] = self
scope["state"] = {}
scope.setdefault("state", {})
await self.asgi_handler(scope, receive, self._wrap_send(send=send, scope=scope)) # type: ignore[arg-type]

async def _call_lifespan_hook(self, hook: LifespanHook) -> None:
Expand Down
2 changes: 1 addition & 1 deletion litestar/connection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def state(self) -> StateT:
Returns:
A State instance constructed from the scope["state"] value.
"""
return cast("StateT", State(self.scope["state"]))
return cast("StateT", State(self.scope.get("state")))

@property
def url(self) -> URL:
Expand Down
7 changes: 3 additions & 4 deletions litestar/utils/scope/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,9 @@ def from_scope(cls, scope: Scope) -> Self:
Returns:
A `ConnectionState` object.
"""
if state := scope["state"].get(CONNECTION_STATE_KEY):
return state # type: ignore[no-any-return]
state = scope["state"][CONNECTION_STATE_KEY] = cls()
scope["state"][CONNECTION_STATE_KEY] = state
base = scope["state"] if "state" in scope else scope
if (state := base.get(CONNECTION_STATE_KEY)) is None:
state = base[CONNECTION_STATE_KEY] = cls()
return state


Expand Down
18 changes: 17 additions & 1 deletion tests/unit/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from dataclasses import fields
from typing import TYPE_CHECKING, List, Tuple
from typing import TYPE_CHECKING, Callable, List, Tuple
from unittest.mock import MagicMock, Mock, PropertyMock

import pytest
Expand Down Expand Up @@ -221,6 +221,22 @@ def modify_state_in_hook(app_config: AppConfig) -> AppConfig:
assert app.state._state == {"a": "b", "c": "D", "e": "f"}


async def test_dont_override_initial_state(create_scope: Callable[..., Scope]) -> None:
app = Litestar()

scope = create_scope(headers=[], state={"foo": "bar"})

async def send(message: Message) -> None:
pass

async def receive() -> None:
pass

await app(scope, receive, send) # type: ignore[arg-type]

assert scope["state"].get("foo") == "bar"


def test_app_from_config(app_config_object: AppConfig) -> None:
Litestar.from_config(app_config_object)

Expand Down
9 changes: 8 additions & 1 deletion tests/unit/test_utils/test_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
get_litestar_scope_state,
set_litestar_scope_state,
)
from litestar.utils.scope.state import ScopeState
from litestar.utils.scope.state import CONNECTION_STATE_KEY, ScopeState

if TYPE_CHECKING:
from litestar.types.asgi_types import Scope
Expand All @@ -21,6 +21,13 @@ def scope(create_scope: Callable[..., Scope]) -> Scope:
return create_scope()


def test_from_scope_without_state() -> None:
scope = {} # type: ignore[var-annotated]
state = ScopeState.from_scope(scope) # type: ignore[arg-type]
assert "state" not in scope
assert scope[CONNECTION_STATE_KEY] is state


@pytest.mark.parametrize(("pop",), [(True,), (False,)])
def test_get_litestar_scope_state_arbitrary_value(pop: bool, scope: Scope) -> None:
key = "test"
Expand Down

0 comments on commit 5b7d2dc

Please sign in to comment.