Skip to content

Commit

Permalink
More secure sqla queries
Browse files Browse the repository at this point in the history
  • Loading branch information
yuvalherziger committed Aug 25, 2022
1 parent 01e11cb commit 877f1e0
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 21 deletions.
2 changes: 1 addition & 1 deletion keystone_scim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
_red = "\033[0;31m"
_nc = "\033[0m"

VERSION = "0.2.0-rc.0"
VERSION = "0.2.1"
LOGO = """
..............
.--------------. :
Expand Down
38 changes: 24 additions & 14 deletions keystone_scim/store/mysql_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import logging
import uuid
from datetime import datetime
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple

import aiomysql
import pymysql.cursors
from aiomysql.sa import create_engine
from aiomysql.sa.result import RowProxy
from scim2_filter_parser.queries import SQLQuery
from sqlalchemy import delete, insert, select, text, update, and_
from sqlalchemy import delete, insert, select, text, update, and_, or_
from sqlalchemy.sql.base import ImmutableColumnCollection
from sqlalchemy.sql.elements import TextClause

Expand Down Expand Up @@ -189,18 +189,21 @@ async def get_by_id(self, resource_id: str):
if self.entity_type == "groups":
return await self._get_group_by_id(resource_id)

async def _get_where_clause_from_filter(self, _filter: str, attr_map: Dict) -> Optional[TextClause]:
async def _get_where_clause_from_filter(self, _filter: str, attr_map: Dict)\
-> Tuple[Optional[TextClause], Dict]:
if not _filter:
return None
return None, {}
parsed_q = SQLQuery(_filter, self.entity_type, attr_map)
where = parsed_q.where_sql
parsed_params = parsed_q.params_dict
sqla_params = {}
for k in parsed_params.keys():
where = where.replace(f"{{{k}}}", f"'{parsed_params[k]}'")
return text(where)
sqla_params[f"param_{k}"] = parsed_params[k]
where = where.replace(f"{{{k}}}", f":param_{k}")
return text(where), sqla_params

async def _search_users(self, _filter: str, start_index: int = 1, count: int = 100) -> tuple[list[Dict], int]:
where_clause = await self._get_where_clause_from_filter(_filter, self.user_attr_map)
where_clause, sqla_params = await self._get_where_clause_from_filter(_filter, self.user_attr_map)
em_agg = text("""
JSON_ARRAYAGG(JSON_OBJECT(
'value', `user_emails`.`value`,
Expand All @@ -223,14 +226,14 @@ async def _search_users(self, _filter: str, start_index: int = 1, count: int = 1
q = q.group_by(text("1,2,3,4,5,6,7,8,9")).offset(start_index - 1).limit(count)
users = []
total = 0
async for row in conn.execute(q):
async for row in conn.execute(q, **sqla_params):
users.append(await _transform_user(row))
total = row.total

return users, total

async def _search_groups(self, _filter: str, start_index: int = 1, count: int = 100) -> tuple[list[Dict], int]:
where_clause = await self._get_where_clause_from_filter(_filter, self.group_attr_map)
where_clause, sqla_params = await self._get_where_clause_from_filter(_filter, self.group_attr_map)
ct = text("count(*) OVER() as `total`")
q = select([tbl.groups, text("CAST('[]' AS JSON) as members"), ct]). \
join(tbl.users_groups, tbl.groups.c.id == tbl.users_groups.c.groupId, isouter=True). \
Expand All @@ -243,7 +246,7 @@ async def _search_groups(self, _filter: str, start_index: int = 1, count: int =
async with engine.acquire() as conn:
groups = []
total = 0
async for row in conn.execute(q):
async for row in conn.execute(q, **sqla_params):
groups.append(await _transform_group(row))
total = row.total
return groups, total
Expand Down Expand Up @@ -423,9 +426,14 @@ async def _delete_group(self, group_id: str):
return {}

async def remove_users_from_group(self, user_ids: List[str], group_id: str):
user_ids_s = ",".join([f"'{uid}'" for uid in user_ids])
user_id_conditions = []
for user_id in user_ids:
user_id_conditions.append(tbl.users_groups.c.userId == user_id)
q = delete(tbl.users_groups).where(
text(f"`users_groups`.`groupId` = '{group_id}' AND `users_groups`.`userId` IN ({user_ids_s})")
and_(
tbl.users_groups.c.groupId == group_id,
or_(*user_id_conditions)
)
)
engine = await self.get_engine()
async with engine.acquire() as conn:
Expand Down Expand Up @@ -465,14 +473,16 @@ async def search_members(self, _filter: str, group_id: str):
parsed_q = SQLQuery(_filter, "users_groups", self.user_attr_map)
where = parsed_q.where_sql
parsed_params = parsed_q.params_dict
sqla_params = {}
for k in parsed_params.keys():
where = where.replace(f"{{{k}}}", f"'{parsed_params[k]}'")
sqla_params[f"param_{k}"] = parsed_params[k]
where = where.replace(f"{{{k}}}", f":param_{k}")
q = select([tbl.users_groups]).join(tbl.users, tbl.users.c.id == tbl.users_groups.c.userId). \
where(and_(tbl.users_groups.c.groupId == group_id, text(where)))
engine = await self.get_engine()
async with engine.acquire() as conn:
res = []
async for row in await conn.execute(q):
async for row in await conn.execute(q, **sqla_params):
res.append({"value": row.userId})
return res

Expand Down
17 changes: 12 additions & 5 deletions keystone_scim/store/postgresql_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aiopg.sa import create_engine
from aiopg.sa.result import RowProxy
from scim2_filter_parser.queries import SQLQuery
from sqlalchemy import delete, insert, select, text, update, and_
from sqlalchemy import delete, insert, select, text, update, and_, or_
from sqlalchemy.sql.base import ImmutableColumnCollection
from sqlalchemy.sql.elements import TextClause

Expand Down Expand Up @@ -402,9 +402,14 @@ async def _delete_group(self, group_id: str):
return {}

async def remove_users_from_group(self, user_ids: List[str], group_id: str):
user_ids_s = ",".join([f"'{uid}'" for uid in user_ids])
user_id_conditions = []
for user_id in user_ids:
user_id_conditions.append(tbl.users_groups.c.userId == user_id)
q = delete(tbl.users_groups).where(
text(f"users_groups.\"groupId\" = '{group_id}' AND users_groups.\"userId\" IN ({user_ids_s})")
and_(
tbl.users_groups.c.groupId == group_id,
or_(*user_id_conditions)
)
)
engine = await self.get_engine()
async with engine.acquire() as conn:
Expand Down Expand Up @@ -438,14 +443,16 @@ async def search_members(self, _filter: str, group_id: str):
parsed_q = SQLQuery(_filter, "users_groups", self.user_attr_map)
where = parsed_q.where_sql
parsed_params = parsed_q.params_dict
sqla_params = {}
for k in parsed_params.keys():
where = where.replace(f"{{{k}}}", f"'{parsed_params[k]}'")
sqla_params[f"param_{k}"] = parsed_params[k]
where = where.replace(f"{{{k}}}", f":param_{k}")
q = select([tbl.users_groups]).join(tbl.users, tbl.users.c.id == tbl.users_groups.c.userId). \
where(and_(tbl.users_groups.c.groupId == group_id, text(where)))
engine = await self.get_engine()
async with engine.acquire() as conn:
res = []
async for row in await conn.execute(q):
async for row in await conn.execute(q, **sqla_params):
res.append({"value": row.userId})
return res

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "keystone-scim"
version = "0.2.0-rc.0"
version = "0.2.1"
description = "A SCIM 2.0 Provisioning API"
authors = ["Yuval Herziger <[email protected]>"]
license = "MIT"
Expand Down

0 comments on commit 877f1e0

Please sign in to comment.