Skip to content

Commit

Permalink
Add fix for double encoding, source: fullonic/brotli-asgi#24
Browse files Browse the repository at this point in the history
  • Loading branch information
tuffnatty committed Dec 16, 2023
1 parent d10b35d commit 238c7ac
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
37 changes: 35 additions & 2 deletions tests.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
"""Main test for zstd middleware.
"""Main tests for zstd middleware.
This tests are the same as the ones from starlette.tests.middleware.test_gzip
Some of these tests are the same as the ones from starlette.tests.middleware.test_gzip
but using zstd instead.
"""
import functools
import gzip
import io

import pytest

from starlette.applications import Starlette
from starlette.responses import (
JSONResponse,
PlainTextResponse,
Response,
StreamingResponse,
)
from starlette.testclient import TestClient
Expand Down Expand Up @@ -153,3 +156,33 @@ def homepage(request):
assert response.text == "x" * 4000
assert "Content-Encoding" not in response.headers
assert int(response.headers["Content-Length"]) == 4000


def test_zstd_avoids_double_encoding():
# See https://github.com/encode/starlette/pull/1901

app = Starlette()

app.add_middleware(ZstdMiddleware, minimum_size=1)

@app.route("/")
def homepage(request):
gzip_buffer = io.BytesIO()
gzip_file = gzip.GzipFile(mode="wb", fileobj=gzip_buffer)
gzip_file.write(b"hello world" * 200)
gzip_file.close()
body = gzip_buffer.getvalue()
return Response(
body,
headers={
"content-encoding": "gzip",
"x-gzipped-content-length": str(len(body))
}
)

client = TestClient(app)
response = client.get("/", headers={"accept-encoding": "zstd"})
assert response.status_code == 200
assert response.text == "hello world" * 200
assert response.headers["Content-Encoding"] == "gzip"
assert response.headers["Content-Length"] == response.headers["x-gzipped-content-length"]
8 changes: 8 additions & 0 deletions zstd_asgi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
self.send = unattached_send # type: Send
self.initial_message = {} # type: Message
self.started = False
self.content_encoding_set = False
self.zstd_buffer = io.BytesIO()
self.zstd_file = zstandard.ZstdCompressor(
level=level,
Expand All @@ -90,6 +91,13 @@ async def send_with_zstd(self, message: Message) -> None:
# Don't send the initial message until we've determined how to
# modify the outgoing headers correctly.
self.initial_message = message
headers = Headers(raw=self.initial_message["headers"])
self.content_encoding_set = "content-encoding" in headers
elif message_type == "http.response.body" and self.content_encoding_set:
if not self.started:
self.started = True
await self.send(self.initial_message)
await self.send(message)
elif message_type == "http.response.body" and not self.started:
self.started = True
body = message.get("body", b"")
Expand Down

0 comments on commit 238c7ac

Please sign in to comment.