diff --git a/alws/crud/errata.py b/alws/crud/errata.py index e6f81b4e..0f661bb8 100644 --- a/alws/crud/errata.py +++ b/alws/crud/errata.py @@ -112,8 +112,13 @@ def simplify(self) -> bool: and len(self.criteria["criterion"]) == 1 and len(self.criteria["criteria"]) == 0 ): - self.parent.criteria["criterion"].append(self.criteria["criterion"].pop()) - if len(self.criteria["criteria"]) == 0 and len(self.criteria["criterion"]) == 0: + self.parent.criteria["criterion"].append( + self.criteria["criterion"].pop() + ) + if ( + len(self.criteria["criteria"]) == 0 + and len(self.criteria["criterion"]) == 0 + ): return True return False @@ -143,7 +148,8 @@ async def get_oval_xml( if only_released: query = query.filter( - models.NewErrataRecord.release_status == ErrataReleaseStatus.RELEASED + models.NewErrataRecord.release_status + == ErrataReleaseStatus.RELEASED ) records = (await db.execute(query)).scalars().all() @@ -180,7 +186,9 @@ def errata_records_to_oval( ] for albs_pkg in albs_pkgs: rhel_evra = f"{pkg.epoch}:{pkg.version}-{pkg.release}" - albs_evra = f"{albs_pkg.epoch}:{albs_pkg.version}-{albs_pkg.release}" + albs_evra = ( + f"{albs_pkg.epoch}:{albs_pkg.version}-{albs_pkg.release}" + ) arch = albs_pkg.arch if arch == "noarch": arch = pkg.arch @@ -229,7 +237,9 @@ def errata_records_to_oval( ): criterion_list = [] criteria["criterion"] = criterion_list - links_tracking.update(criterion["ref"] for criterion in criterion_list) + links_tracking.update( + criterion["ref"] for criterion in criterion_list + ) criteria_list = new_criteria_list for criteria in record.original_criteria: criteria_node = CriteriaNode(criteria, None) @@ -323,7 +333,9 @@ def errata_records_to_oval( if test.get("state_ref"): test["state_ref"] = debrand_id(test["state_ref"]) links_tracking.update([test["object_ref"], test["state_ref"]]) - oval.append_object(get_test_cls_by_tag(test["type"]).from_dict(test)) + oval.append_object( + get_test_cls_by_tag(test["type"]).from_dict(test) + ) for obj in record.original_objects: obj["id"] = debrand_id(obj["id"]) if obj["id"] in objects: @@ -337,7 +349,9 @@ def errata_records_to_oval( if get_object_cls_by_tag(obj["type"]) == RpmverifyfileObject: if obj["filepath"] == "/etc/redhat-release": obj["filepath"] = "/etc/almalinux-release" - oval.append_object(get_object_cls_by_tag(obj["type"]).from_dict(obj)) + oval.append_object( + get_object_cls_by_tag(obj["type"]).from_dict(obj) + ) for state in record.original_states: state["id"] = debrand_id(state["id"]) if state["id"] in objects: @@ -347,7 +361,9 @@ def errata_records_to_oval( if state.get("evr"): if state["evr"] in rhel_evra_mapping: if state["arch"]: - state["arch"] = "|".join(rhel_evra_mapping[state["evr"]].keys()) + state["arch"] = "|".join( + rhel_evra_mapping[state["evr"]].keys() + ) state["evr"] = rhel_evra_mapping[state["evr"]][ next(iter(rhel_evra_mapping[state["evr"]].keys())) ] @@ -368,9 +384,14 @@ def errata_records_to_oval( if var["id"] not in links_tracking: continue objects.add(var["id"]) - oval.append_object(get_variable_cls_by_tag(var["type"]).from_dict(var)) + oval.append_object( + get_variable_cls_by_tag(var["type"]).from_dict(var) + ) for obj in record.original_objects: - if obj["id"] != var["arithmetic"]["object_component"]["object_ref"]: + if ( + obj["id"] + != var["arithmetic"]["object_component"]["object_ref"] + ): continue if obj["id"] in objects: continue @@ -471,7 +492,9 @@ async def update_errata_record( else: record.title = update_record.title if record.title: - record.oval_title = get_oval_title(record.title, record.id, record.severity) + record.oval_title = get_oval_title( + record.title, record.id, record.severity + ) if update_record.description is not None: if update_record.description == record.original_description: record.description = None @@ -598,7 +621,10 @@ async def get_matching_albs_packages( items_to_insert.append(mapping) errata_package.albs_packages.append(mapping) errata_record_ids.add(errata_package.errata_record_id) - package_type = {"type": ErrataPackagesType.BUILD, "build_ids": list(build_ids)} + package_type = { + "type": ErrataPackagesType.BUILD, + "build_ids": list(build_ids), + } return items_to_insert, package_type @@ -640,7 +666,9 @@ async def create_errata_record(db: AsyncSession, errata: BaseErrataRecord): description=None, original_description=errata.description, title=None, - oval_title=get_oval_title(errata.title, alma_errata_id, errata.severity), + oval_title=get_oval_title( + errata.title, alma_errata_id, errata.severity + ), original_title=get_verbose_errata_title(errata.title, errata.severity), contact_mail=platform.contact_mail, status=errata.status, @@ -672,7 +700,9 @@ async def create_errata_record(db: AsyncSession, errata: BaseErrataRecord): db_cve = None if ref.cve: db_cve = await db.execute( - select(models.ErrataCVE).where(models.ErrataCVE.id == ref.cve.id) + select(models.ErrataCVE).where( + models.ErrataCVE.id == ref.cve.id + ) ) db_cve = db_cve.scalars().first() if db_cve is None: @@ -812,7 +842,11 @@ async def list_errata_records( options = [] if compact: options.append( - load_only(models.NewErrataRecord.id, models.NewErrataRecord.updated_date) + load_only( + models.NewErrataRecord.id, + models.NewErrataRecord.updated_date, + models.NewErrataRecord.platform_id, + ) ) else: options.extend([ @@ -831,7 +865,9 @@ def generate_query(count=False): query = select(models.NewErrataRecord).options(*options) query = query.order_by(models.NewErrataRecord.id.desc()) if errata_id: - query = query.filter(models.NewErrataRecord.id.like(f"%{errata_id}%")) + query = query.filter( + models.NewErrataRecord.id.like(f"%{errata_id}%") + ) if errata_ids: query = query.filter(models.NewErrataRecord.id.in_(errata_ids)) if title: @@ -844,15 +880,21 @@ def generate_query(count=False): if platform: query = query.filter(models.NewErrataRecord.platform_id == platform) if cve_id: - query = query.filter(models.NewErrataRecord.cves.like(f"%{cve_id}%")) + query = query.filter( + models.NewErrataRecord.cves.like(f"%{cve_id}%") + ) if status: - query = query.filter(models.NewErrataRecord.release_status == status) + query = query.filter( + models.NewErrataRecord.release_status == status + ) if page and not count: query = query.slice(10 * page - 10, 10 * page) return query return { - "total_records": (await db.execute(generate_query(count=True))).scalar(), + "total_records": ( + await db.execute(generate_query(count=True)) + ).scalar(), "records": (await db.execute(generate_query())).scalars().all(), "current_page": page, } @@ -980,10 +1022,14 @@ async def release_errata_packages( f"{platform.name.lower()}-for-{arch}-{repo_stage}-" f"rpms__{platform_version}_default" ) - default_summary = clean_errata_title(record.get_title(), severity=record.severity) + default_summary = clean_errata_title( + record.get_title(), severity=record.severity + ) pulp_record = { "id": record.id, - "updated_date": datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S"), + "updated_date": datetime.datetime.utcnow().strftime( + "%Y-%m-%d %H:%M:%S" + ), "issued_date": record.issued_date.strftime("%Y-%m-%d %H:%M:%S"), "description": record.get_description(), "fromstr": record.contact_mail, @@ -1043,9 +1089,9 @@ async def prepare_updateinfo_mapping( models.BuildTaskArtifact.href == pkg_href, ) .options( - selectinload(models.BuildTaskArtifact.build_task).selectinload( - models.BuildTask.rpm_modules - ) + selectinload( + models.BuildTaskArtifact.build_task + ).selectinload(models.BuildTask.rpm_modules) ) ) ) @@ -1057,7 +1103,8 @@ async def prepare_updateinfo_mapping( select(models.NewErrataToALBSPackage) .where( or_( - models.NewErrataToALBSPackage.albs_artifact_id == db_pkg.id, + models.NewErrataToALBSPackage.albs_artifact_id + == db_pkg.id, models.NewErrataToALBSPackage.pulp_href == pkg_href, ) ) @@ -1094,7 +1141,9 @@ def append_update_packages_in_update_records( errata_records: List[Dict[str, Any]], updateinfo_mapping: DefaultDict[ str, - List[Tuple[models.BuildTaskArtifact, dict, models.NewErrataToALBSPackage]], + List[ + Tuple[models.BuildTaskArtifact, dict, models.NewErrataToALBSPackage] + ], ], ): for record in errata_records: @@ -1345,7 +1394,9 @@ async def release_errata_record(record_id: str, platform_id: int, force: bool): query = generate_query_for_release([record_id]) query = query.filter(models.NewErrataRecord.platform_id == platform_id) db_record = await session.execute(query) - db_record: Optional[models.NewErrataRecord] = db_record.scalars().first() + db_record: Optional[models.NewErrataRecord] = ( + db_record.scalars().first() + ) if not db_record: logging.info("Record with %s id doesn't exists", record_id) return @@ -1432,7 +1483,8 @@ async def bulk_errata_records_release(records_ids: List[str]): return logging.info( - "Starting bulk errata release, the following records are" " locked: %s", + "Starting bulk errata release, the following records are" + " locked: %s", [rec.id for rec in db_records], ) for db_record in db_records: diff --git a/alws/routers/errata.py b/alws/routers/errata.py index 22ba9838..741e3a2d 100644 --- a/alws/routers/errata.py +++ b/alws/routers/errata.py @@ -130,12 +130,16 @@ async def update_errata_record( @router.get("/all/", response_model=List[errata_schema.CompactErrataRecord]) async def list_all_errata_records( db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())), + platform_id: Optional[int] = None, ): - records = await errata_crud.list_errata_records(db, compact=True) + records = await errata_crud.list_errata_records( + db, compact=True, platform=platform_id + ) return [ { "id": record.id, "updated_date": record.updated_date, + "platform_id": record.platform_id, } for record in records["records"] ] diff --git a/alws/schemas/errata_schema.py b/alws/schemas/errata_schema.py index cdcbd8b0..bf4361a8 100644 --- a/alws/schemas/errata_schema.py +++ b/alws/schemas/errata_schema.py @@ -124,6 +124,7 @@ class ErrataListResponse(BaseModel): class CompactErrataRecord(BaseModel): id: str updated_date: datetime.datetime + platform_id: int class CreateErrataResponse(BaseModel): diff --git a/tests/test_api/test_errata.py b/tests/test_api/test_errata.py index f360d62b..9797f4ab 100644 --- a/tests/test_api/test_errata.py +++ b/tests/test_api/test_errata.py @@ -1,3 +1,5 @@ +from datetime import datetime + import pytest from tests.mock_classes import BaseAsyncTestCase @@ -29,3 +31,36 @@ async def test_get_updateinfo_xml( response.status_code == self.status_codes.HTTP_200_OK and "xml version" in response.text ), f"Cannot get updateinfo.xml:\n{response.text}" + + async def test_list_errata_all_records( + self, + errata_create_payload, + ): + response = await self.make_request("get", "/api/v1/errata/all/") + errata = response.json() + assert ( + response.status_code == self.status_codes.HTTP_200_OK and errata + ), f"Cannot get errata records:\n{response.text}" + assert errata[0]['id'] == errata_create_payload["id"] + assert errata[0]['platform_id'] == errata_create_payload["platform_id"] + + async def test_list_errata_all_records_by_platform( + self, + errata_create_payload, + ): + platform_id = errata_create_payload['platform_id'] + response = await self.make_request( + "get", f"/api/v1/errata/all/?platform_id={platform_id}" + ) + assert ( + response.status_code == self.status_codes.HTTP_200_OK + and response.json() + ), f"Cannot get errata records by platform id:\n{response.text}" + + response = await self.make_request( + "get", f"/api/v1/errata/all/?platform_id={platform_id + 1}" + ) + assert ( + response.status_code == self.status_codes.HTTP_200_OK + and not response.json() + ), f"Cannot get errata records by platform id:\n{response.text}"