From f5be02abb716a5f3cd6cba2999ef80874a494fc2 Mon Sep 17 00:00:00 2001 From: Agrim Gupta Date: Sun, 1 Feb 2026 15:38:00 +0530 Subject: [PATCH 1/4] feat: allow collecting agents reporters by unique id --- mesa_frames/concrete/datacollector.py | 16 ++++- tests/test_datacollector.py | 92 ++++++++++++++++++++------- 2 files changed, 83 insertions(+), 25 deletions(-) diff --git a/mesa_frames/concrete/datacollector.py b/mesa_frames/concrete/datacollector.py index f7db338e..ffad199e 100644 --- a/mesa_frames/concrete/datacollector.py +++ b/mesa_frames/concrete/datacollector.py @@ -190,6 +190,9 @@ def _is_str_collection(x: Any) -> bool: agent_data_dict: dict[str, pl.Series] = {} + for set_name, aset in self._model.sets.items(): + agent_data_dict[f"unique_id_{set_name}"] = aset["unique_id"] + for col_name, reporter in self._agent_reporters.items(): # 1) String or collection[str]: shorthand to fetch columns if isinstance(reporter, str) or _is_str_collection(reporter): @@ -449,18 +452,20 @@ def _validate_inputs(self): - Ensures a `storage_uri` is provided if needed. - For PostgreSQL, validates that required tables and columns exist. """ - if self._storage != "memory" and self._storage_uri == None: + if self._storage != "memory" and self._storage_uri is None: raise ValueError( "Please define a storage_uri to if to be stored not in memory" ) if self._storage == "postgresql": - conn = self._get_db_connection(self._storage_uri) + conn = None try: + conn = self._get_db_connection(self._storage_uri) self._validate_postgress_table_exists(conn) self._validate_postgress_columns_exists(conn) finally: - conn.close() + if conn: + conn.close() def _validate_postgress_table_exists(self, conn: connection): """ @@ -556,6 +561,11 @@ def _is_str_collection(x: Any) -> bool: return False expected_columns: set[str] = set() + + if table_name == "agent_data": + for set_name, _ in self._model.sets.items(): + expected_columns.add(f"unique_id_{set_name}".lower()) + for col_name, req in reporter.items(): # Strings → one column per set with suffix if isinstance(req, str): diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index b2ac3279..3e9e454e 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -1,3 +1,4 @@ +import psycopg2 from mesa_frames.concrete.datacollector import DataCollector from mesa_frames import Model, AgentSet, AgentSetRegistry import pytest @@ -15,6 +16,7 @@ def custom_trigger(model): class ExampleAgentSet1(AgentSet): def __init__(self, model: Model): super().__init__(model) + self["unique_id"] = pl.Series("unique_id", [101, 102, 103, 104], dtype=pl.Int64) self["wealth"] = pl.Series("wealth", [1, 2, 3, 4]) self["age"] = pl.Series("age", [10, 20, 30, 40]) @@ -28,6 +30,7 @@ def step(self) -> None: class ExampleAgentSet2(AgentSet): def __init__(self, model: Model): super().__init__(model) + self["unique_id"] = pl.Series("unique_id", [201, 202, 203, 204], dtype=pl.Int64) self["wealth"] = pl.Series("wealth", [10, 20, 30, 40]) self["age"] = pl.Series("age", [11, 22, 33, 44]) @@ -41,6 +44,7 @@ def step(self) -> None: class ExampleAgentSet3(AgentSet): def __init__(self, model: Model): super().__init__(model) + self["unique_id"] = pl.Series("unique_id", [301, 302, 303, 304], dtype=pl.Int64) self["age"] = pl.Series("age", [1, 2, 3, 4]) self["wealth"] = pl.Series("wealth", [1, 2, 3, 4]) @@ -147,11 +151,16 @@ def test__init__(self, fix1_model, postgres_uri): ): model.test_dc = DataCollector(model=model, storage="S3-csv") + try: + psycopg2.connect(postgres_uri) + except psycopg2.OperationalError: + pass + with pytest.raises( ValueError, match="Please define a storage_uri to if to be stored not in memory", ): - model.test_dc = DataCollector(model=model, storage="postgresql") + model.test_dc = DataCollector(model=model, storage="postgresql", storage_uri=None) def test_collect(self, fix1_model): model = fix1_model @@ -185,12 +194,15 @@ def test_collect(self, fix1_model): with pytest.raises(pl.exceptions.ColumnNotFoundError, match="max_wealth"): collected_data["model"]["max_wealth"] - assert collected_data["agent"].shape == (4, 7) + assert collected_data["agent"].shape == (4, 10) assert set(collected_data["agent"].columns) == { "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch", @@ -242,12 +254,15 @@ def test_collect_step(self, fix1_model): assert collected_data["model"]["step"].to_list() == [5] assert collected_data["model"]["total_agents"].to_list() == [12] - assert collected_data["agent"].shape == (4, 7) + assert collected_data["agent"].shape == (4, 10) assert set(collected_data["agent"].columns) == { "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch", @@ -297,25 +312,20 @@ def test_conditional_collect(self, fix1_model): assert collected_data["model"]["step"].to_list() == [2, 4] assert collected_data["model"]["total_agents"].to_list() == [12, 12] - assert collected_data["agent"].shape == (8, 7) - assert set(collected_data["agent"].columns) == { - "wealth", - "age_ExampleAgentSet1", - "age_ExampleAgentSet2", - "age_ExampleAgentSet3", - "step", - "seed", - "batch", - } + assert collected_data["agent"].shape == (8, 10) assert set(collected_data["agent"].columns) == { "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch", } + assert collected_data["agent"]["wealth"].to_list() == [3, 4, 5, 6, 5, 6, 7, 8] assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [ 10, @@ -394,15 +404,25 @@ def test_flush_local_csv(self, fix1_model): assert model_df["step"].to_list() == [2] assert model_df["total_agents"].to_list() == [12] + agent_overrides = { + "seed": pl.Utf8, + "unique_id_ExampleAgentSet1": pl.Utf8, + "unique_id_ExampleAgentSet2": pl.Utf8, + "unique_id_ExampleAgentSet3": pl.Utf8, + } + agent_df = pl.read_csv( os.path.join(tmpdir, "agent_step2_batch0.csv"), - schema_overrides={"seed": pl.Utf8}, + schema_overrides=agent_overrides, ) assert set(agent_df.columns) == { "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch", @@ -420,7 +440,7 @@ def test_flush_local_csv(self, fix1_model): agent_df = pl.read_csv( os.path.join(tmpdir, "agent_step4_batch0.csv"), - schema_overrides={"seed": pl.Utf8}, + schema_overrides=agent_overrides, ) assert agent_df["step"].to_list() == [4, 4, 4, 4] assert agent_df["wealth"].to_list() == [5, 6, 7, 8] @@ -474,10 +494,13 @@ def test_flush_local_parquet(self, fix1_model): reason="PostgreSQL tests are skipped on Windows runners", ) def test_postgress(self, fix1_model, postgres_uri): - model = fix1_model + try: + conn = psycopg2.connect(postgres_uri) + conn.close() + except psycopg2.OperationalError: + pytest.skip("PostgreSQL not available") - # Connect directly and validate data - import psycopg2 + model = fix1_model conn = psycopg2.connect(postgres_uri) cur = conn.cursor() @@ -496,6 +519,9 @@ def test_postgress(self, fix1_model, postgres_uri): step INTEGER, seed VARCHAR, batch INTEGER, + "unique_id_ExampleAgentSet1" INTEGER, + "unique_id_ExampleAgentSet2" INTEGER, + "unique_id_ExampleAgentSet3" INTEGER, age_ExampleAgentSet1 INTEGER, age_ExampleAgentSet2 INTEGER, age_ExampleAgentSet3 INTEGER, @@ -580,12 +606,15 @@ def test_batch_memory(self, fix2_model): assert collected_data["model"]["batch"].to_list() == [0, 1, 0, 1] assert collected_data["model"]["total_agents"].to_list() == [12, 12, 12, 12] - assert collected_data["agent"].shape == (16, 7) + assert collected_data["agent"].shape == (16, 10) assert set(collected_data["agent"].columns) == { "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch", @@ -596,6 +625,9 @@ def test_batch_memory(self, fix2_model): "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch", @@ -773,16 +805,26 @@ def test_batch_save(self, fix2_model): assert model_df_step4_batch0["step"].to_list() == [4] assert model_df_step4_batch0["total_agents"].to_list() == [12] + agent_overrides = { + "seed": pl.Utf8, + "unique_id_ExampleAgentSet1": pl.Utf8, + "unique_id_ExampleAgentSet2": pl.Utf8, + "unique_id_ExampleAgentSet3": pl.Utf8, + } + # test agent batch reset agent_df_step2_batch0 = pl.read_csv( os.path.join(tmpdir, "agent_step2_batch0.csv"), - schema_overrides={"seed": pl.Utf8}, + schema_overrides=agent_overrides, ) assert set(agent_df_step2_batch0.columns) == { "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch", @@ -810,13 +852,16 @@ def test_batch_save(self, fix2_model): agent_df_step2_batch1 = pl.read_csv( os.path.join(tmpdir, "agent_step2_batch1.csv"), - schema_overrides={"seed": pl.Utf8}, + schema_overrides=agent_overrides, ) assert set(agent_df_step2_batch1.columns) == { "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch", @@ -844,13 +889,16 @@ def test_batch_save(self, fix2_model): agent_df_step4_batch0 = pl.read_csv( os.path.join(tmpdir, "agent_step4_batch0.csv"), - schema_overrides={"seed": pl.Utf8}, + schema_overrides=agent_overrides, ) assert set(agent_df_step4_batch0.columns) == { "wealth", "age_ExampleAgentSet1", "age_ExampleAgentSet2", "age_ExampleAgentSet3", + "unique_id_ExampleAgentSet1", + "unique_id_ExampleAgentSet2", + "unique_id_ExampleAgentSet3", "step", "seed", "batch", From 2baf04649e7b7a98903e93467bc7a4057e306fbd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 1 Feb 2026 10:17:59 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa_frames/concrete/datacollector.py | 2 +- tests/test_datacollector.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mesa_frames/concrete/datacollector.py b/mesa_frames/concrete/datacollector.py index ffad199e..102e26e9 100644 --- a/mesa_frames/concrete/datacollector.py +++ b/mesa_frames/concrete/datacollector.py @@ -190,7 +190,7 @@ def _is_str_collection(x: Any) -> bool: agent_data_dict: dict[str, pl.Series] = {} - for set_name, aset in self._model.sets.items(): + for set_name, aset in self._model.sets.items(): agent_data_dict[f"unique_id_{set_name}"] = aset["unique_id"] for col_name, reporter in self._agent_reporters.items(): diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index 3e9e454e..d1b1170d 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -154,13 +154,15 @@ def test__init__(self, fix1_model, postgres_uri): try: psycopg2.connect(postgres_uri) except psycopg2.OperationalError: - pass + pass with pytest.raises( ValueError, match="Please define a storage_uri to if to be stored not in memory", ): - model.test_dc = DataCollector(model=model, storage="postgresql", storage_uri=None) + model.test_dc = DataCollector( + model=model, storage="postgresql", storage_uri=None + ) def test_collect(self, fix1_model): model = fix1_model @@ -325,7 +327,7 @@ def test_conditional_collect(self, fix1_model): "seed", "batch", } - + assert collected_data["agent"]["wealth"].to_list() == [3, 4, 5, 6, 5, 6, 7, 8] assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [ 10, From be64717c6e6dca7a7afbc5196bfc880d089aa20e Mon Sep 17 00:00:00 2001 From: Agrim Gupta Date: Sun, 1 Feb 2026 15:59:18 +0530 Subject: [PATCH 3/4] fix: qoutes in postgress --- tests/test_datacollector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index 3e9e454e..d3c90257 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -519,9 +519,9 @@ def test_postgress(self, fix1_model, postgres_uri): step INTEGER, seed VARCHAR, batch INTEGER, - "unique_id_ExampleAgentSet1" INTEGER, - "unique_id_ExampleAgentSet2" INTEGER, - "unique_id_ExampleAgentSet3" INTEGER, + unique_id_ExampleAgentSet1 INTEGER, + unique_id_ExampleAgentSet2 INTEGER, + unique_id_ExampleAgentSet3 INTEGER, age_ExampleAgentSet1 INTEGER, age_ExampleAgentSet2 INTEGER, age_ExampleAgentSet3 INTEGER, From 276f928913e6da527b56274ccf7c4643d1124d75 Mon Sep 17 00:00:00 2001 From: Agrim Gupta Date: Sun, 1 Feb 2026 16:29:59 +0530 Subject: [PATCH 4/4] fix: postgrss test failure --- mesa_frames/concrete/datacollector.py | 14 +++++++++++--- tests/test_datacollector.py | 6 +++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/mesa_frames/concrete/datacollector.py b/mesa_frames/concrete/datacollector.py index 102e26e9..5f379f27 100644 --- a/mesa_frames/concrete/datacollector.py +++ b/mesa_frames/concrete/datacollector.py @@ -409,10 +409,18 @@ def _write_postgres(self, uri: str, frames_to_flush: list): for kind, step, batch, lf in frames_to_flush: df = lf.collect() table = f"{kind}_data" - cols = df.columns + + if kind == "model": + ordered_cols = ["step", "seed", "batch"] + [ + c for c in df.columns if c not in {"step", "seed", "batch"} + ] + else: + ordered_cols = df.columns + + df = df.select(ordered_cols) values = [tuple(row) for row in df.rows()] - placeholders = ", ".join(["%s"] * len(cols)) - columns = ", ".join(cols) + placeholders = ", ".join(["%s"] * len(ordered_cols)) + columns = ", ".join(ordered_cols) cur.executemany( f"INSERT INTO {self._schema}.{table} ({columns}) VALUES ({placeholders})", values, diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index 0604715c..e84acbfc 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -521,9 +521,9 @@ def test_postgress(self, fix1_model, postgres_uri): step INTEGER, seed VARCHAR, batch INTEGER, - unique_id_ExampleAgentSet1 INTEGER, - unique_id_ExampleAgentSet2 INTEGER, - unique_id_ExampleAgentSet3 INTEGER, + unique_id_exampleagentset1 INTEGER, + unique_id_exampleagentset2 INTEGER, + unique_id_exampleagentset3 INTEGER, age_ExampleAgentSet1 INTEGER, age_ExampleAgentSet2 INTEGER, age_ExampleAgentSet3 INTEGER,