From 6df5efd85be9a3284215d29958f46d235c4f3ab5 Mon Sep 17 00:00:00 2001 From: Dan Plamadeala Date: Tue, 27 Feb 2024 17:23:08 +0100 Subject: [PATCH] extended the utility script --- util/check_for_unknown_vulns.py | 223 +++++++++++++++++++++++++++++++- 1 file changed, 217 insertions(+), 6 deletions(-) diff --git a/util/check_for_unknown_vulns.py b/util/check_for_unknown_vulns.py index 60a0206..069f30e 100644 --- a/util/check_for_unknown_vulns.py +++ b/util/check_for_unknown_vulns.py @@ -1,6 +1,34 @@ import argparse import json import os +import subprocess + +import requests +from jproperties import Properties +from mysql.connector import pooling +from pymongo import MongoClient +from tqdm import tqdm + + +def parse_database_url(db_url): + # db_url is in the format "jdbc:postgresql://localhost:5432/maven" + try: + url_parts = db_url.split("//")[1].split("/") + host_port = url_parts[0] + database = url_parts[1] + + host = host_port.split(":")[0] + + return host, database + except IndexError: + raise ValueError("Invalid database URL format") + + +properties = Properties() +with open("../config.properties", "rb") as properties_file: + properties.load(properties_file, "utf-8") + +db_host, db_name = parse_database_url(properties.get("database.url").data) def load_json_file(file_path): @@ -35,6 +63,7 @@ def analyze_artifacts(maven_artifacts_file, artifacts_directory): usable_shaded_jar_count = 0 detected_unknown_vulnerable_versions = 0 detected_known_vulnerable_versions = 0 + total_known_vulnerable_versions = 0 maven_artifacts = load_json_file(maven_artifacts_file) for artifact in maven_artifacts: group_id = artifact["groupId"] @@ -51,8 +80,9 @@ def analyze_artifacts(maven_artifacts_file, artifacts_directory): detected_vulnerabilities = find_vulnerabilities_in_inferred_artifact( inferred_artifact_file_path ) + is_vulnerable_version = version_key == "mostUsedVulnerableVersion" + total_known_vulnerable_versions += 1 if is_vulnerable_version else 0 if detected_vulnerabilities: - is_vulnerable_version = version_key == "mostUsedVulnerableVersion" print_artifact_info( group_id, artifact_id, @@ -70,28 +100,209 @@ def analyze_artifacts(maven_artifacts_file, artifacts_directory): f"Total detected vulnerable versions: {detected_unknown_vulnerable_versions} ({(detected_unknown_vulnerable_versions / usable_shaded_jar_count) * 100:.2f}%)" ) print( - f"Total detected known vulnerable versions: {detected_known_vulnerable_versions} ({(detected_known_vulnerable_versions / usable_shaded_jar_count) * 100:.2f}%)" + f"Total detected known vulnerable versions: {detected_known_vulnerable_versions} ({(detected_known_vulnerable_versions / total_known_vulnerable_versions) * 100:.2f}%)" + ) + + +def connect_to_db(): + try: + connection_pool = pooling.MySQLConnectionPool( + pool_name="pom_resolution_pool", + pool_size=5, + host=db_host, + database=db_name, + user=properties.get("database.username").data, + password=properties.get("database.password").data, + ) + return connection_pool + except Exception as e: + print(f"Error connecting to the database: {e}") + return None + + +def get_all_libraries(cursor): + cursor.execute("SELECT id, group_id, artifact_id, version FROM libraries") + return cursor.fetchall() + + +def connect_to_mongodb(): + client = MongoClient("mongodb://localhost:27072/") + db = client.osv_db + return db + + +def check_vulnerability_in_mongodb(db, group_id, artifact_id, version): + query = { + "affected.package.name": f"{group_id}:{artifact_id}", + "affected.package.ecosystem": "Maven", + "affected.versions": {"$in": [version]}, + } + count = db.data.count_documents(query) + return count > 0 + + +def get_vulnerable_libraries_from_mongodb(db): + query = { + "affected.package.ecosystem": "Maven", + } + return db.data.find(query) + + +def update_library_vulnerability_status(vulnerable_libraries, output_file_path): + not_found = 0 + total_vulnerable = 0 + pool = connect_to_db() + if pool: + cnx = pool.get_connection() + cursor = cnx.cursor(buffered=True) + + with open(output_file_path, "w") as file: + for vuln in tqdm(vulnerable_libraries): + maven_affected = [ + a for a in vuln["affected"] if a["package"]["ecosystem"] == "Maven" + ] + if not maven_affected: + continue + + affected_package = maven_affected[0] + if "versions" in affected_package: + for version in affected_package["versions"]: + total_vulnerable += 1 + group_id, artifact_id = affected_package["package"][ + "name" + ].split(":") + + file.write(f"{group_id}:{artifact_id}:{version}\n") + print( + f"Updating {group_id}:{artifact_id} version {version} to vulnerable" + ) + query = "SELECT id FROM libraries WHERE group_id = %s AND artifact_id = %s AND version = %s" + cursor.execute(query, (group_id, artifact_id, version)) + library_id = cursor.fetchone() + if not library_id: + print( + f"Library {group_id}:{artifact_id} version {version} not found in corpus" + ) + not_found += 1 + # query = "UPDATE libraries SET vulnerable = 1 WHERE group_id = %s AND artifact_id = %s AND version = %s" + # cursor.execute(query, (group_id, artifact_id, version)) + cnx.commit() + cursor.close() + cnx.close() + print(f"Total not found: {not_found}") + print(f"Total vulnerable: {total_vulnerable}") + + +def fill_vulnerabilities(output_file_path="vulnerable_versions.txt"): + mongo_db = connect_to_mongodb() + vulnerable_libraries = get_vulnerable_libraries_from_mongodb(mongo_db) + update_library_vulnerability_status(vulnerable_libraries, output_file_path) + + +def check_if_exists_in_maven_central_index(group_id, artifact_id, version): + try: + response = requests.get( + "http://localhost:8032/lookup", + params={ + "groupId": group_id, + "artifactId": artifact_id, + "version": version, + }, + ) + return response.status_code == 200 + except subprocess.CalledProcessError as e: + print(f"An error occurred: {e}") + return None + + +def filter_maven_central_artifacts(input_file, output_file="filtered_vulnerable_versions.txt"): + with open(input_file, "r") as file: + vulnerable_artifacts = file.readlines() + + vulnerable_artifacts = [a.strip() for a in vulnerable_artifacts] + count_in_maven_index = 0 + + with open(output_file, "w") as output_file: + for artifact in tqdm(vulnerable_artifacts): + group_id, artifact_id, version = artifact.split(":") + exists = check_if_exists_in_maven_central_index(group_id, artifact_id, version) + if exists is not None: + if exists: + count_in_maven_index += 1 + + output_file.write(f"{group_id}:{artifact_id}:{version}\n") + + print(f"Total in Maven Central index: {count_in_maven_index}") + print( + f"Percentage in Maven Central index: {(count_in_maven_index / len(vulnerable_artifacts)) * 100:.2f}%" ) +def download_vulnerable_artifacts(input_file, download_output_path): + with open(input_file, "r") as file: + vulnerable_artifacts = file.readlines() + + vulnerable_artifacts = [a.strip() for a in vulnerable_artifacts] + for artifact in tqdm(vulnerable_artifacts): + group_id, artifact_id, version = artifact.split(":") + group_id_path = group_id.replace(".", "/") + download_path = os.path.join(download_output_path, group_id_path, artifact_id, version, f"{artifact_id}-{version}.jar") + # https://repo1.maven.org/maven2/com/daml/participant-state_2.13/2.3.13/participant-state_2.13-2.3.13.jar + if not os.path.exists(download_path): + url = f"https://repo1.maven.org/maven2/{group_id_path}/{artifact_id}/{version}/{artifact_id}-{version}.jar" + response = requests.get(url) + # check if the entire path exists + os.makedirs(os.path.dirname(download_path), exist_ok=True) + with open(download_path, "wb") as download_file: + download_file.write(response.content) + import sys + sys.exit(0) + + def main(): parser = argparse.ArgumentParser( description="Analyze Maven artifacts for vulnerabilities" ) + parser.add_argument("--mode", required=True, help="Mode of operation") + parser.add_argument( + "--output_file", help="Path to the vulnerable artifacts GAV output file" + ) + parser.add_argument( + "--input_file", help="Path to the vulnerable artifacts GAV input file" + ) parser.add_argument( "--maven_artifacts_file", - required=True, help="Path to the JSON file containing Maven artifacts information", ) parser.add_argument( "--artifacts_directory", - required=True, help="Path to the directory containing the inferred artifacts metadata", ) + parser.add_argument("--download_output_path", help="Path to the output directory") args = parser.parse_args() - analyze_artifacts(args.maven_artifacts_file, args.artifacts_directory) - + if args.mode == "analyze_artifacts": + if args.maven_artifacts_file and args.artifacts_directory: + analyze_artifacts(args.maven_artifacts_file, args.artifacts_directory) + else: + print( + "Error: Both --maven_artifacts_file and --artifacts_directory are required for this mode" + ) + elif args.mode == "fill_vulnerabilities": + fill_vulnerabilities() + elif args.mode == "filter_maven_central_artifacts": + if not args.input_file: + print("Error: --input_file is required for this mode") + else: + filter_maven_central_artifacts(args.input_file, args.output_file) + elif args.mode == "download_vulnerable_artifacts": + if not args.input_file or not args.download_output_path: + print( + "Error: Both --input_file and --download_output_path are required for this mode" + ) + download_vulnerable_artifacts(args.input_file, args.download_output_path) + else: + print(f"Error: Unknown mode {args.mode}") if __name__ == "__main__": main()