Skip to content

Commit

Permalink
Merge pull request #909 from hailiangzhang/session-cookie-reuse
Browse files Browse the repository at this point in the history
Thread-local Session Management and Cookie Reuse to Address EDL DSE issue
  • Loading branch information
betolink authored Jan 18, 2025
2 parents 4d0721d + b628c71 commit d36baad
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 10 deletions.
45 changes: 36 additions & 9 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import logging
import threading
import traceback
from functools import lru_cache
from itertools import chain
Expand All @@ -17,7 +18,7 @@

import earthaccess

from .auth import Auth
from .auth import Auth, SessionWithHeaderRedirection
from .daac import DAAC_TEST_URLS, find_provider
from .results import DataGranule
from .search import DataCollections
Expand Down Expand Up @@ -118,6 +119,7 @@ def __init__(self, auth: Any, pre_authorize: bool = False) -> None:
Parameters:
auth: Auth instance to download and access data.
"""
self.thread_locals = threading.local()
if auth.authenticated is True:
self.auth = auth
self._s3_credentials: Dict[
Expand All @@ -126,7 +128,7 @@ def __init__(self, auth: Any, pre_authorize: bool = False) -> None:
oauth_profile = f"https://{auth.system.edl_hostname}/profile"
# sets the initial URS cookie
self._requests_cookies: Dict[str, Any] = {}
self.set_requests_session(oauth_profile)
self.set_requests_session(oauth_profile, bearer_token=True)
if pre_authorize:
# collect cookies from other DAACs
for url in DAAC_TEST_URLS:
Expand Down Expand Up @@ -182,7 +184,7 @@ def _running_in_us_west_2(self) -> bool:
return False

def set_requests_session(
self, url: str, method: str = "get", bearer_token: bool = False
self, url: str, method: str = "get", bearer_token: bool = True
) -> None:
"""Sets up a `requests` session with bearer tokens that are used by CMR.
Expand Down Expand Up @@ -323,19 +325,19 @@ def get_fsspec_session(self) -> fsspec.AbstractFileSystem:
session = fsspec.filesystem("https", client_kwargs=client_kwargs)
return session

def get_requests_session(self, bearer_token: bool = True) -> requests.Session:
def get_requests_session(self) -> SessionWithHeaderRedirection:
"""Returns a requests HTTPS session with bearer tokens that are used by CMR.
This HTTPS session can be used to download granules if we want to use a direct,
lower level API.
Parameters:
bearer_token: if true, will be used for authenticated queries on CMR
Returns:
requests Session
"""
return self.auth.get_session()
if hasattr(self, "_http_session"):
return self._http_session
else:
raise AttributeError("The requests session hasn't been set up yet.")

