diff --git a/keystone_scim/__init__.py b/keystone_scim/__init__.py index 17f34b7..3084307 100644 --- a/keystone_scim/__init__.py +++ b/keystone_scim/__init__.py @@ -5,7 +5,7 @@ _red = "\033[0;31m" _nc = "\033[0m" -VERSION = "0.2.0-rc.0" +VERSION = "0.2.1" LOGO = """ .............. .--------------. : diff --git a/keystone_scim/store/mysql_store.py b/keystone_scim/store/mysql_store.py index a8bf9cc..4e8bee0 100644 --- a/keystone_scim/store/mysql_store.py +++ b/keystone_scim/store/mysql_store.py @@ -2,14 +2,14 @@ import logging import uuid from datetime import datetime -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import aiomysql import pymysql.cursors from aiomysql.sa import create_engine from aiomysql.sa.result import RowProxy from scim2_filter_parser.queries import SQLQuery -from sqlalchemy import delete, insert, select, text, update, and_ +from sqlalchemy import delete, insert, select, text, update, and_, or_ from sqlalchemy.sql.base import ImmutableColumnCollection from sqlalchemy.sql.elements import TextClause @@ -189,18 +189,21 @@ async def get_by_id(self, resource_id: str): if self.entity_type == "groups": return await self._get_group_by_id(resource_id) - async def _get_where_clause_from_filter(self, _filter: str, attr_map: Dict) -> Optional[TextClause]: + async def _get_where_clause_from_filter(self, _filter: str, attr_map: Dict)\ + -> Tuple[Optional[TextClause], Dict]: if not _filter: - return None + return None, {} parsed_q = SQLQuery(_filter, self.entity_type, attr_map) where = parsed_q.where_sql parsed_params = parsed_q.params_dict + sqla_params = {} for k in parsed_params.keys(): - where = where.replace(f"{{{k}}}", f"'{parsed_params[k]}'") - return text(where) + sqla_params[f"param_{k}"] = parsed_params[k] + where = where.replace(f"{{{k}}}", f":param_{k}") + return text(where), sqla_params async def _search_users(self, _filter: str, start_index: int = 1, count: int = 100) -> tuple[list[Dict], int]: - where_clause = await self._get_where_clause_from_filter(_filter, self.user_attr_map) + where_clause, sqla_params = await self._get_where_clause_from_filter(_filter, self.user_attr_map) em_agg = text(""" JSON_ARRAYAGG(JSON_OBJECT( 'value', `user_emails`.`value`, @@ -223,14 +226,14 @@ async def _search_users(self, _filter: str, start_index: int = 1, count: int = 1 q = q.group_by(text("1,2,3,4,5,6,7,8,9")).offset(start_index - 1).limit(count) users = [] total = 0 - async for row in conn.execute(q): + async for row in conn.execute(q, **sqla_params): users.append(await _transform_user(row)) total = row.total return users, total async def _search_groups(self, _filter: str, start_index: int = 1, count: int = 100) -> tuple[list[Dict], int]: - where_clause = await self._get_where_clause_from_filter(_filter, self.group_attr_map) + where_clause, sqla_params = await self._get_where_clause_from_filter(_filter, self.group_attr_map) ct = text("count(*) OVER() as `total`") q = select([tbl.groups, text("CAST('[]' AS JSON) as members"), ct]). \ join(tbl.users_groups, tbl.groups.c.id == tbl.users_groups.c.groupId, isouter=True). \ @@ -243,7 +246,7 @@ async def _search_groups(self, _filter: str, start_index: int = 1, count: int = async with engine.acquire() as conn: groups = [] total = 0 - async for row in conn.execute(q): + async for row in conn.execute(q, **sqla_params): groups.append(await _transform_group(row)) total = row.total return groups, total @@ -423,9 +426,14 @@ async def _delete_group(self, group_id: str): return {} async def remove_users_from_group(self, user_ids: List[str], group_id: str): - user_ids_s = ",".join([f"'{uid}'" for uid in user_ids]) + user_id_conditions = [] + for user_id in user_ids: + user_id_conditions.append(tbl.users_groups.c.userId == user_id) q = delete(tbl.users_groups).where( - text(f"`users_groups`.`groupId` = '{group_id}' AND `users_groups`.`userId` IN ({user_ids_s})") + and_( + tbl.users_groups.c.groupId == group_id, + or_(*user_id_conditions) + ) ) engine = await self.get_engine() async with engine.acquire() as conn: @@ -465,14 +473,16 @@ async def search_members(self, _filter: str, group_id: str): parsed_q = SQLQuery(_filter, "users_groups", self.user_attr_map) where = parsed_q.where_sql parsed_params = parsed_q.params_dict + sqla_params = {} for k in parsed_params.keys(): - where = where.replace(f"{{{k}}}", f"'{parsed_params[k]}'") + sqla_params[f"param_{k}"] = parsed_params[k] + where = where.replace(f"{{{k}}}", f":param_{k}") q = select([tbl.users_groups]).join(tbl.users, tbl.users.c.id == tbl.users_groups.c.userId). \ where(and_(tbl.users_groups.c.groupId == group_id, text(where))) engine = await self.get_engine() async with engine.acquire() as conn: res = [] - async for row in await conn.execute(q): + async for row in await conn.execute(q, **sqla_params): res.append({"value": row.userId}) return res diff --git a/keystone_scim/store/postgresql_store.py b/keystone_scim/store/postgresql_store.py index 504c1ec..05e8004 100644 --- a/keystone_scim/store/postgresql_store.py +++ b/keystone_scim/store/postgresql_store.py @@ -10,7 +10,7 @@ from aiopg.sa import create_engine from aiopg.sa.result import RowProxy from scim2_filter_parser.queries import SQLQuery -from sqlalchemy import delete, insert, select, text, update, and_ +from sqlalchemy import delete, insert, select, text, update, and_, or_ from sqlalchemy.sql.base import ImmutableColumnCollection from sqlalchemy.sql.elements import TextClause @@ -402,9 +402,14 @@ async def _delete_group(self, group_id: str): return {} async def remove_users_from_group(self, user_ids: List[str], group_id: str): - user_ids_s = ",".join([f"'{uid}'" for uid in user_ids]) + user_id_conditions = [] + for user_id in user_ids: + user_id_conditions.append(tbl.users_groups.c.userId == user_id) q = delete(tbl.users_groups).where( - text(f"users_groups.\"groupId\" = '{group_id}' AND users_groups.\"userId\" IN ({user_ids_s})") + and_( + tbl.users_groups.c.groupId == group_id, + or_(*user_id_conditions) + ) ) engine = await self.get_engine() async with engine.acquire() as conn: @@ -438,14 +443,16 @@ async def search_members(self, _filter: str, group_id: str): parsed_q = SQLQuery(_filter, "users_groups", self.user_attr_map) where = parsed_q.where_sql parsed_params = parsed_q.params_dict + sqla_params = {} for k in parsed_params.keys(): - where = where.replace(f"{{{k}}}", f"'{parsed_params[k]}'") + sqla_params[f"param_{k}"] = parsed_params[k] + where = where.replace(f"{{{k}}}", f":param_{k}") q = select([tbl.users_groups]).join(tbl.users, tbl.users.c.id == tbl.users_groups.c.userId). \ where(and_(tbl.users_groups.c.groupId == group_id, text(where))) engine = await self.get_engine() async with engine.acquire() as conn: res = [] - async for row in await conn.execute(q): + async for row in await conn.execute(q, **sqla_params): res.append({"value": row.userId}) return res diff --git a/pyproject.toml b/pyproject.toml index d6719b3..5c6eb03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "keystone-scim" -version = "0.2.0-rc.0" +version = "0.2.1" description = "A SCIM 2.0 Provisioning API" authors = ["Yuval Herziger "] license = "MIT"