diff --git a/mesa_frames/concrete/datacollector.py b/mesa_frames/concrete/datacollector.py index f7db338e..5f379f27 100644 --- a/mesa_frames/concrete/datacollector.py +++ b/mesa_frames/concrete/datacollector.py @@ -190,6 +190,9 @@ def _is_str_collection(x: Any) -> bool: agent_data_dict: dict[str, pl.Series] = {} + for set_name, aset in self._model.sets.items(): + agent_data_dict[f"unique_id_{set_name}"] = aset["unique_id"] + for col_name, reporter in self._agent_reporters.items(): # 1) String or collection[str]: shorthand to fetch columns if isinstance(reporter, str) or _is_str_collection(reporter): @@ -406,10 +409,18 @@ def _write_postgres(self, uri: str, frames_to_flush: list): for kind, step, batch, lf in frames_to_flush: df = lf.collect() table = f"{kind}_data" - cols = df.columns + + if kind == "model": + ordered_cols = ["step", "seed", "batch"] + [ + c for c in df.columns if c not in {"step", "seed", "batch"} + ] + else: + ordered_cols = df.columns + + df = df.select(ordered_cols) values = [tuple(row) for row in df.rows()] - placeholders = ", ".join(["%s"] * len(cols)) - columns = ", ".join(cols) + placeholders = ", ".join(["%s"] * len(ordered_cols)) + columns = ", ".join(ordered_cols) cur.executemany( f"INSERT INTO {self._schema}.{table} ({columns}) VALUES ({placeholders})", values, @@ -449,18 +460,20 @@ def _validate_inputs(self): - Ensures a `storage_uri` is provided if needed. - For PostgreSQL, validates that required tables and columns exist. """ - if self._storage != "memory" and self._storage_uri == None: + if self._storage != "memory" and self._storage_uri is None: raise ValueError( "Please define a storage_uri to if to be stored not in memory" ) if self._storage == "postgresql": - conn = self._get_db_connection(self._storage_uri) + conn = None try: + conn = self._get_db_connection(self._storage_uri) self._validate_postgress_table_exists(conn) self._validate_postgress_columns_exists(conn) finally: - conn.close() + if conn: + conn.close() def _validate_postgress_table_exists(self, conn: connection): """ @@ -556,6 +569,11 @@ def _is_str_collection(x: Any) -> bool: return False expected_columns: set[str] = set() + + if table_name == "agent_data": + for set_name, _ in self._model.sets.items(): + expected_columns.add(f"unique_id_{set_name}".lower()) + for col_name, req in reporter.items(): # Strings → one column per set with suffix if isinstance(req, str): diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index b2ac3279..e84acbfc 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -1,3 +1,4 @@ +import psycopg2 from mesa_frames.concrete.datacollector import DataCollector from mesa_frames import Model, AgentSet, AgentSetRegistry import pytest @@ -15,6 +16,7 @@ def custom_trigger(model): class ExampleAgentSet1(AgentSet): def __init__(self, model: Model): super().__init__(model) + self["unique_id"] = pl.Series("unique_id", [101, 102, 103, 104], dtype=pl.Int64) self["wealth"] = pl.Series("wealth", [1, 2, 3, 4]) self["age"] = pl.Series("age", [10, 20, 30, 40]) @@ -28,6 +30,7 @@ def step(self) -> None: class ExampleAgentSet2(AgentSet): def __init__(self, model: Model): super().__init__(model) + self["unique_id"] = pl.Series("unique_id", [201, 202, 203, 204], dtype=pl.Int64) self["wealth"] = pl.Series("wealth", [10, 20, 30, 40]) self["age"] = pl.Series("age", [11, 22, 33, 44]) @@ -41,6 +44,7 @@ def step(self) -> None: class ExampleAgentSet3(AgentSet): def __init__(self, model: Model): super().__init__(model) + self["unique_id"] = pl.Series("unique_id", [301, 302, 303, 304], dtype=pl.Int64) self["age"] = pl.Series("age", [1, 2, 3, 4]) self["wealth"] = pl.Series("wealth", [1, 2, 3, 4]) @@ -147,11 +151,18 @@ def test__init__(self, fix1_model, postgres_uri): ): model.test_dc = DataCollector(model=model, storage="S3-csv") + try: + psycopg2.connect(postgres_uri) + except psycopg2.OperationalError: + pass + with pytest.raises( ValueError, match="Please define a storage_uri to if to be stored not in memory", ): - model.test_dc = DataCollector(model=model, storage="postgresql") + model.test_dc = DataCollector( + model=model, storage="postgresql", storage_uri=None + ) def test_collect(self, fix1_model): model = fix1_model @@ -185,12 +196,15 @@ def test_collect(self, fix1_model): with pytest.raises(pl.exceptions.ColumnNotFoundError, match="max_wealth"): collected_data["model"]["max_wealth"] - assert collected_data["agent"].shape == (4, 7) + assert collected_data["agent"].shape == (4, 10) assert set(collected_data["agent"].columns) == { "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch", @@ -242,12 +256,15 @@ def test_collect_step(self, fix1_model): assert collected_data["model"]["step"].to_list() == [5] assert collected_data["model"]["total_agents"].to_list() == [12] - assert collected_data["agent"].shape == (4, 7) + assert collected_data["agent"].shape == (4, 10) assert set(collected_data["agent"].columns) == { "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch", @@ -297,25 +314,20 @@ def test_conditional_collect(self, fix1_model): assert collected_data["model"]["step"].to_list() == [2, 4] assert collected_data["model"]["total_agents"].to_list() == [12, 12] - assert collected_data["agent"].shape == (8, 7) - assert set(collected_data["agent"].columns) == { - "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", - "step", - "seed", - "batch", - } + assert collected_data["agent"].shape == (8, 10) assert set(collected_data["agent"].columns) == { "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch", } + assert collected_data["agent"]["wealth"].to_list() == [3, 4, 5, 6, 5, 6, 7, 8] assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [ 10, @@ -394,15 +406,25 @@ def test_flush_local_csv(self, fix1_model): assert model_df["step"].to_list() == [2] assert model_df["total_agents"].to_list() == [12] + agent_overrides = { + "seed": pl.Utf8, + "unique_id_ExampleAgentSet1": pl.Utf8, + "unique_id_ExampleAgentSet2": pl.Utf8, + "unique_id_ExampleAgentSet3": pl.Utf8, + } + agent_df = pl.read_csv( os.path.join(tmpdir, "agent_step2_batch0.csv"), - schema_overrides={"seed": pl.Utf8}, + schema_overrides=agent_overrides, ) assert set(agent_df.columns) == { "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch", @@ -420,7 +442,7 @@ def test_flush_local_csv(self, fix1_model): agent_df = pl.read_csv( os.path.join(tmpdir, "agent_step4_batch0.csv"), - schema_overrides={"seed": pl.Utf8}, + schema_overrides=agent_overrides, ) assert agent_df["step"].to_list() == [4, 4, 4, 4] assert agent_df["wealth"].to_list() == [5, 6, 7, 8] @@ -474,10 +496,13 @@ def test_flush_local_parquet(self, fix1_model): reason="PostgreSQL tests are skipped on Windows runners", ) def test_postgress(self, fix1_model, postgres_uri): - model = fix1_model + try: + conn = psycopg2.connect(postgres_uri) + conn.close() + except psycopg2.OperationalError: + pytest.skip("PostgreSQL not available") - # Connect directly and validate data - import psycopg2 + model = fix1_model conn = psycopg2.connect(postgres_uri) cur = conn.cursor() @@ -496,6 +521,9 @@ def test_postgress(self, fix1_model, postgres_uri): step INTEGER, seed VARCHAR, batch INTEGER, + unique_id_exampleagentset1 INTEGER, + unique_id_exampleagentset2 INTEGER, + unique_id_exampleagentset3 INTEGER, age_ExampleAgentSet1 INTEGER, age_ExampleAgentSet2 INTEGER, age_ExampleAgentSet3 INTEGER, @@ -580,12 +608,15 @@ def test_batch_memory(self, fix2_model): assert collected_data["model"]["batch"].to_list() == [0, 1, 0, 1] assert collected_data["model"]["total_agents"].to_list() == [12, 12, 12, 12] - assert collected_data["agent"].shape == (16, 7) + assert collected_data["agent"].shape == (16, 10) assert set(collected_data["agent"].columns) == { "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch", @@ -596,6 +627,9 @@ def test_batch_memory(self, fix2_model): "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch", @@ -773,16 +807,26 @@ def test_batch_save(self, fix2_model): assert model_df_step4_batch0["step"].to_list() == [4] assert model_df_step4_batch0["total_agents"].to_list() == [12] + agent_overrides = { + "seed": pl.Utf8, + "unique_id_ExampleAgentSet1": pl.Utf8, + "unique_id_ExampleAgentSet2": pl.Utf8, + "unique_id_ExampleAgentSet3": pl.Utf8, + } + # test agent batch reset agent_df_step2_batch0 = pl.read_csv( os.path.join(tmpdir, "agent_step2_batch0.csv"), - schema_overrides={"seed": pl.Utf8}, + schema_overrides=agent_overrides, ) assert set(agent_df_step2_batch0.columns) == { "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch", @@ -810,13 +854,16 @@ def test_batch_save(self, fix2_model): agent_df_step2_batch1 = pl.read_csv( os.path.join(tmpdir, "agent_step2_batch1.csv"), - schema_overrides={"seed": pl.Utf8}, + schema_overrides=agent_overrides, ) assert set(agent_df_step2_batch1.columns) == { "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch", @@ -844,13 +891,16 @@ def test_batch_save(self, fix2_model): agent_df_step4_batch0 = pl.read_csv( os.path.join(tmpdir, "agent_step4_batch0.csv"), - schema_overrides={"seed": pl.Utf8}, + schema_overrides=agent_overrides, ) assert set(agent_df_step4_batch0.columns) == { "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch",