Skip to content

Commit

Permalink
initial commit for ingesting pytorch models
Browse files Browse the repository at this point in the history
  • Loading branch information
Philip Colangelo committed Dec 17, 2024
1 parent 496ef19 commit 7865caf
Show file tree
Hide file tree
Showing 16 changed files with 1,762 additions and 154 deletions.
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="digestai",
version="1.1.0",
version="1.2.0",
description="Model analysis toolkit",
author="Philip Colangelo, Daniel Holanda",
packages=find_packages(where="src"),
Expand All @@ -25,6 +25,8 @@
"platformdirs>=4.2.2",
"pyyaml>=6.0.1",
"psutil>=6.0.0",
"torch",
"transformers",
],
classifiers=[],
entry_points={"console_scripts": ["digest = digest.main:main"]},
Expand Down
137 changes: 87 additions & 50 deletions src/digest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@

from digest.dialog import StatusDialog, InfoDialog, WarnDialog, ProgressDialog
from digest.thread import StatsThread, SimilarityThread, post_process
from digest.popup_window import PopupWindow
from digest.popup_window import PopupWindow, PopupDialog
from digest.huggingface_page import HuggingfacePage
from digest.pytorch_ingest import PyTorchIngest
from digest.multi_model_selection_page import MultiModelSelectionPage
from digest.ui.mainwindow_ui import Ui_MainWindow
from digest.modelsummary import modelSummary
Expand All @@ -49,6 +50,7 @@
from digest.model_class.digest_model import DigestModel
from digest.model_class.digest_onnx_model import DigestOnnxModel
from digest.model_class.digest_report_model import DigestReportModel
from digest.model_class.digest_pytorch_model import DigestPyTorchModel
from utils import onnx_utils

GUI_CONFIG = os.path.join(os.path.dirname(__file__), "gui_config.yaml")
Expand Down Expand Up @@ -166,7 +168,11 @@ def __init__(self, model_file: Optional[str] = None):
self.status_dialog = None
self.err_open_dialog = None
self.temp_dir = tempfile.TemporaryDirectory()
self.digest_models: Dict[str, Union[DigestOnnxModel, DigestReportModel]] = {}
self.digest_models: Dict[
str, Union[DigestOnnxModel, DigestReportModel, DigestPyTorchModel]
] = {}

self.pytorch_ingest_window: Optional[PopupDialog] = None

# QThread containers
self.model_nodes_stats_thread: Dict[str, StatsThread] = {}
Expand Down Expand Up @@ -225,6 +231,9 @@ def __init__(self, model_file: Optional[str] = None):
)
self.multimodelselection_page.model_signal.connect(self.load_model)

# Set up the pyptorch ingest page
self.pytorch_ingest: Optional[PyTorchIngest] = None

# Load model file if given as input to the executable
if model_file:
exists = os.path.exists(model_file)
Expand Down Expand Up @@ -287,7 +296,10 @@ def closeTab(self, index):

def openFile(self):
file_name, _ = QFileDialog.getOpenFileName(
self, "Open File", "", "ONNX and Report Files (*.onnx *.yaml)"
self,
"Open File",
"",
"ONNX, PyTorch, and Report Files (*.onnx *.pt *.yaml)",
)

if not file_name:
Expand Down Expand Up @@ -364,7 +376,7 @@ def update_similarity_widget(
completed_successfully: bool,
model_id: str,
most_similar: str,
png_filepath: Optional[str] = None,
png_file_path: Optional[str] = None,
df_sorted: Optional[pd.DataFrame] = None,
):
widget = None
Expand All @@ -388,20 +400,20 @@ def update_similarity_widget(
completed_successfully
and isinstance(widget, modelSummary)
and digest_model
and png_filepath
and png_file_path
):

if df_sorted is not None:
post_process(
digest_model.model_name, most_similar_list, df_sorted, png_filepath
digest_model.model_name, most_similar_list, df_sorted, png_file_path
)

widget.load_gif.stop()
widget.ui.similarityImg.clear()
# We give the image a 10% haircut to fit it more aesthetically
widget_width = widget.ui.similarityImg.width()

pixmap = QPixmap(png_filepath)
pixmap = QPixmap(png_file_path)
aspect_ratio = pixmap.width() / pixmap.height()
target_height = int(widget_width / aspect_ratio)
pixmap_scaled = pixmap.scaled(
Expand Down Expand Up @@ -436,12 +448,12 @@ def update_similarity_widget(
# Create option to click to enlarge image
widget.ui.similarityImg.mousePressEvent = (
lambda event: self.open_similarity_report(
model_id, png_filepath, most_similar_list
model_id, png_file_path, most_similar_list
)
)
# Create option to click to enlarge image
self.model_similarity_report[model_id] = SimilarityAnalysisReport(
png_filepath, most_similar_list
png_file_path, most_similar_list
)

widget.ui.similarityCorrelation.setText(text)
Expand All @@ -463,12 +475,12 @@ def update_similarity_widget(
):
self.ui.saveBtn.setEnabled(True)

def load_onnx(self, filepath: str):
def load_onnx(self, file_path: str):

# Ensure the filepath follows a standard formatting:
filepath = os.path.normpath(filepath)
# Ensure the file_path follows a standard formatting:
file_path = os.path.normpath(file_path)

if not os.path.exists(filepath):
if not os.path.exists(file_path):
return

# Every time an onnx is loaded we should emulate a model summary button click
Expand All @@ -477,7 +489,7 @@ def load_onnx(self, filepath: str):
# Before opening the file, check to see if it is already opened.
for index in range(self.ui.tabWidget.count()):
widget = self.ui.tabWidget.widget(index)
if isinstance(widget, modelSummary) and filepath == widget.file:
if isinstance(widget, modelSummary) and file_path == widget.file:
self.ui.tabWidget.setCurrentIndex(index)
return

Expand All @@ -486,11 +498,11 @@ def load_onnx(self, filepath: str):
progress = ProgressDialog("Loading & Optimizing ONNX Model...", 8, self)
QApplication.processEvents() # Process pending events

model = onnx_utils.load_onnx(filepath, load_external_data=False)
model = onnx_utils.load_onnx(file_path, load_external_data=False)
opt_model, opt_passed = onnx_utils.optimize_onnx_model(model)
progress.step()

basename = os.path.splitext(os.path.basename(filepath))
basename = os.path.splitext(os.path.basename(file_path))
model_name = basename[0]

# Save the model proto so we can use the Freeze Inputs feature
Expand Down Expand Up @@ -534,14 +546,14 @@ def load_onnx(self, filepath: str):
model_summary.ui.similarityCorrelation.hide()
model_summary.ui.similarityCorrelationStatic.hide()

model_summary.file = filepath
model_summary.file = file_path
model_summary.setObjectName(model_name)
model_summary.ui.modelName.setText(model_name)
model_summary.ui.modelFilename.setText(filepath)
model_summary.ui.modelFilename.setText(file_path)
model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y"))

digest_model.model_name = model_name
digest_model.filepath = filepath
digest_model.file_path = file_path
digest_model.model_inputs = onnx_utils.get_model_input_shapes_types(
opt_model
)
Expand Down Expand Up @@ -694,8 +706,8 @@ def load_onnx(self, filepath: str):
self.model_similarity_thread[model_id].completed_successfully.connect(
self.update_similarity_widget
)
self.model_similarity_thread[model_id].model_filepath = filepath
self.model_similarity_thread[model_id].png_filepath = os.path.join(
self.model_similarity_thread[model_id].model_file_path = file_path
self.model_similarity_thread[model_id].png_file_path = os.path.join(
png_tmp_path, f"heatmap_{model_name}.png"
)
self.model_similarity_thread[model_id].model_id = model_id
Expand All @@ -706,12 +718,12 @@ def load_onnx(self, filepath: str):
except FileNotFoundError as e:
print(f"File not found: {e.filename}")

def load_report(self, filepath: str):
def load_report(self, file_path: str):

# Ensure the filepath follows a standard formatting:
filepath = os.path.normpath(filepath)
# Ensure the file_path follows a standard formatting:
file_path = os.path.normpath(file_path)

if not os.path.exists(filepath):
if not os.path.exists(file_path):
return

# Every time a report is loaded we should emulate a model summary button click
Expand All @@ -720,7 +732,7 @@ def load_report(self, filepath: str):
# Before opening the file, check to see if it is already opened.
for index in range(self.ui.tabWidget.count()):
widget = self.ui.tabWidget.widget(index)
if isinstance(widget, modelSummary) and filepath == widget.file:
if isinstance(widget, modelSummary) and file_path == widget.file:
self.ui.tabWidget.setCurrentIndex(index)
return

Expand All @@ -729,13 +741,13 @@ def load_report(self, filepath: str):
progress = ProgressDialog("Loading Digest Report File...", 2, self)
QApplication.processEvents() # Process pending events

digest_model = DigestReportModel(filepath)
digest_model = DigestReportModel(file_path)

if not digest_model.is_valid:
progress.close()
invalid_yaml_dialog = StatusDialog(
title="Warning",
status_message=f"YAML file {filepath} is not a valid digest report",
status_message=f"YAML file {file_path} is not a valid digest report",
)
invalid_yaml_dialog.show()

Expand All @@ -758,10 +770,10 @@ def load_report(self, filepath: str):
model_summary.ui.similarityCorrelation.hide()
model_summary.ui.similarityCorrelationStatic.hide()

model_summary.file = filepath
model_summary.file = file_path
model_summary.setObjectName(digest_model.model_name)
model_summary.ui.modelName.setText(digest_model.model_name)
model_summary.ui.modelFilename.setText(filepath)
model_summary.ui.modelFilename.setText(file_path)
model_summary.ui.generatedDate.setText(datetime.now().strftime("%B %d, %Y"))

model_summary.ui.parameters.setText(format(digest_model.parameters, ","))
Expand Down Expand Up @@ -888,17 +900,38 @@ def load_report(self, filepath: str):
completed_successfully=bool(digest_model.similarity_heatmap_path),
model_id=digest_model.unique_id,
most_similar="",
png_filepath=digest_model.similarity_heatmap_path,
png_file_path=digest_model.similarity_heatmap_path,
)

progress.close()

except FileNotFoundError as e:
print(f"File not found: {e.filename}")

def load_pytorch(self, file_path: str):
# Ensure the file_path follows a standard formatting:
file_path = os.path.normpath(file_path)

if not os.path.exists(file_path):
return

basename = os.path.splitext(os.path.basename(file_path))
model_name = basename[0]

self.pytorch_ingest = PyTorchIngest(file_path, model_name)
self.pytorch_ingest_window = PopupDialog(
self.pytorch_ingest, "PyTorch Ingest", self
)
self.pytorch_ingest_window.open()

# The above code will block until the user has completed the pytorch ingest form
# The form will exit upon a successful export at which point the path will be set
if self.pytorch_ingest.digest_pytorch_model.onnx_file_path:
self.load_onnx(self.pytorch_ingest.digest_pytorch_model.onnx_file_path)

def load_model(self, file_path: str):

# Ensure the filepath follows a standard formatting:
# Ensure the file_path follows a standard formatting:
file_path = os.path.normpath(file_path)

if not os.path.exists(file_path):
Expand All @@ -910,6 +943,8 @@ def load_model(self, file_path: str):
self.load_onnx(file_path)
elif file_ext == ".yaml":
self.load_report(file_path)
elif file_ext == ".pt" or file_ext == ".pth":
self.load_pytorch(file_path)
else:
bad_ext_dialog = StatusDialog(
f"Digest does not support files with the extension {file_ext}",
Expand Down Expand Up @@ -992,30 +1027,32 @@ def save_reports(self):
)

# Save csv of node type counts
node_type_filepath = os.path.join(
node_type_file_path = os.path.join(
save_directory, f"{model_name}_node_type_counts.csv"
)
digest_model.save_node_type_counts_csv_report(node_type_filepath)
digest_model.save_node_type_counts_csv_report(node_type_file_path)

# Save (copy) the similarity image
png_file_path = self.model_similarity_thread[
digest_model.unique_id
].png_filepath
].png_file_path
png_save_path = os.path.join(save_directory, f"{model_name}_heatmap.png")
if png_file_path and os.path.exists(png_file_path):
shutil.copy(png_file_path, png_save_path)

# Save the text report
txt_report_filepath = os.path.join(save_directory, f"{model_name}_report.txt")
digest_model.save_text_report(txt_report_filepath)
txt_report_file_path = os.path.join(save_directory, f"{model_name}_report.txt")
digest_model.save_text_report(txt_report_file_path)

# Save the yaml report
yaml_report_filepath = os.path.join(save_directory, f"{model_name}_report.yaml")
digest_model.save_yaml_report(yaml_report_filepath)
yaml_report_file_path = os.path.join(
save_directory, f"{model_name}_report.yaml"
)
digest_model.save_yaml_report(yaml_report_file_path)

# Save the node list
nodes_report_filepath = os.path.join(save_directory, f"{model_name}_nodes.csv")
self.save_nodes_csv(nodes_report_filepath, False)
nodes_report_file_path = os.path.join(save_directory, f"{model_name}_nodes.csv")
self.save_nodes_csv(nodes_report_file_path, False)

self.status_dialog = StatusDialog(
f"Saved reports to: \n{os.path.abspath(save_directory)}",
Expand Down Expand Up @@ -1051,20 +1088,20 @@ def save_file_dialog(
)
return path, filter_type

def save_parameters_csv(self, filepath: str, open_dialog: bool = True):
self.save_nodes_csv(filepath, open_dialog)
def save_parameters_csv(self, file_path: str, open_dialog: bool = True):
self.save_nodes_csv(file_path, open_dialog)

def save_flops_csv(self, filepath: str, open_dialog: bool = True):
self.save_nodes_csv(filepath, open_dialog)
def save_flops_csv(self, file_path: str, open_dialog: bool = True):
self.save_nodes_csv(file_path, open_dialog)

def save_nodes_csv(self, csv_filepath: Optional[str], open_dialog: bool = True):
def save_nodes_csv(self, csv_file_path: Optional[str], open_dialog: bool = True):
if open_dialog:
csv_filepath, _ = self.save_file_dialog()
if not csv_filepath:
raise ValueError("A filepath must be given.")
csv_file_path, _ = self.save_file_dialog()
if not csv_file_path:
raise ValueError("A file_path must be given.")
current_tab = self.ui.tabWidget.currentWidget()
if isinstance(current_tab, modelSummary):
current_tab.digest_model.save_nodes_csv_report(csv_filepath)
current_tab.digest_model.save_nodes_csv_report(csv_file_path)

def save_chart(self, chart_view):
path, _ = self.save_file_dialog("Save PNG", "PNG(*.png)")
Expand Down
Loading

0 comments on commit 7865caf

Please sign in to comment.