def open(
self,
Expand Down Expand Up @@ -651,6 +653,27 @@ def _get_granules(
data_links, local_path, pqdm_kwargs=pqdm_kwargs
)

def _clone_session_in_local_thread(
self, original_session: SessionWithHeaderRedirection
) -> None:
"""Clone the original session and store it in the local thread context.
This method creates a new session that replicates the headers, cookies, and authentication settings
from the provided original session. The new session is stored in a thread-local storage.
Parameters:
original_session (SessionWithHeaderRedirection): The session to be cloned.
Returns:
None
"""
if not hasattr(self.thread_locals, "local_thread_session"):
local_thread_session = SessionWithHeaderRedirection()
local_thread_session.headers.update(original_session.headers)
local_thread_session.cookies.update(original_session.cookies)
local_thread_session.auth = original_session.auth
self.thread_locals.local_thread_session = local_thread_session

def _download_file(self, url: str, directory: Path) -> str:
"""Download a single file from an on-prem location, a DAAC data center.
Expand All @@ -668,7 +691,11 @@ def _download_file(self, url: str, directory: Path) -> str:
path = directory / Path(local_filename)
if not path.exists():
try:
session = self.auth.get_session()
original_session = self.get_requests_session()
# This reuses the auth cookie, we make sure we only authenticate N threads instead
# of one per file, see #913
self._clone_session_in_local_thread(original_session)
session = self.thread_locals.local_thread_session
with session.get(url, stream=True, allow_redirects=True) as r:
r.raise_for_status()
with open(path, "wb") as f:
Expand Down
82 changes: 81 additions & 1 deletion tests/unit/test_store.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# package imports
import os
import threading
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch

import fsspec
import pytest
import responses
import s3fs
from earthaccess import Auth, Store
from earthaccess.auth import SessionWithHeaderRedirection
from earthaccess.store import EarthAccessFile
from pqdm.threads import pqdm


class TestStoreSessions(unittest.TestCase):
Expand Down Expand Up @@ -128,14 +133,89 @@ def test_store_can_create_s3_fsspec_session(self):

return None

@responses.activate
def test_session_reuses_token_download(self):
mock_creds = {
"accessKeyId": "sure",
"secretAccessKey": "correct",
"sessionToken": "whynot",
}
test_cases = [
(2, 500), # 2 threads, 500 files
(4, 400), # 4 threads, 400 files
(8, 5000), # 8 threads, 5k files
]
for n_threads, n_files in test_cases:
with self.subTest(n_threads=n_threads, n_files=n_files):
urls = [f"https://example.com/file{i}" for i in range(1, n_files + 1)]
for i, url in enumerate(urls):
responses.add(
responses.GET, url, body=f"Content of file {i + 1}", status=200
)

mock_auth = MagicMock()
mock_auth.authenticated = True
mock_auth.system.edl_hostname = "urs.earthdata.nasa.gov"
responses.add(
responses.GET,
"https://urs.earthdata.nasa.gov/profile",
json=mock_creds,
status=200,
)

original_session = SessionWithHeaderRedirection()
original_session.cookies.set("sessionid", "mocked-session-cookie")
mock_auth.get_session.return_value = original_session

store = Store(auth=mock_auth)
store.thread_locals = threading.local() # Use real thread-local storage

# Track cloned sessions
cloned_sessions = set()

def mock_clone_session_in_local_thread(original_session):
"""Mock session cloning to track cloned sessions."""
if not hasattr(store.thread_locals, "local_thread_session"):
session = SessionWithHeaderRedirection()
session.cookies.update(original_session.cookies)
cloned_sessions.add(id(session))
store.thread_locals.local_thread_session = session

with patch.object(
store,
"_clone_session_in_local_thread",
side_effect=mock_clone_session_in_local_thread,
):
mock_directory = Path("/mock/directory")
downloaded_files = []

def mock_download_file(url):
"""Mock file download to track downloaded files."""
# Ensure session cloning happens before downloading
store._clone_session_in_local_thread(original_session)
downloaded_files.append(url)
return mock_directory / f"{url.split('/')[-1]}"

with patch.object(
store, "_download_file", side_effect=mock_download_file
):
# Test multi-threaded download
pqdm(urls, store._download_file, n_jobs=n_threads) # type: ignore

# We make sure we reuse the token up to N threads
self.assertTrue(len(cloned_sessions) <= n_threads)

self.assertEqual(len(downloaded_files), n_files) # 10 files downloaded
self.assertCountEqual(downloaded_files, urls) # All files accounted for


@pytest.mark.xfail(
reason="Expected failure: Reproduces a bug (#610) that has not yet been fixed."
)
def test_earthaccess_file_getattr():
fs = fsspec.filesystem("memory")
with fs.open("/foo", "wb") as f:
earthaccess_file = EarthAccessFile(f, granule="foo")
earthaccess_file = EarthAccessFile(f, granule="foo") # type: ignore
assert f.tell() == earthaccess_file.tell()
# cleanup
fs.store.clear()

0 comments on commit d36baad

Please sign in to comment.