Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed lahman.py; added test_lahman.py #449

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/lahman.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Lahman Data Acquisition Functions

Pull data from [Sean Lahman's database](http://www.seanlahman.com/baseball-archive/statistics/), also hosted by [Chadwick Bureau on GitHub](https://github.com/chadwickbureau/baseballdatabank) -- our new source -- using the following functions:
Pulls data linked from [Sean Lahman's database](http://seanlahman.com/) now hosted on dropbox -- using the following functions:

```python
from pybaseball.lahman import *
download_lahman() #download the entire lahman database to your current working directory
download_lahman()

# a table of all player biographical info and ids
people = people()
Expand Down Expand Up @@ -81,7 +81,7 @@ schools = schools()
series_post = series_post()

# data on teams by year: record, division, stadium, attendance, etc
teams = teams()
teams = teams_core()

# current and historical franchises, whether they're still active, and their ids
teams_franchises = teams_franchises()
Expand Down
1 change: 0 additions & 1 deletion pybaseball/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
from .lahman import schools
from .lahman import series_post
from .lahman import teams_core
from .lahman import teams_upstream
from .lahman import teams_franchises
from .lahman import teams_half
from .lahman import download_lahman
Expand Down
157 changes: 91 additions & 66 deletions pybaseball/lahman.py
Original file line number Diff line number Diff line change
@@ -1,136 +1,161 @@
from datetime import timedelta
from io import BytesIO
from os import makedirs
from os import path
from typing import Optional
from zipfile import ZipFile

from bs4 import BeautifulSoup
import pandas as pd
from pathlib import Path
from py7zr import SevenZipFile
import requests
from requests_cache import CachedSession

from . import cache

url = "https://github.com/chadwickbureau/baseballdatabank/archive/master.zip"
base_string = "baseballdatabank-master"

_handle = None

def get_lahman_zip() -> Optional[ZipFile]:
# Retrieve the Lahman database zip file, returns None if file already exists in cwd.
# If we already have the zip file, keep re-using that.
# Making this a function since everything else will be re-using these lines
global _handle
if path.exists(path.join(cache.config.cache_directory, base_string)):
_handle = None
elif not _handle:
s = requests.get(url, stream=True)
_handle = ZipFile(BytesIO(s.content))
return _handle

def download_lahman():
# download entire lahman db to present working directory
z = get_lahman_zip()
if z is not None:
z.extractall(cache.config.cache_directory)
z = get_lahman_zip()
# this way we'll now start using the extracted zip directory
# instead of the session ZipFile object

def _get_file(tablename: str, quotechar: str = "'") -> pd.DataFrame:
z = get_lahman_zip()
f = f'{base_string}/{tablename}'
# NB: response will be cached for 30 days unless force is True
def _get_response(force:bool=False) -> requests.Response:
session = _get_session()
response = session.get("http://seanlahman.com", refresh=force)
return response

# For example, "https://www.dropbox.com/scl/fi/hy0sxw6gaai7ghemrshi8/lahman_1871-2023_csv.7z?rlkey=edw1u63zzxg48gvpcmr3qpnhz&dl=1"
def _get_download_url(force:bool=False) -> str:
response = _get_response(force)
soup = BeautifulSoup(response.content, "html.parser")

anchor = soup.find("a", string="Comma-delimited version")
url = anchor["href"].replace("dl=0", "dl=1")

return url

def _get_cache_dir() -> str:
return f"{cache.config.cache_directory}/lahman"

def _get_session() -> CachedSession:
return CachedSession(_get_cache_dir(), expire_after=timedelta(days=30))

def _get_base_string() -> str:
url = _get_download_url()
path = Path(url)

return path.stem

def _get_file_path(filename: str = "") -> str:
base_string = _get_base_string()
return path.join(_get_cache_dir(), base_string, filename)

def _get_table(filename: str,
quotechar: str = "'",
encoding=None,
on_bad_lines="error") -> pd.DataFrame:
filepath = _get_file_path(filename)
data = pd.read_csv(
f"{path.join(cache.config.cache_directory, f)}" if z is None else z.open(f),
filepath,
header=0,
sep=',',
quotechar=quotechar
sep=",",
quotechar=quotechar,
encoding=encoding,
on_bad_lines=on_bad_lines,
)
return data

# Return whether download happened (True) or if cache used (False)
def download_lahman(force: bool = False) -> bool:
if force or not path.exists(_get_file_path()):
cache_dir = _get_cache_dir()
base_string = _get_base_string()
makedirs(f"{cache_dir}/{base_string}", exist_ok=True)

# do this for every table in the lahman db so they can exist as separate functions
def parks() -> pd.DataFrame:
return _get_file('core/Parks.csv')
url = _get_download_url(force)
stream = requests.get(url, stream=True)
with SevenZipFile(BytesIO(stream.content)) as zip:
zip.extractall(cache_dir)
return True
return False

# do this for every table in the lahman db so they can exist as separate functions
def all_star_full() -> pd.DataFrame:
return _get_file("core/AllstarFull.csv")
return _get_table("AllstarFull.csv")

def appearances() -> pd.DataFrame:
return _get_file("core/Appearances.csv")
return _get_table("Appearances.csv")

def awards_managers() -> pd.DataFrame:
return _get_file("contrib/AwardsManagers.csv")
return _get_table("AwardsManagers.csv")

def awards_players() -> pd.DataFrame:
return _get_file("contrib/AwardsPlayers.csv")
return _get_table("AwardsPlayers.csv")

def awards_share_managers() -> pd.DataFrame:
return _get_file("contrib/AwardsShareManagers.csv")
return _get_table("AwardsShareManagers.csv")

def awards_share_players() -> pd.DataFrame:
return _get_file("contrib/AwardsSharePlayers.csv")
return _get_table("AwardsSharePlayers.csv")

def batting() -> pd.DataFrame:
return _get_file("core/Batting.csv")
return _get_table("Batting.csv")

def batting_post() -> pd.DataFrame:
return _get_file("core/BattingPost.csv")
return _get_table("BattingPost.csv")

def college_playing() -> pd.DataFrame:
return _get_file("contrib/CollegePlaying.csv")
return _get_table("CollegePlaying.csv")

def fielding() -> pd.DataFrame:
return _get_file("core/Fielding.csv")
return _get_table("Fielding.csv")

def fielding_of() -> pd.DataFrame:
return _get_file("core/FieldingOF.csv")
return _get_table("FieldingOF.csv")

def fielding_of_split() -> pd.DataFrame:
return _get_file("core/FieldingOFsplit.csv")
return _get_table("FieldingOFsplit.csv")

def fielding_post() -> pd.DataFrame:
return _get_file("core/FieldingPost.csv")
return _get_table("FieldingPost.csv")

def hall_of_fame() -> pd.DataFrame:
return _get_file("contrib/HallOfFame.csv")
return _get_table("HallOfFame.csv")

def home_games() -> pd.DataFrame:
return _get_file("core/HomeGames.csv")
return _get_table("HomeGames.csv")

def managers() -> pd.DataFrame:
return _get_file("core/Managers.csv")
return _get_table("Managers.csv")

def managers_half() -> pd.DataFrame:
return _get_file("core/ManagersHalf.csv")
return _get_table("ManagersHalf.csv")

def master() -> pd.DataFrame:
# Alias for people -- the new name for master
return people()

def parks() -> pd.DataFrame:
return _get_table("Parks.csv", encoding="unicode_escape")

def people() -> pd.DataFrame:
return _get_file("core/People.csv")
return _get_table("People.csv", encoding="unicode_escape")

def pitching() -> pd.DataFrame:
return _get_file("core/Pitching.csv")
return _get_table("Pitching.csv")

def pitching_post() -> pd.DataFrame:
return _get_file("core/PitchingPost.csv")
return _get_table("PitchingPost.csv")

def salaries() -> pd.DataFrame:
return _get_file("contrib/Salaries.csv")
return _get_table("Salaries.csv")

def schools() -> pd.DataFrame:
return _get_file("contrib/Schools.csv", quotechar='"') # different here bc of doublequotes used in some school names
# NB: one line is bad; "brklyncuny" should use double quotes, but doesn't
return _get_table("Schools.csv", quotechar='"', on_bad_lines="skip")

def series_post() -> pd.DataFrame:
return _get_file("core/SeriesPost.csv")
return _get_table("SeriesPost.csv")

def teams_core() -> pd.DataFrame:
return _get_file("core/Teams.csv")

def teams_upstream() -> pd.DataFrame:
return _get_file("upstream/Teams.csv") # manually maintained file
return _get_table("Teams.csv")

def teams_franchises() -> pd.DataFrame:
return _get_file("core/TeamsFranchises.csv")
return _get_table("TeamsFranchises.csv")

def teams_half() -> pd.DataFrame:
return _get_file("core/TeamsHalf.csv")
return _get_table("TeamsHalf.csv")
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
'matplotlib>=2.0.0',
'tqdm>=4.50.0',
'attrs>=20.3.0',
'py7zr>=0.22.0',
'requests_cache>=1.2.1',
],

# List additional groups of dependencies here (e.g. development
Expand Down
75 changes: 31 additions & 44 deletions tests/pybaseball/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,83 +131,70 @@ def get_contents(filename: str) -> str:

return get_contents


@pytest.fixture()
def get_data_file_dataframe(data_dir: str) -> GetDataFrameCallable:
def get_data_file_bytes(data_dir: str) -> Callable[[str], bytes]:
"""
Returns a function that will allow getting a dataframe from a csv file in the tests data directory easily
Returns a function that will allow getting the contents of a file in the tests data directory easily
"""
def get_dataframe(filename: str, parse_dates: _ParseDates = False) -> pd.DataFrame:
def get_bytes(filename: str) -> bytes:
"""
Get the DatFrame representation of the contents of a csv file in the tests data directory
Get the byte contents of a file in the tests data directory


ARGUMENTS:
filename : str : the name of the file within the tests data directory to load into a DataFrame
filename : str : the name of the file within the tests data directory to get the contents of
"""
return pd.read_csv(os.path.join(data_dir, filename), index_col=0, parse_dates=parse_dates).reset_index(drop=True).convert_dtypes(convert_string=False)

return get_dataframe
with open(os.path.join(data_dir, filename), 'rb') as _file:
data = _file.read()
return data

return get_bytes

@pytest.fixture()
def response_get_monkeypatch(monkeypatch: MonkeyPatch) -> Callable:
def get_data_file_dataframe(data_dir: str) -> GetDataFrameCallable:
"""
Returns a function that will monkeypatch the requests.get function call to return expected data
Returns a function that will allow getting a dataframe from a csv file in the tests data directory easily
"""
def setup(result: Union[str, bytes], expected_url: Optional[str] = None) -> None:
def get_dataframe(filename: str, parse_dates: _ParseDates = False) -> pd.DataFrame:
"""
Get the DatFrame representation of the contents of a csv file in the tests data directory
Get the DatFrame representation of the contents of a csv file in the tests data directory


ARGUMENTS:
result : str : the payload to return in the contents of the request.get call
expected_url : str (optional) : an expected_url to test the get call against
to ensure the correct endpoint is hit
filename : str : the name of the file within the tests data directory to load into a DataFrame
"""
def _monkeypatch(url: str, params: Optional[Dict] = None, timeout: Optional[int] = None) -> object:
final_url = url

if params:
query_params = urllib.parse.urlencode(params, safe=',')
final_url = f"{final_url}?{query_params}"

if expected_url is not None:
# These prints are desired as these are long and get cut off in the test outpute.
# These will only render on failed tests, so only when you would want to see them anyway.
print("expected", expected_url)
print("received", final_url)
assert final_url.endswith(expected_url)

class DummyResponse:
def __init__(self, content: Union[str, bytes]):
self.content = content
self.text = content
self.status_code = 200
self.url = final_url
return pd.read_csv(os.path.join(data_dir, filename), index_col=0, parse_dates=parse_dates).reset_index(drop=True).convert_dtypes(convert_string=False)

return DummyResponse(result)
return get_dataframe

monkeypatch.setattr(requests, 'get', _monkeypatch)

return setup
@pytest.fixture()
def response_get_monkeypatch(monkeypatch: MonkeyPatch) -> Callable:
return _get_monkeypatch(monkeypatch, requests)

@pytest.fixture()
def bref_get_monkeypatch(monkeypatch: MonkeyPatch) -> Callable:
return _get_monkeypatch(monkeypatch, BRefSession())

@pytest.fixture()
def target_get_monkeypatch(monkeypatch: MonkeyPatch, target: str | object) -> Callable:
return _get_monkeypatch(monkeypatch, target)

def _get_monkeypatch(monkeypatch: MonkeyPatch, target: str | object) -> Callable:
"""
Returns a function that will monkeypatch the BRefSession.get function call to return expected data
Returns a function that will monkeypatch the input target's get() function call to return supplied result.
"""
def setup(result: Union[str, bytes], expected_url: Optional[str] = None) -> None:
"""
Get the DatFrame representation of the contents of a csv file in the tests data directory
Get the result when calling the get() function


ARGUMENTS:
result : str : the payload to return in the contents of the request.get call
result : str | bytes : the payload to return in the contents of the request.get call
expected_url : str (optional) : an expected_url to test the get call against
to ensure the correct endpoint is hit
"""
def _monkeypatch(url: str, params: Optional[Dict] = None, timeout: Optional[int] = None) -> object:
def _monkeypatch(url: str, params: Optional[Dict] = None, stream = False, timeout: Optional[int] = None) -> object:
final_url = url

if params:
Expand All @@ -230,6 +217,6 @@ def __init__(self, content: Union[str, bytes]):

return DummyResponse(result)

monkeypatch.setattr(BRefSession(), 'get', _monkeypatch)
monkeypatch.setattr(target, 'get', _monkeypatch)

return setup
Loading