Skip to content
156 changes: 156 additions & 0 deletions epymorph/adrio/cdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,162 @@ def validate_result(self, context: Context, result: NDArray) -> None:
on_date_values(validate_values_in_range(0, None)),
)

class _HealthdataG62hSyehMixin(FetchADRIO[DateValueType, np.int64]):
"""
A mixin implementing some of `FetchADRIO`'s API for ADRIOs which fetch
data from healthdata.gov dataset g62h-syeh: a.k.a.
"COVID-19 Reported Patient Impact and Hospital Capacity by State Timeseries(RAW)".

https://healthdata.gov/Hospital/COVID-19-Reported-Patient-Impact-and-Hospital-Capa/g62h-syeh/about_data
"""

_RESOURCE = q.SocrataResource(domain="healthdata.gov", id="g62h-syeh")
"""The Socrata API endpoint."""

_TIME_RANGE = DateRange(iso8601("2020-01-01"), iso8601("2024-04-27"), step=1)
"""The time range over which values are available."""

@property
@override
def result_format(self) -> ResultFormat:
return ResultFormat(shape=Shapes.AxN, dtype=date_value_dtype(np.int64))

@override
def validate_context(self, context: Context):
if not isinstance(context.scope, StateScope):
err = "US State geo scope required."
raise ADRIOContextError(self, context, err)
validate_time_frame(self, context, self._TIME_RANGE)

@override
def validate_result(self, context: Context, result: NDArray) -> None:
time_series = self._TIME_RANGE.overlap_or_raise(context.time_frame)
result_shape = (len(time_series), context.scope.nodes)
adrio_validate_pipe(
self,
context,
result,
validate_numpy(),
validate_shape(result_shape),
validate_dtype(self.result_format.dtype),
on_date_values(validate_values_in_range(0, None)),
)


@adrio_cache
class InfluenzaStateHospitalizationDaily(
_HealthdataG62hSyehMixin, FetchADRIO[DateValueType, np.int64]
):
"""
Loads influenza hospitalization data from HealthData.gov's
"COVID-19 Reported Patient Impact and Hospital Capacity by State Timeseries(RAW)"
dataset. The data were reported by healthcare facilities on a daily basis and aggregated to the state level
by the CDC. This ADRIO is restricted to the date range 2020-01-01 and 2024-04-27 which is the last day of
reporting for this dataset. Note that before the date 2020-10-21 there was highly inconsistent reporting.

This ADRIO supports geo scopes at US State granularity.
The data loaded will be matched to the simulation time frame. The result is a 2D matrix
where the first axis represents reporting days during the time frame and the
second axis is geo scope nodes. Values are tuples of date and the integer number of
reported data.

Parameters
----------
column :
Which column to fetch data from.
Supported columns are 'previous_day_admission_influenza_confirmed' and 'total_patients_hospitalized_confirmed_influenza'.
To select these columns set this parameter to 'admissions' and 'hospitalizations' respectively.

fix_missing :
The method to use to fix missing values.

See Also
--------
[The dataset documentation](https://healthdata.gov/Hospital/COVID-19-Reported-Patient-Impact-and-Hospital-Capa/g62h-syeh/about_data).
""" # noqa: E501

_fix_missing: Fill[np.int64]
_ADMISSIONS = "previous_day_admission_influenza_confirmed"
_HOSPITALIZATIONS = "total_patients_hospitalized_confirmed_influenza"
_column_name: Literal["admissions","hospitalizations"]

def __init__(
self,
*,
column: Literal["admissions","hospitalizations"],
fix_missing: FillLikeInt = False
):
if column not in ("admissions","hospitalizations"):
raise ValueError(("Invalid value for column. Supported values are "
"admissions and hospitalizations."))
self._column_name = column
try:
self._fix_missing = Fill.of_int64(fix_missing)
except ValueError:
raise ValueError("Invalid value for `fix_missing`")

@override
def _fetch(self, context: Context) -> pd.DataFrame:

match self._column_name:
case "admissions":
values = [q.Select(self._ADMISSIONS, "nullable_int", as_name="value")]
case "hospitalizations":
values = [q.Select(self._HOSPITALIZATIONS, "nullable_int",
as_name="value")]
case x:
raise ValueError(f"Unsupported `column_name`: {x}")

query = q.Query(
select=(
q.Select("date", "date", as_name="date"),
q.Select("state", "str", as_name="geoid"),
*values,
),
where=q.And(
q.DateBetween(
"date",
context.time_frame.start_date,
context.time_frame.end_date,
),
q.In("state", context.scope.labels),
),
order_by=(
q.Ascending("date"),
q.Ascending("state"),
q.Ascending(":id"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an FYI: including a sort by row ID is important because with SODA APIs, assembling complete results can require multiple requests (pagination; there's a limit to how many rows can be returned in one request). Thus we need to guarantee that the answer to "what's on page X?" is consistent between requests. Having an unambiguous sort order does that. It's possible date/state are sufficient for this dataset, but including ID makes it very certain and doesn't cost much.

),
)

try:
return q.query_csv(
resource=self._RESOURCE,
query=query,
api_token=healthdata_api_key(),
)
except Exception as e:
raise ADRIOCommunicationError(self, context) from e

@override
def _process(
self,
context: Context,
data_df: pd.DataFrame,
) -> PipelineResult[DateValueType]:
time_series = self._TIME_RANGE.overlap_or_raise(context.time_frame).to_numpy()
pipeline = (
DataPipeline(
axes=(
PivotAxis("date", time_series),
PivotAxis("geoid", context.scope.labels),
),
ndims=2,
dtype=self.result_format.dtype["value"].type,
rng=context,
).finalize(self._fix_missing)
)
return pipeline(data_df).to_date_value(time_series)


@adrio_cache
class COVIDFacilityHospitalization(
Expand Down
Loading