diff --git a/modules/ui/AdditionalEmbeddingsTab.py b/modules/ui.legacy/AdditionalEmbeddingsTab.py
similarity index 100%
rename from modules/ui/AdditionalEmbeddingsTab.py
rename to modules/ui.legacy/AdditionalEmbeddingsTab.py
diff --git a/modules/ui/CaptionUI.py b/modules/ui.legacy/CaptionUI.py
similarity index 100%
rename from modules/ui/CaptionUI.py
rename to modules/ui.legacy/CaptionUI.py
diff --git a/modules/ui/CloudTab.py b/modules/ui.legacy/CloudTab.py
similarity index 100%
rename from modules/ui/CloudTab.py
rename to modules/ui.legacy/CloudTab.py
diff --git a/modules/ui/ConceptTab.py b/modules/ui.legacy/ConceptTab.py
similarity index 100%
rename from modules/ui/ConceptTab.py
rename to modules/ui.legacy/ConceptTab.py
diff --git a/modules/ui/ConceptWindow.py b/modules/ui.legacy/ConceptWindow.py
similarity index 100%
rename from modules/ui/ConceptWindow.py
rename to modules/ui.legacy/ConceptWindow.py
diff --git a/modules/ui/ConfigList.py b/modules/ui.legacy/ConfigList.py
similarity index 100%
rename from modules/ui/ConfigList.py
rename to modules/ui.legacy/ConfigList.py
diff --git a/modules/ui/ConvertModelUI.py b/modules/ui.legacy/ConvertModelUI.py
similarity index 100%
rename from modules/ui/ConvertModelUI.py
rename to modules/ui.legacy/ConvertModelUI.py
diff --git a/modules/ui/GenerateCaptionsWindow.py b/modules/ui.legacy/GenerateCaptionsWindow.py
similarity index 100%
rename from modules/ui/GenerateCaptionsWindow.py
rename to modules/ui.legacy/GenerateCaptionsWindow.py
diff --git a/modules/ui/GenerateMasksWindow.py b/modules/ui.legacy/GenerateMasksWindow.py
similarity index 100%
rename from modules/ui/GenerateMasksWindow.py
rename to modules/ui.legacy/GenerateMasksWindow.py
diff --git a/modules/ui/LoraTab.py b/modules/ui.legacy/LoraTab.py
similarity index 100%
rename from modules/ui/LoraTab.py
rename to modules/ui.legacy/LoraTab.py
diff --git a/modules/ui/ModelTab.py b/modules/ui.legacy/ModelTab.py
similarity index 100%
rename from modules/ui/ModelTab.py
rename to modules/ui.legacy/ModelTab.py
diff --git a/modules/ui/OffloadingWindow.py b/modules/ui.legacy/OffloadingWindow.py
similarity index 100%
rename from modules/ui/OffloadingWindow.py
rename to modules/ui.legacy/OffloadingWindow.py
diff --git a/modules/ui/OptimizerParamsWindow.py b/modules/ui.legacy/OptimizerParamsWindow.py
similarity index 100%
rename from modules/ui/OptimizerParamsWindow.py
rename to modules/ui.legacy/OptimizerParamsWindow.py
diff --git a/modules/ui/ProfilingWindow.py b/modules/ui.legacy/ProfilingWindow.py
similarity index 100%
rename from modules/ui/ProfilingWindow.py
rename to modules/ui.legacy/ProfilingWindow.py
diff --git a/modules/ui/SampleFrame.py b/modules/ui.legacy/SampleFrame.py
similarity index 100%
rename from modules/ui/SampleFrame.py
rename to modules/ui.legacy/SampleFrame.py
diff --git a/modules/ui/SampleParamsWindow.py b/modules/ui.legacy/SampleParamsWindow.py
similarity index 100%
rename from modules/ui/SampleParamsWindow.py
rename to modules/ui.legacy/SampleParamsWindow.py
diff --git a/modules/ui/SampleWindow.py b/modules/ui.legacy/SampleWindow.py
similarity index 100%
rename from modules/ui/SampleWindow.py
rename to modules/ui.legacy/SampleWindow.py
diff --git a/modules/ui/SamplingTab.py b/modules/ui.legacy/SamplingTab.py
similarity index 100%
rename from modules/ui/SamplingTab.py
rename to modules/ui.legacy/SamplingTab.py
diff --git a/modules/ui/SchedulerParamsWindow.py b/modules/ui.legacy/SchedulerParamsWindow.py
similarity index 100%
rename from modules/ui/SchedulerParamsWindow.py
rename to modules/ui.legacy/SchedulerParamsWindow.py
diff --git a/modules/ui/TimestepDistributionWindow.py b/modules/ui.legacy/TimestepDistributionWindow.py
similarity index 100%
rename from modules/ui/TimestepDistributionWindow.py
rename to modules/ui.legacy/TimestepDistributionWindow.py
diff --git a/modules/ui/TopBar.py b/modules/ui.legacy/TopBar.py
similarity index 100%
rename from modules/ui/TopBar.py
rename to modules/ui.legacy/TopBar.py
diff --git a/modules/ui/TrainUI.py b/modules/ui.legacy/TrainUI.py
similarity index 100%
rename from modules/ui/TrainUI.py
rename to modules/ui.legacy/TrainUI.py
diff --git a/modules/ui/TrainingTab.py b/modules/ui.legacy/TrainingTab.py
similarity index 100%
rename from modules/ui/TrainingTab.py
rename to modules/ui.legacy/TrainingTab.py
diff --git a/modules/ui/VideoToolUI.py b/modules/ui.legacy/VideoToolUI.py
similarity index 100%
rename from modules/ui/VideoToolUI.py
rename to modules/ui.legacy/VideoToolUI.py
diff --git a/modules/ui/README.md b/modules/ui/README.md
new file mode 100644
index 000000000..8f582bf3b
--- /dev/null
+++ b/modules/ui/README.md
@@ -0,0 +1,132 @@
+QT6 GUI Overview
+=================
+
+
+
+## Overall Architecture
+
+The GUI has been completely re-implemented as a Model-View-Controller architecture, for better future-proofing.
+The folder structure is the following:
+- `modules/ui/models`: OneTrainer functionalities, abstracted from GUI implementation
+- `modules/ui/controllers`: Linker classes, managing how models should be invoked, validating (complex) user inputs and orchestrating GUI behavior
+- `modules/ui/views`: `*.ui` files drawing each component, in a way which is as data-agnostic as possible
+- `modules/ui/utils`: auxiliary classes.
+
+### Models
+Model classes collect original OneTrainer functionalities, abstracting from the specific user interface.
+As models can potentially be invoked from different processes/threads/event loops, each operation modifying internal states must be thread-safe.
+
+Models subclassing `SingletonConfigModel` wrap `modules.util.config` classes, exposing a singleton interface and a thread-safe dot-notation-based read/write mechanism.
+
+Other models provide auxiliary utilities (e.g., open the browser, load files, etc.) and are mostly grouped conceptually (i.e., all file operations are handled by the same class).
+
+Thread-safe access to model objects is mediated by a global QSimpleMutex, shared by every subclass of `SingletonConfigModel`. Multiple levels of synchronization are possible:
+- Each model has a `self.config` attribute which can be accessed safely with `Whatever.instance().get_state(var)` and `Whatever.instance().set_state(var, value)` (or unsafely with `Whatever.instance().config.var`)
+- Multiple variables can be read/written atomically with the `self.bulk_read()` and `self.bulk_write()` methods. These should be used to make sure that users editing UI controls while a multiple variables are read consecutively do not result in an inconsistent state.
+- There are four context managers wrapping blocks of code in critical regions:
+ 1. `with self.critical_region_read()` and `with self.critical_region_write()` mediate access to a shared resource with an *instance-specific* reentrant read-write lock. Most, if not all, synchronizations should use these two context managers.
+ 2. `with self.critical_region()` uses a generic reentrant lock which is *instance-specific*
+ 3. `with self.critical_region_global()` uses a generic reentrant lock which is *shared across every subclass of `SingletonConfigModel`*.
+
+
+### Controllers
+Controller classes are finite-state machines that initialize themselves with a specific sequence of events, and then react to external events (slots/signals).
+Each controller is associated with a view (`self.ui`) and is optionally associated with a parent controller (`self.parent`), creating a hierarchy with the `OneTrainerController` at the root.
+
+At construction, each controller executes these operations:
+1. `BaseController.__init__`: initializes the view
+2. `_setup()`: setups additional attributes (e.g., references to model classes)
+3. `_loadPresets()`: for controls that contain variable data (e.g., QComboBox), loads the list of values (typically from a `modules.util.enum` class, or from files)
+4. Connect static controls according to `self.state_ui_connections` dict: connects ui elements to `StateModel` variables bidirectionally (every time a control is changed, the `TrainConfig` is updated, and every time `stateChanged` is emitted, the control is updated)
+5. `_connectUIBehavior()`: forms static connections between signals and slots (e.g., button behaviors)
+4. `_connectInputValidation()`: associates complex validation functions (QValidators, slots, or other mechanisms) to each control (simple validations are defined in view files)
+6. Invalidation of controls connected with `update_after_connect=True`
+7. `self.__init__`: Additional controller-specific initializations.
+
+The `state_ui_connections` dictionary contains pairs `{'train_config_variable': 'ui_element'}` for ease of connection, and a similar pattern is often used for other connections. This dictionary involves *only* connections with `StateModel`.
+Other models are connected to controls manually in `_connectUIBehavior()`, using a similar pattern on a local dictionary.
+Every interaction with non-GUI code (e.g., progress bar updates, training, etc.) is mediated by signals/slots which invoke model methods.
+
+Controllers also have the responsibility of owning and handling additional threads. This is to guarantee better encapsulation and future-proofing, as changing libraries or invocation patterns will allow to keep the models untouched.
+
+### Views
+View files are created with QtCreator, or QtDesigner, and assumed to expose, whenever possible,data-agnostic controls (e.g., a QComboBox for data types, the values of which are populated at runtime).
+
+Naming convention: each widget within a `*.ui` file is either a decoration (e.g., a static label or a spacer) with its default name (e.g. `label_42`), or is associated with a meaningful name in the form `camelCaseControlNameXxx`,
+where `Xxx` is a class identifier:
+- `Lbl`: QLabel
+- `Led`: QLineEdit
+- `Ted`: QTextEdit
+- `Cbx`: QComboBox
+- `Sbx`: QSpinBox or QDoubleSpinBox
+- `Cmb`: QComboBox
+- `Lay`: QVerticalLayout, QHorizontalLayout or QGridLayout
+- `Btn`: QPushButton or QToolButton.
+
+This convention has no real use, other than allowing contributors to quickly tell from the name which signals/slots are supported by a given UI element.
+
+Suggested development checklist:
+1. Create UI layout
+2. Assign widget attributes (name, text, size policy, tooltip, etc.)
+3. Assign buddies for labels
+4. Edit tab order
+5. Assign simple validations (e.g., QSpinBox min/max values, QLineEdit masks, etc.)
+
+Note that `*.ui` files allow for simple Signal-Slot connections to be defined directly from the WYSIWYG editor, however this can lead to maintenance headaches, when a connection is accidentally made both on the View and the Controller. I strongly advice to connect controls only in the `_connectUIBehavior()` and `connectInputValidation()` methods of the Controller.
+
+Violations of the Model-View-Controller architecture:
+- The fields of the optimizer window are created dynamically from its controller. This was mostly to avoid having a hard to maintain `.ui` file.
+
+### Utils
+Auxiliary, but QT-dependent, classes.
+
+- `FigureWidget`: Figure widget for plots and images. Can be instantiated with a toolbar (separate `MaskDrawingToolbar` class) for inspection or image editing (arbitrary tools are managed by the controller instantiating the widget).
+- `OneTrainerApplication`: Subclass of QApplication defining global signals which can be connected from any Controller
+- `WorkerPool`: Generic threaded processor executing functions on a thread pool automatically managed. Functions can be made reentrant (i.e., they will be executed once, even when multiple calls are made, useful for example when a user attempts to scan the same folder before the previous operation terminated) if they are associated with a name.
+
+## QT6 Notions
+The following are some basic notions for useful QT6 features.
+
+Signal-slot connections: QT's interactions are asynchronous and based on message passing. Each widget exposes two types of methods:
+- Signals are fired when a particular event occurs (e.g., a QPushButton is clicked) or when explicitly `emit()`ed. Some signals are associated with data with known type (e.g., `QLineEdit.textChanged` also transmits the text in a string parameter).
+- Slots are functions receiving a signal and processing its data. For efficiency reasons, they should be annotated with a `@Slot(types)` decorator, but arbitrary python functions can act as slots, as long as their parameters match the signal.
+- The `@Slot` decorator does not accept the idiom `type|None`, you can either use "normal" functions, or decorate them with `@Slot(object)` for nullable parameters.
+
+A signal-slot connection can be created (`connect()`) and destroyed (`disconnect()`) dynamically.
+Every time a signal is emitted, all the slots connected to it are executed.
+
+Important considerations:
+- While slots can be also anonymous lambdas, signals must be class members, therefore subclassing a QWidget is needed in case new signals are needed.
+- If a slot modifies a UI element, it is possible that a new signal may be emitted, potentially causing infinite signal-slot calls. To avoid such cases, a slot should invoke `widget.blockSignals(True)` before changing its value.
+- QtCreator/QtDesigner allow to directly connect signals and slots with matching signatures (e.g., `QLineEdit.textChanged(str)` and `QLabel.text(str)` will automatically copy the text from the line edit to the label) from the UI editor, this is convenient, but there is the risk of forgetting to connect something, or connecting it twice (once in the UI editor and then again in python code)
+- The order in which slots are executed is by default FIFO. This can be a source of bugs if code relies on slots being fired in a specific order.
+
+Buddies: Events involving QLabels can be redirected to different controls (e.g., clicking on a label may activate a text box on its right), to improve the user experience.
+Buddies can be associated statically in `*.ui` files, or associated programmatically (e.g., when a label is created from python code).
+
+Widget promotion: Widgets can be subclassed to provide additional functionalities, without losing the possibility of exploiting the WYSIWYG editor. It is sufficient to define a widget as belonging to a particular class, and registering at runtime the promotion.
+
+Text masks and validators: Invalid QLineEdit input can be rejected automatically with either of two mechanisms:
+- [Masks](https://doc.qt.io/qtforpython-6/PySide6/QtWidgets/QLineEdit.html#PySide6.QtWidgets.QLineEdit.inputMask): force format adherence (e.g., imposing a `hh:mm:ss` format for times, or `#hhhhhh` for RGB colors) by altering the text as it is edited
+- Instances of QValidator: prevent the control to emit `returnPressed` and `editingFinished` signals as long as the text entered does not pass the checks, and optionally expose methods to correct invalid text (default QValidators, such a QRegexValidator, use these additional methods to automatically cancel invalid characters as they are typed).
+
+[Localization](https://doc.qt.io/qt-6/localization.html): Each string defined in `*.ui` files, as well as each string processed by QTranslator, `tr()` or `QCoreApplication.translate()` can be automatically extracted into an xml file by the `lupdate` tool, translated by native-language contributors, and loaded at runtime.
+Since `lupdate` is a static analyzer, it is important that each string can be inspected from the source file (i.e., `tr("A static string")` will be translatable, `my_var = "A not-so-static string"; tr(my_var)` will not).
+
+## Concurrent Execution Model
+The application uses multiple approaches for concurrent execution.
+- QT6 objects implicitly use the internal `QThreadPool` model. This is transparent from the programmer's perspective, but using only Signals and Slots for every non-trivial communication mechanism is important to prevent unintended behavior arising from this execution model.
+- The `ImageModel` and `BulkModel` rely on the MapReduce paradigm, therefore are implemented with a standard `multiprocessing.Pool.map` approach. Note that since it internally relies on `pickle`, not every function can be run on it (namely, class methods and lambdas are not pickleable)
+- `WorkerPool` offers three execution mechanisms, all exposing the same Signal-Slot-based interface:
+ 1. Anonymous `QRunnable` functions automatically handled by `QThreadPool`, which can be enqueued arbitrarily many times.
+ 2. Named `QRunnable` functions automatically handled by `QThreadPool`, which are reentrant based on a name (if a `QRunnable` with the same name is already running, the new worker is not enqueued).
+ 3. Traditional `threading.Thread` functions, manually launched. This addresses two limitations of `QRunnable`, at the expenses of sacrificing automatic load balancing: exceptions in the underlying C++ code can crash the application, and the absence of `join()`.
+
+## Decisions and Caveats
+- Since the original OneTrainer code was strongly coupled with the user interface, many model classes were rewritten from scratch, with a high chance of introducing bugs.
+- Enums in `modules/util/enum` have been extended with methods for GUI pretty-printing (`modules.util.enum.BaseEnum.BaseEnum` class), without altering their existing functionality
+- I have more or less arbitrarily decided that strings should all be translated with `QCoreApplication.translate()`, because it groups them by context (e.g. `QCoreApplication.translate(context="model_tab", sourceText="Data Type")`), allowing explicit disambiguation every time, and providing translators with a somewhat ordered xml (every string with the same context will be close together).
+- At the moment Enum values are non-translatable, because pretty printing often relies on string manipulation.
+- Signal-slot connections are wrapped by `BaseController._connect()` to easily manage reconnections of dynamic widgets, and the "low level" methods should not be called directly.
+- The application exposes global signals (e.g., `modelChanged`, `openedConcept(int)`, etc.), which are used to guarantee data consistency across all UI elements, by letting slots handle updates. This should be cleaner than asking the caller to modify UI elements other than its own.
+- For the time being, `modules.ui.models` classes simply handle the backend functionalities that were implemented in the old UI. In the future it may be reasonable to merge it with `modules.util.config` into thread-safe global states.
diff --git a/modules/ui/controllers/BaseController.py b/modules/ui/controllers/BaseController.py
new file mode 100644
index 000000000..d00b308aa
--- /dev/null
+++ b/modules/ui/controllers/BaseController.py
@@ -0,0 +1,320 @@
+import functools
+import os
+import re
+import webbrowser
+
+from modules.ui.models.StateModel import StateModel
+
+import PySide6.QtCore as QtC
+import PySide6.QtGui as QtG
+import PySide6.QtWidgets as QtW
+from PySide6.QtCore import Qt, Slot
+from showinfm import show_in_file_manager
+
+
+# Abstract controller with some utility methods. Each controller is a Finite-State Machine executing:
+# super()__init__ -> _setup -> _loadPresets -> _connectStateUI -> _connectUIBehavior -> _connectInputValidation -> _invalidateUI -> self.__init__
+# After it is initialized, a controller reacts to external signals by using the slots connected with self._connect(), possibly with some helper methods.
+# For legibility, the methods are grouped into: ###FSM###, ###Reactions###, ###Utils###
+class BaseController:
+ state_ui_connections = {} # Class attribute, but it will be overwritten by every subclass.
+
+ def __init__(self, loader, ui_file, name=None, parent=None, invalidate_once=True, **kwargs):
+ self.loader = loader
+ self.parent = parent
+ self.ui = loader.load(ui_file, parentWidget=parent.ui if parent is not None else None)
+ self.name = name
+
+ self.connections = {}
+ self.invalidation_callbacks = []
+
+ self._setup()
+ self._loadPresets()
+ self._connectStateUI(self.state_ui_connections, StateModel.instance(), signal=QtW.QApplication.instance().stateChanged, update_after_connect=True, **kwargs)
+ self._connectUIBehavior()
+ self._connectInputValidation()
+ if invalidate_once:
+ self._connect(QtW.QApplication.instance().initialized, lambda: self._invalidateUI())
+ else:
+ self._invalidateUI()
+
+ ###FSM###
+
+ # Override this method to initialize auxiliary attributes of each controller.
+ def _setup(self):
+ pass
+
+ # Override this method to load preset values for each control.
+ def _loadPresets(self):
+ pass
+
+ # DO NOT override this method. It connects UI elements to SingletonConfigModel internal variables.
+ # It will be automatically called for StateModel, but you should call it manually for other models.
+ def _connectStateUI(self, connection_dict, model, signal=None, update_after_connect=False, group="global", **kwargs):
+ for var, ui_names in connection_dict.items():
+ if len(kwargs) > 0:
+ var = var.format(**kwargs)
+
+ if isinstance(ui_names, str):
+ ui_names = [ui_names]
+ for ui_name in ui_names:
+ ui_elem = self.ui.findChild(QtC.QObject, ui_name)
+ if ui_elem is None:
+ self._log("error", f"ERROR: {ui_name} not found.")
+ else:
+ if isinstance(ui_elem, QtW.QCheckBox):
+ self._connect(ui_elem.stateChanged, self.__readCbx(ui_elem, var, model), group)
+ elif isinstance(ui_elem, QtW.QComboBox):
+ self._connect(ui_elem.activated, self.__readCbm(ui_elem, var, model), group)
+ elif isinstance(ui_elem, (QtW.QSpinBox, QtW.QDoubleSpinBox)):
+ self._connect(ui_elem.valueChanged, self.__readSbx(ui_elem, var, model), group)
+ elif isinstance(ui_elem, QtW.QLineEdit):
+ self._connect(ui_elem.editingFinished, self.__readLed(ui_elem, var, model), group)
+
+ callback = functools.partial(BaseController._writeControl, ui_elem, var, model)
+ if signal is not None:
+ self._connect(signal, callback)
+
+ if update_after_connect:
+ self.invalidation_callbacks.append((callback, None))
+
+ # Override this method to connect signals and slots intended for visual behavior (e.g., enable/disable controls).
+ def _connectUIBehavior(self):
+ pass
+
+ # Override this method to handle complex field validation OTHER than the automatic validations defined in *.ui files.
+ def _connectInputValidation(self):
+ pass
+
+ # DO NOT override this method. It triggers the UI updates queued by _connect(), at the end of super().__init__.
+ def _invalidateUI(self):
+ for fn, *args in self.invalidation_callbacks:
+ if len(args) > 0 and args[0] is not None:
+ fn(*args)
+ else:
+ fn()
+
+ ###Reactions###
+
+ # Connects a signal to a slot, possibly segregating it into a named category (for selectively disconnecting it later).
+ # If update_after_connect is true, notifies the controller that the slot must be fired at the end of __init__. initial_args is a list of values to be passed during this initial firing.
+ def _connect(self, signal_list, slot, key="global", update_after_connect=False, initial_args=None):
+ if not isinstance(signal_list, list):
+ signal_list = [signal_list]
+
+ for signal in signal_list:
+ c = signal.connect(slot)
+ if key not in self.connections:
+ self.connections[key] = []
+ self.connections[key].append(c)
+
+ # Schedule every update to be executed at the end of __init__
+ if update_after_connect:
+ if initial_args is None:
+ self.invalidation_callbacks.append((slot, None))
+ else:
+ self.invalidation_callbacks.append((slot, *initial_args))
+
+ # Disconnects all the UI connections.
+ def _disconnectAll(self):
+ for v in self.connections.values():
+ for c in v:
+ self.ui.disconnect(c)
+
+ self.connections = {}
+
+ # Selectively disconnects only the connections belonging to a specific key (e.g., concept indexes).
+ def _disconnectGroup(self, key):
+ if key in self.connections:
+ for c in self.connections[key]:
+ self.ui.disconnect(c)
+ del self.connections[key]
+
+ def _updateProgress(self, elem):
+ @Slot(dict)
+ def f(data):
+ if "value" in data and "max_value" in data:
+ if isinstance(elem, QtW.QProgressBar):
+ elem.setMaximum(data["max_value"])
+ elem.setValue(data["value"])
+ elif isinstance(elem, QtW.QLabel):
+ val = int(data["value"] / data["max_value"]) * 100 if data["max_value"] > 0 else 0
+
+ elem.setText(f"{val}% ({data['value']}/{data['max_value']})")
+ return f
+
+ ###Utils###
+
+ # Force a QLineEdit to accept only scientific notation values, bounded by the provided parameters.
+ # For improved user experience, if the values are non-negative, the regex immediately rejects a minus sign, for every other range, checks are performed after editing is finished.
+ def _connectScientificNotation(self, edit_box, min=None, max=None, inf=False, neg_inf=False, include_min=True, include_max=True):
+ regex = r"\d+\.?\d*([eE][+-]?\d+)?" # Positive numbers.
+
+ if min is None or neg_inf or min < 0.0: # Allow negative numbers.
+ regex = r"[+-]?" + regex
+ if inf: # Allow positive infinity.
+ regex = r"inf|" + regex
+ if neg_inf: # Allow negative infinity.
+ regex = r"-inf|" + regex
+
+ edit_box.setValidator(QtG.QRegularExpressionValidator(regex, self.ui))
+
+ if min is not None and not neg_inf or max is not None and not inf:
+ # Value capping after editing is finished (i.e., the user presses Enter, Tab or the QLineEdit looses focus).
+ self._connect(edit_box.editingFinished, self.__capValues(edit_box, min, max, include_min, include_max))
+
+ @staticmethod
+ def __capValues(elem, min=None, max=None, include_min=True, include_max=True):
+ @Slot()
+ def f():
+ val = float(elem.text())
+ if min is not None and include_min and val < min:
+ elem.setText(str(min))
+ if min is not None and not include_min and val <= min:
+ elem.setText(str(min + 0.01)) # Adding arbitrary delta.
+ if max is not None and not include_max and val >= max:
+ elem.setText(str(max - 0.01)) # Subtracting arbitrary delta.
+ if max is not None and include_max and val > max:
+ elem.setText(str(max))
+ return f
+
+ # Opens a file dialog window when tool_button is pressed, then populates edit_box with the returned value.
+ # Filters for file extensions follow QT6 syntax.
+ def _connectFileDialog(self, tool_button, edit_box, is_dir=False, save=False, title=None, filters=None):
+ def f(elem):
+ diag = QtW.QFileDialog()
+
+ if is_dir:
+ dir = None
+ if os.path.isdir(elem.text()):
+ dir = elem.text()
+ txt = diag.getExistingDirectory(parent=None, caption=title, dir=dir)
+ if txt != "":
+ elem.setText(self._removeWorkingDir(txt))
+ elem.editingFinished.emit()
+ else:
+ file = None
+ if os.path.exists(elem.text()):
+ file = self._removeWorkingDir(elem.text())
+
+ if save:
+ txt, flt = diag.getSaveFileName(parent=None, caption=title, dir=file, filter=filters)
+ if txt != "":
+ elem.setText(self._removeWorkingDir(self._appendExtension(txt, flt)))
+ elem.editingFinished.emit()
+ else:
+ txt, _ = diag.getOpenFileName(parent=None, caption=title, dir=file, filter=filters)
+ if txt != "":
+ elem.setText(self._removeWorkingDir(txt))
+ elem.editingFinished.emit()
+
+ self._connect(tool_button.clicked, functools.partial(f, edit_box))
+
+ # Log a message with the given severity.
+ def _log(self, severity, message):
+ # TODO: if you prefer a GUI text area, print on it instead: https://stackoverflow.com/questions/24469662/how-to-redirect-logger-output-into-pyqt-text-widget
+ # In that case it is important to register a global logger widget (e.g. on a window with different tabs for each severity level)
+ # For high severity, maybe an alertbox can also be opened automatically
+ StateModel.instance().log(severity, message)
+
+
+ # Open a subwindow.
+ def _openWindow(self, controller, fixed_size=False):
+ if fixed_size:
+ controller.ui.setWindowFlag(Qt.WindowCloseButtonHint)
+ controller.ui.setWindowFlag(Qt.WindowMaximizeButtonHint, on=False)
+ controller.ui.setFixedSize(controller.ui.size())
+ controller.ui.show()
+
+ # Open an alert window. Remember to translate the messages.
+ def _openAlert(self, title, message, type="about", buttons=QtW.QMessageBox.StandardButton.Ok):
+ wnd = None
+ if type == "about":
+ QtW.QMessageBox.about(self.ui, title, message) # About has no buttons nor return values.
+ elif type == "critical":
+ wnd = QtW.QMessageBox.critical(self.ui, title, message, buttons=buttons)
+ elif type == "information":
+ wnd = QtW.QMessageBox.information(self.ui, title, message, buttons=buttons)
+ elif type == "question":
+ wnd = QtW.QMessageBox.question(self.ui, title, message, buttons=buttons)
+ elif type == "warning":
+ wnd = QtW.QMessageBox.warning(self.ui, title, message, buttons=buttons)
+
+ return wnd
+
+ # Open an URL in the default web browser.
+ def _openUrl(self, url):
+ webbrowser.open(url, new=0, autoraise=False)
+
+ # Open a directory in the OS' file browser.
+ def _browse(self, dir):
+ if os.path.isdir(dir):
+ show_in_file_manager(dir)
+
+
+ def _appendExtension(self, file, filter):
+ patterns = filter.split("(")[1].split(")")[0].split(", ")
+ for p in patterns:
+ if re.match(p.replace(".", "\\.").replace("*", ".*"), file): # If the file already has a valid extension, return it as is.
+ return file
+
+ if "*" not in patterns[0]: # The pattern is a fixed filename, returning it regardless of the user selected name.
+ return patterns[0] # TODO: maybe returning folder/patterns[0] is more reasonable? In original code there is: path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x (removes file and returns base folder instead)
+ else:
+ return "{}.{}".format(file, patterns[0].split("*.")[1]) # Append the first valid extension to file.
+
+ # These methods cannot use directly lambdas, because variable names would be reassigned within the loop.
+ @staticmethod
+ def __readCbx( ui_elem, var, model):
+ return lambda: model.set_state(var, ui_elem.isChecked())
+
+ @staticmethod
+ def __readCbm(ui_elem, var, model):
+ return lambda: model.set_state(var, ui_elem.currentData())
+
+ @staticmethod
+ def __readSbx(ui_elem, var, model):
+ return lambda x: model.set_state(var, x)
+
+ @staticmethod
+ def __readLed(ui_elem, var, model):
+ return lambda: model.set_state(var, ui_elem.text())
+
+ @staticmethod
+ def _writeControl(ui_elem, var, model, *args): # Discard possible signal arguments.
+ ui_elem.blockSignals(True)
+ val = model.get_state(var)
+ if val is not None:
+ if isinstance(ui_elem, QtW.QCheckBox):
+ ui_elem.setChecked(val)
+ elif isinstance(ui_elem, QtW.QComboBox):
+ idx = ui_elem.findData(val)
+ if idx != -1:
+ ui_elem.setCurrentIndex(idx)
+ elif isinstance(ui_elem, (QtW.QSpinBox, QtW.QDoubleSpinBox)):
+ ui_elem.setValue(float(val))
+ elif isinstance(ui_elem, QtW.QLineEdit):
+ ui_elem.setText(str(val))
+ ui_elem.blockSignals(False)
+
+ def _removeWorkingDir(self, txt):
+ cwd = os.getcwd()
+ if txt.startswith(cwd):
+ out = txt[len(cwd) + 1:]
+ if out == "":
+ out = "."
+ return out # Remove working directory and trailing slash.
+ else:
+ return txt
+
+ def _appendWidget(self, list_widget, controller, self_delete_fn=None, self_clone_fn=None):
+ item = QtW.QListWidgetItem(list_widget)
+ item.setSizeHint(controller.ui.size())
+ list_widget.addItem(item)
+ list_widget.setItemWidget(item, controller.ui)
+
+ if self_delete_fn is not None:
+ self._connect(controller.ui.deleteBtn.clicked, self_delete_fn)
+
+ if self_clone_fn is not None:
+ self._connect(controller.ui.cloneBtn.clicked, self_clone_fn)
diff --git a/modules/ui/controllers/OneTrainerController.py b/modules/ui/controllers/OneTrainerController.py
new file mode 100644
index 000000000..7163fea14
--- /dev/null
+++ b/modules/ui/controllers/OneTrainerController.py
@@ -0,0 +1,305 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.controllers.tabs.AdditionalEmbeddingsController import AdditionalEmbeddingsController
+from modules.ui.controllers.tabs.BackupController import BackupController
+from modules.ui.controllers.tabs.CloudController import CloudController
+from modules.ui.controllers.tabs.ConceptsController import ConceptsController
+from modules.ui.controllers.tabs.DataController import DataController
+from modules.ui.controllers.tabs.EmbeddingsController import EmbeddingsController
+from modules.ui.controllers.tabs.GeneralController import GeneralController
+from modules.ui.controllers.tabs.LoraController import LoraController
+from modules.ui.controllers.tabs.ModelController import ModelController
+from modules.ui.controllers.tabs.SamplingController import SamplingController
+from modules.ui.controllers.tabs.ToolsController import ToolsController
+from modules.ui.controllers.tabs.TrainingController import TrainingController
+from modules.ui.controllers.windows.SaveController import SaveController
+from modules.ui.models.BulkCaptionModel import BulkCaptionModel
+from modules.ui.models.BulkImageModel import BulkImageModel
+from modules.ui.models.CaptionModel import CaptionModel
+from modules.ui.models.MaskModel import MaskModel
+from modules.ui.models.StateModel import StateModel
+from modules.ui.models.TrainingModel import TrainingModel
+from modules.ui.utils.WorkerPool import WorkerPool
+from modules.util.enum.ModelFlags import ModelFlags
+from modules.util.enum.ModelType import ModelType
+from modules.util.enum.TrainingMethod import TrainingMethod
+
+import PySide6.QtGui as QtG
+import PySide6.QtWidgets as QtW
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+# Main window.
+class OnetrainerController(BaseController):
+ state_ui_connections = {
+ "model_type": "modelTypeCmb",
+ "training_method": "trainingTypeCmb"
+ }
+
+ def __init__(self, loader):
+ super().__init__(loader, "modules/ui/views/windows/onetrainer.ui", name="OneTrainer", parent=None)
+
+ QtW.QApplication.instance().initialized.emit()
+
+ ###FSM###
+
+ def _setup(self):
+ self.save_window = SaveController(self.loader, parent=self)
+ self.children = {}
+ self.__createTabs()
+ self.training = False
+
+ # Non-editable QComboBoxes do not honor the maxVisibleItems property, unless the following style is enforced.
+ self.ui.configCmb.setStyleSheet("QComboBox { combobox-popup: 0; }")
+
+ self.__enableControls("enabled")()
+
+ def _connectUIBehavior(self):
+ self._connect(self.ui.wikiBtn.clicked, lambda: self._openUrl("https://github.com/Nerogar/OneTrainer/wiki"))
+ self._connect(self.ui.saveConfigBtn.clicked, lambda: self._openWindow(self.save_window, fixed_size=True))
+ self._connect(self.ui.exportBtn.clicked, lambda: self.__exportConfig())
+ self._connect(self.ui.startBtn.clicked, self.__toggleTrain())
+ self._connect(self.ui.debugBtn.clicked, self.__startDebug())
+ self._connect(self.ui.tensorboardBtn.clicked, self.__openTensorboard())
+
+ self._connect([self.ui.trainingTypeCmb.activated, self.ui.modelTypeCmb.activated, QtW.QApplication.instance().stateChanged],
+ self.__changeModel(), update_after_connect=True)
+
+ self._connect(self.ui.configCmb.activated, lambda idx: self.__loadConfig(self.ui.configCmb.currentData(), idx))
+
+ self._connect([self.ui.modelTypeCmb.activated, QtW.QApplication.instance().stateChanged],
+ self.__updateModel(), update_after_connect=True)
+
+ self._connect(QtW.QApplication.instance().stateChanged, self.__updateConfigs(), update_after_connect=True)
+ self._connect(QtW.QApplication.instance().savedConfig, self.__updateSelectedConfig())
+
+ self._connect(QtW.QApplication.instance().aboutToQuit, self.__onQuit())
+
+ self.__loadConfig("training_presets/#.json") # Load last config.
+ QtW.QApplication.instance().stateChanged.emit()
+
+ def _loadPresets(self):
+ for e in ModelType.enabled_values(context="main_window"):
+ self.ui.modelTypeCmb.addItem(e.pretty_print(), userData=e)
+
+ ###Reactions###
+
+ def __openTensorboard(self):
+ @Slot()
+ def f():
+ self._openUrl("http://localhost:" + str(StateModel.instance().get_state("tensorboard_port")))
+ return f
+
+ def __startDebug(self):
+ @Slot()
+ def f():
+ diag = QtW.QFileDialog()
+ txt, _ = diag.getSaveFileName(parent=None,
+ dir="OneTrainer_debug_report.zip",
+ caption=QCA.translate("main_window", "Save Debug Package"),
+ filter=QCA.translate("filetype_filters", "Zip (*.zip)"))
+ if txt != "":
+ worker, name = WorkerPool.instance().createNamed(self.__generate_debug_package(txt), "generate_debug", poolless=True, inject_progress_callback=True)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableDebugControls(False), result_fn=None,
+ finished_fn=self.__enableDebugControls(True),
+ errored_fn=self.__enableDebugControls(True),
+ progress_fn=self.__updateStatus())
+ WorkerPool.instance().start(name)
+ return f
+
+ def __enableDebugControls(self, enabled):
+ @Slot()
+ def f():
+ self.ui.debugBtn.setEnabled(enabled)
+ return f
+
+ def __updateSelectedConfig(self):
+ @Slot(str)
+ def f(config):
+ self.ui.configCmb.setCurrentText(config)
+ return f
+
+ def __updateConfigs(self):
+ @Slot()
+ def f():
+ configs = StateModel.instance().load_available_config_names("training_presets")
+ self.ui.configCmb.clear()
+ self.save_window.ui.configCmb.clear()
+ for k, v in configs:
+ self.ui.configCmb.addItem(k, userData=v)
+ if not k.startswith("#"):
+ self.save_window.ui.configCmb.addItem(k, userData=v)
+ return f
+
+ def __updateModel(self):
+ @Slot()
+ def f():
+ flags = ModelFlags.getFlags(self.ui.modelTypeCmb.currentData(), self.ui.trainingTypeCmb.currentData())
+
+ old_training_type = self.ui.trainingTypeCmb.currentData()
+
+ self.ui.trainingTypeCmb.clear()
+
+ self.ui.trainingTypeCmb.addItem(QCA.translate("training_method", "Fine Tune"), userData=TrainingMethod.FINE_TUNE)
+ self.ui.trainingTypeCmb.addItem(QCA.translate("training_method", "LoRA"), userData=TrainingMethod.LORA)
+
+ if ModelFlags.CAN_TRAIN_EMBEDDING in flags:
+ self.ui.trainingTypeCmb.addItem(QCA.translate("training_method", "Embedding"), userData=TrainingMethod.EMBEDDING)
+ if ModelFlags.CAN_FINE_TUNE_VAE in flags:
+ self.ui.trainingTypeCmb.addItem(QCA.translate("training_method", "Fine Tune VAE"), userData=TrainingMethod.FINE_TUNE_VAE)
+
+ if old_training_type is not None:
+ self.ui.trainingTypeCmb.setCurrentIndex(self.ui.trainingTypeCmb.findData(old_training_type))
+ self.ui.trainingTypeCmb.activated.emit(self.ui.trainingTypeCmb.findData(old_training_type))
+ else:
+ old_training_type = StateModel.instance().get_state("training_method")
+ if old_training_type is not None:
+ self.ui.trainingTypeCmb.setCurrentIndex(self.ui.trainingTypeCmb.findData(old_training_type))
+ else:
+ self.ui.trainingTypeCmb.activated.emit(0)
+
+ return f
+
+ def __enableControls(self, state):
+ @Slot()
+ def f():
+ if state == "enabled": # Startup and successful termination.
+ self.training = False
+ self.ui.startBtn.setEnabled(True)
+ self.ui.startBtn.setText(QCA.translate("main_window", "Start Training"))
+ self.ui.startBtn.setPalette(self.ui.palette())
+ self.ui.stepPrg.setValue(0)
+ self.ui.epochPrg.setValue(0)
+ self.ui.etaLbl.setText("")
+ elif state == "running":
+ self.training = True
+ self.ui.startBtn.setEnabled(True)
+ self.ui.startBtn.setText(QCA.translate("main_window", "Stop Training"))
+ self.ui.startBtn.setPalette(QtG.QPalette(QtG.QColor("green")))
+ elif state == "stopping":
+ self.training = True
+ self.ui.startBtn.setEnabled(False)
+ self.ui.startBtn.setText(QCA.translate("main_window", "Stopping..."))
+ self.ui.startBtn.setPalette(QtG.QPalette(QtG.QColor("red")))
+ elif state == "cancelled": # Interrupted or errored termination. Do not update progress bars, as we might be interested in knowing in which epoch/step the error occurred.
+ self.training = False
+ self.ui.startBtn.setText(QCA.translate("main_window", "Start Training"))
+ self.ui.startBtn.setPalette(QtG.QPalette(QtG.QColor("darkred")))
+ self.ui.startBtn.setEnabled(True)
+ self.ui.etaLbl.setText("")
+ return f
+
+ def __updateStatus(self):
+ @Slot(dict)
+ def f(data):
+ if "status" in data:
+ self.ui.statusLbl.setText(data["status"])
+
+ if "eta" in data:
+ self.ui.etaLbl.setText(f"ETA: {data['eta']}")
+
+ if "step" in data and "max_steps" in data:
+ self.ui.stepPrg.setMaximum(data["max_steps"])
+ self.ui.stepPrg.setValue(data["step"])
+ if "epoch" in data and "max_epochs" in data:
+ self.ui.epochPrg.setMaximum(data["max_epochs"])
+ self.ui.epochPrg.setValue(data["epoch"])
+
+ if "event" in data:
+ self.__enableControls(data["event"])()
+
+ return f
+
+ def __changeModel(self):
+ @Slot()
+ def f():
+ model_type = self.ui.modelTypeCmb.currentData()
+ training_type = self.ui.trainingTypeCmb.currentData()
+ self.ui.tabWidget.setTabVisible(self.children["lora"]["index"], training_type == TrainingMethod.LORA)
+ self.ui.tabWidget.setTabVisible(self.children["embedding"]["index"], training_type == TrainingMethod.EMBEDDING)
+
+ QtW.QApplication.instance().modelChanged.emit(model_type, training_type)
+ return f
+
+ def __onQuit(self):
+ @Slot()
+ def f():
+ StateModel.instance().save_default()
+ StateModel.instance().stop_tensorboard()
+ CaptionModel.instance().release_model()
+ MaskModel.instance().release_model()
+ BulkImageModel.instance().terminate_pool()
+ BulkCaptionModel.instance().terminate_pool()
+ return f
+
+ def __toggleTrain(self):
+ @Slot()
+ def f():
+ if self.training:
+ self.__stopTrain()
+ else:
+ worker, name = WorkerPool.instance().createNamed(self.__train(), "train", poolless=True, daemon=True, inject_progress_callback=True)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableControls("running"), result_fn=None,
+ finished_fn=self.__enableControls("enabled"),
+ errored_fn=self.__enableControls("cancelled"), aborted_fn=self.__enableControls("cancelled"),
+ progress_fn=self.__updateStatus())
+ WorkerPool.instance().start(name)
+
+ return f
+
+ ###Utils###
+
+ def __generate_debug_package(self, zip_path):
+ def f(progress_fn=None):
+ StateModel.instance().generate_debug_package(zip_path, progress_fn=progress_fn)
+
+ return f
+
+ def __createTabs(self):
+ for name, controller in [
+ ("general", GeneralController),
+ ("model", ModelController),
+ ("data", DataController),
+ ("concepts", ConceptsController),
+ ("training", TrainingController),
+ ("sampling", SamplingController),
+ ("backup", BackupController),
+ ("tools", ToolsController),
+ ("additional_embeddings", AdditionalEmbeddingsController),
+ ("cloud", CloudController),
+ ("lora", LoraController),
+ ("embedding", EmbeddingsController)
+ ]:
+ c = controller(self.loader, parent=self)
+ self.children[name] = {"controller": c, "index": len(self.children)}
+ self.ui.tabWidget.addTab(c.ui, c.name)
+
+ self.ui.tabWidget.setTabVisible(self.children["lora"]["index"], False)
+ self.ui.tabWidget.setTabVisible(self.children["embedding"]["index"], False)
+
+ def __loadConfig(self, config, idx=None):
+ StateModel.instance().load_config(config)
+ QtW.QApplication.instance().stateChanged.emit()
+ QtW.QApplication.instance().embeddingsChanged.emit()
+ if idx is not None:
+ self.ui.configCmb.setCurrentIndex(idx)
+
+
+ def __exportConfig(self):
+ diag = QtW.QFileDialog()
+ txt, flt = diag.getSaveFileName(parent=None, caption=QCA.translate("dialog_window", "Save Config"), dir=None,
+ filter=QCA.translate("filetype_filters", "JSON (*.json)"))
+ if txt != "":
+ filename = self._appendExtension(txt, flt)
+ StateModel.instance().save_config(filename)
+
+
+ def __train(self):
+ def f(progress_fn=None):
+ TrainingModel.instance().train(progress_fn=progress_fn)
+ return f
+
+ def __stopTrain(self):
+ TrainingModel.instance().stop_training()
diff --git a/modules/ui/controllers/tabs/AdditionalEmbeddingsController.py b/modules/ui/controllers/tabs/AdditionalEmbeddingsController.py
new file mode 100644
index 000000000..c9811e0b8
--- /dev/null
+++ b/modules/ui/controllers/tabs/AdditionalEmbeddingsController.py
@@ -0,0 +1,69 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.controllers.widgets.EmbeddingController import EmbeddingController
+from modules.ui.models.StateModel import StateModel
+
+import PySide6.QtWidgets as QtW
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class AdditionalEmbeddingsController(BaseController):
+ children = []
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/tabs/additional_embeddings.ui", name=QCA.translate("main_window_tabs", "Additional Embeddings"), parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connect(self.ui.addEmbeddingBtn.clicked, self.__appendEmbedding())
+ self._connect(self.ui.enableBtn.clicked, self.__enableEmbeddings())
+
+ self._connect([QtW.QApplication.instance().embeddingsChanged, QtW.QApplication.instance().stateChanged],
+ self.__updateEmbeddings(), update_after_connect=True)
+
+
+ ###Reactions###
+
+ def __enableEmbeddings(self):
+ @Slot()
+ def f():
+ StateModel.instance().enable_embeddings()
+ QtW.QApplication.instance().embeddingsChanged.emit()
+ return f
+
+ def __updateEmbeddings(self):
+ @Slot()
+ def f():
+ for c in self.children:
+ c._disconnectAll()
+
+ self.ui.listWidget.clear()
+ self.children = []
+
+ for idx, _ in enumerate(StateModel.instance().get_state("additional_embeddings")):
+ wdg = EmbeddingController(self.loader, idx, parent=self)
+ self.children.append(wdg)
+ self._appendWidget(self.ui.listWidget, wdg, self_delete_fn=self.__deleteEmbedding(idx), self_clone_fn=self.__cloneEmbedding(idx))
+
+ return f
+
+ def __cloneEmbedding(self, idx):
+ @Slot()
+ def f():
+ StateModel.instance().clone_embedding(idx)
+ QtW.QApplication.instance().embeddingsChanged.emit()
+ return f
+
+ def __deleteEmbedding(self, idx):
+ @Slot()
+ def f():
+ StateModel.instance().delete_embedding(idx)
+ QtW.QApplication.instance().embeddingsChanged.emit()
+ return f
+
+ def __appendEmbedding(self):
+ @Slot()
+ def f():
+ StateModel.instance().create_new_embedding()
+ QtW.QApplication.instance().embeddingsChanged.emit()
+ return f
diff --git a/modules/ui/controllers/tabs/BackupController.py b/modules/ui/controllers/tabs/BackupController.py
new file mode 100644
index 000000000..4af12596e
--- /dev/null
+++ b/modules/ui/controllers/tabs/BackupController.py
@@ -0,0 +1,126 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.TrainingModel import TrainingModel
+from modules.ui.utils.WorkerPool import WorkerPool
+from modules.util.enum.TimeUnit import TimeUnit
+
+import PySide6.QtWidgets as QtW
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class BackupController(BaseController):
+ state_ui_connections = {
+ "backup_after": "backupSbx",
+ "backup_after_unit": "backupCmb",
+ "rolling_backup": "rollingBackupCbx",
+ "backup_before_save": "backupBeforeSaveCbx",
+ "rolling_backup_count": "rollingCountSbx",
+ "save_every": "saveSbx",
+ "save_every_unit": "saveCmb",
+ "save_skip_first": "skipSbx",
+ "save_filename_prefix": "savePrefixLed"
+ }
+
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/tabs/backup.ui", name=QCA.translate("main_window_tabs", "Backup"), parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connect(self.ui.backupBtn.clicked, self.__startBackup())
+ self._connect(self.ui.saveBtn.clicked, self.__startSave())
+
+ self._connect([QtW.QApplication.instance().stateChanged, self.ui.backupCmb.activated],
+ self.__updateBackup(), update_after_connect=True)
+
+ self._connect([QtW.QApplication.instance().stateChanged, self.ui.saveCmb.activated],
+ self.__updateSave(), update_after_connect=True)
+
+ self._connect([QtW.QApplication.instance().stateChanged, self.ui.rollingBackupCbx.toggled],
+ self.__updateRollingBackup(), update_after_connect=True)
+
+ def _loadPresets(self):
+ for e in TimeUnit.enabled_values():
+ self.ui.backupCmb.addItem(e.pretty_print(), userData=e)
+ for e in TimeUnit.enabled_values():
+ self.ui.saveCmb.addItem(e.pretty_print(), userData=e)
+
+
+ ###Reactions###
+
+ def __startBackup(self):
+ @Slot()
+ def f():
+ worker, name = WorkerPool.instance().createNamed(self.__backupNow(), "backup_operations", poolless=True, daemon=True,
+ inject_progress_callback=True)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableControls(False), result_fn=None,
+ finished_fn=self.__enableControls(True))
+ WorkerPool.instance().start(name)
+ return f
+
+ def __startSave(self):
+ @Slot()
+ def f():
+ worker, name = WorkerPool.instance().createNamed(self.__saveNow(), "backup_operations", poolless=True, daemon=True,
+ inject_progress_callback=True)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableControls(False), result_fn=None,
+ finished_fn=self.__enableControls(True))
+ WorkerPool.instance().start(name)
+
+ return f
+
+ def __enableControls(self, enabled):
+ @Slot()
+ def f():
+ self.ui.backupBtn.setEnabled(enabled)
+ self.ui.saveBtn.setEnabled(enabled)
+ return f
+
+ def __updateBackup(self):
+ @Slot()
+ def f():
+ enabled = self.ui.backupCmb.currentData() != TimeUnit.NEVER
+
+ self.ui.backupSbx.setEnabled(enabled)
+ self.ui.rollingBackupCbx.setEnabled(enabled)
+ self.ui.rollingCountSbx.setEnabled(enabled and self.ui.rollingBackupCbx.isChecked())
+ self.ui.rollingCountLbl.setEnabled(enabled and self.ui.rollingBackupCbx.isChecked())
+
+ return f
+
+ def __updateRollingBackup(self):
+ @Slot()
+ def f():
+ enabled = self.ui.rollingBackupCbx.isChecked()
+ self.ui.rollingCountSbx.setEnabled(enabled and self.ui.backupCmb.currentData() != TimeUnit.NEVER)
+ self.ui.rollingCountLbl.setEnabled(enabled and self.ui.backupCmb.currentData() != TimeUnit.NEVER)
+
+ return f
+
+ def __updateSave(self):
+ @Slot()
+ def f():
+ enabled = self.ui.saveCmb.currentData() != TimeUnit.NEVER
+
+ self.ui.saveSbx.setEnabled(enabled)
+ self.ui.skipSbx.setEnabled(enabled)
+ self.ui.rollingCountSbx.setEnabled(enabled)
+ self.ui.savePrefixLed.setEnabled(enabled)
+ self.ui.skipLbl.setEnabled(enabled)
+ self.ui.savePrefixLbl.setEnabled(enabled)
+
+ return f
+
+ ###Utils###
+
+ def __backupNow(self):
+ def f(progress_fn=None):
+ TrainingModel.instance().backup_now()
+ return f
+
+ def __saveNow(self):
+ def f(progress_fn=None):
+ TrainingModel.instance().save_now()
+ return f
diff --git a/modules/ui/controllers/tabs/CloudController.py b/modules/ui/controllers/tabs/CloudController.py
new file mode 100644
index 000000000..33b34097e
--- /dev/null
+++ b/modules/ui/controllers/tabs/CloudController.py
@@ -0,0 +1,131 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.StateModel import StateModel
+from modules.ui.models.TrainingModel import TrainingModel
+from modules.ui.utils.WorkerPool import WorkerPool
+from modules.util.enum.CloudAction import CloudAction
+from modules.util.enum.CloudFileSync import CloudFileSync
+from modules.util.enum.CloudSubtype import CloudSubtype
+from modules.util.enum.CloudType import CloudType
+
+import PySide6.QtWidgets as QtW
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class CloudController(BaseController):
+ state_ui_connections = {
+ "cloud.enabled": "enabledCbx",
+ "cloud.type": "cloudTypeCmb",
+ "cloud.file_sync": "fileSyncMethodCmb",
+ "secrets.cloud.api_key": "apiKeyLed",
+ "secrets.cloud.host": "hostnameLed",
+ "secrets.cloud.port": "portSbx",
+ "secrets.cloud.user": "userLed",
+ "secrets.cloud.id": "cloudIdLed",
+ "cloud.tensorboard_tunnel": "tensorboardTcpTunnelCbx",
+ "cloud.detach_trainer": "detachRemoteTrainerCbx",
+ "cloud.run_id": "reattachIdLed",
+ "cloud.download_samples": "downloadSamplesCbx",
+ "cloud.download_output_model": "downloadOutputModelCbx",
+ "cloud.download_saves": "downloadSavedCheckpointsCbx",
+ "cloud.download_backups": "downloadBackupsCbx",
+ "cloud.download_tensorboard": "downloadTensorboardLogCbx",
+ "cloud.delete_workspace": "deleteRemoteWorkspaceCbx",
+ "cloud.remote_dir": "remoteDirectoryLed",
+ "cloud.onetrainer_dir": "onetrainerDirectoryLed",
+ "cloud.huggingface_cache_dir": "huggingfaceCacheLed",
+ "cloud.install_onetrainer": "installOnetrainerCbx",
+ "cloud.install_cmd": "installCommandLed",
+ "cloud.update_onetrainer": "updateOnetrainerCbx",
+ "cloud.create": "createCloudCbx",
+ "cloud.name": "cloudNameLed",
+ "cloud.sub_type": "subTypeCmb",
+ "cloud.gpu_type": "gpuCmb",
+ "cloud.volume_size": "volumeSizeSbx",
+ "cloud.min_download": "minDownloadSbx",
+ "cloud.on_finish": "onFinishCmb",
+ "cloud.on_error": "onErrorCmb",
+ "cloud.on_detached_finish": "onDetachedCmb",
+ "cloud.on_detached_error": "onDetachedErrorCmb",
+ }
+
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/tabs/cloud.ui", name=QCA.translate("main_window_tabs", "Cloud"), parent=parent)
+
+
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connect(self.ui.createCloudBtn.clicked, self.__createCloud())
+ self._connect(self.ui.gpuBtn.clicked, self.__getGPUTypes())
+ self._connect(self.ui.reattachBtn.clicked, self.__reattach())
+
+ self._connect([self.ui.enabledCbx.toggled, QtW.QApplication.instance().stateChanged],
+ self.__enableCloud(), update_after_connect=True)
+
+
+
+ def _loadPresets(self):
+ for ctl in [self.ui.onFinishCmb, self.ui.onErrorCmb, self.ui.onDetachedCmb, self.ui.onDetachedErrorCmb]:
+ for e in CloudAction.enabled_values():
+ ctl.addItem(e.pretty_print(), userData=e)
+
+ for e in CloudType.enabled_values():
+ self.ui.cloudTypeCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in CloudFileSync.enabled_values():
+ self.ui.fileSyncMethodCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in CloudSubtype.enabled_values():
+ self.ui.subTypeCmb.addItem(e.pretty_print(), userData=e)
+
+ ###Reactions###
+
+ def __reattach(self):
+ @Slot()
+ def f():
+ worker, name = WorkerPool.instance().createNamed(self.__train(), "train", poolless=True, daemon=True)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableReattach(False), result_fn=None,
+ finished_fn=self.__enableReattach(True),
+ errored_fn=self.__enableReattach(True), aborted_fn=self.__enableReattach(True))
+ WorkerPool.instance().start(name)
+
+ return f
+
+ def __enableReattach(self, enabled):
+ @Slot()
+ def f():
+ self.ui.reattachBtn.setEnabled(enabled)
+ return f
+
+ def __enableCloud(self):
+ @Slot()
+ def f():
+ self.ui.frame.setEnabled(self.ui.enabledCbx.isChecked())
+ return f
+
+
+ def __getGPUTypes(self):
+ @Slot()
+ def f():
+ self.ui.gpuCmb.clear()
+ for gpu in StateModel.instance().get_gpus():
+ self.ui.gpuCmb.addItem(gpu.name, userData=gpu)
+
+ return f
+
+ def __createCloud(self):
+ @Slot()
+ def f():
+ if StateModel.instance().get_state("cloud.type") == CloudType.RUNPOD:
+ self._openUrl("https://www.runpod.io/console/deploy?template=1a33vbssq9&type=gpu")
+ return f
+
+ ###Utils###
+
+ def __train(self):
+ def f(progress_fn=None):
+ TrainingModel.instance().train(reattach=True, progress_fn=progress_fn)
+ return f
diff --git a/modules/ui/controllers/tabs/ConceptsController.py b/modules/ui/controllers/tabs/ConceptsController.py
new file mode 100644
index 000000000..3928d9430
--- /dev/null
+++ b/modules/ui/controllers/tabs/ConceptsController.py
@@ -0,0 +1,146 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.controllers.widgets.ConceptController import ConceptController as WidgetConceptController
+from modules.ui.controllers.windows.ConceptController import ConceptController as WinConceptController
+from modules.ui.models.ConceptModel import ConceptModel
+from modules.ui.models.StateModel import StateModel
+from modules.util.enum.ConceptType import ConceptType
+
+import PySide6.QtGui as QtGui
+import PySide6.QtWidgets as QtW
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class ConceptsController(BaseController):
+ children = []
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/tabs/concepts.ui", name=QCA.translate("main_window_tabs", "Concepts"), parent=parent)
+
+ ###FSM###
+
+ def _setup(self):
+ self.concept_window = WinConceptController(self.loader, parent=self)
+
+ def _connectUIBehavior(self):
+ self._connect(self.ui.addConceptBtn.clicked, self.__appendConcept())
+ self._connect(self.ui.toggleBtn.clicked, self.__toggleConcepts())
+ self._connect(self.ui.clearBtn.clicked, self.__clearFilters())
+
+ self._connect(QtW.QApplication.instance().stateChanged, self.__updateConfigs(), update_after_connect=True)
+
+ self._connect([QtW.QApplication.instance().aboutToQuit, QtW.QApplication.instance().conceptsChanged],
+ self.__saveConfig())
+
+ self._connect([self.ui.searchLed.textChanged, self.ui.typeCmb.activated, self.ui.showDisabledCbx.toggled, QtW.QApplication.instance().stateChanged],
+ lambda: QtW.QApplication.instance().conceptsChanged.emit(False))
+
+
+ self._connect([QtW.QApplication.instance().conceptsChanged, QtW.QApplication.instance().stateChanged],
+ self.__updateConcepts(), update_after_connect=True)
+
+
+ self._connect([self.ui.presetCmb.textActivated, QtW.QApplication.instance().stateChanged], self.__loadConfig(), update_after_connect=True)
+
+
+
+ def _loadPresets(self):
+ for e in ConceptType.enabled_values(context="all"):
+ self.ui.typeCmb.addItem(e.pretty_print(), userData=e)
+
+ def _connectInputValidation(self):
+ self.ui.presetCmb.setValidator(QtGui.QRegularExpressionValidator(r"[a-zA-Z0-9_\-.][a-zA-Z0-9_\-. ]*", self.ui))
+
+ ###Reactions###
+
+ def __updateConfigs(self):
+ @Slot()
+ def f():
+ configs = ConceptModel.instance().load_available_config_names("training_concepts", include_default=False)
+ if len(configs) == 0:
+ configs.append(("concepts", "training_concepts/concepts.json"))
+
+ for c in self.children:
+ c._disconnectAll()
+
+ self.ui.presetCmb.clear()
+ for k, v in configs:
+ self.ui.presetCmb.addItem(k, userData=v)
+
+ self.ui.presetCmb.setCurrentIndex(self.ui.presetCmb.findData(StateModel.instance().get_state("concept_file_name")))
+ return f
+
+ def __loadConfig(self):
+ def f(filename=None):
+ if filename is None:
+ filename = self.ui.presetCmb.currentText()
+ ConceptModel.instance().load_config(filename)
+ QtW.QApplication.instance().conceptsChanged.emit(False)
+ return f
+
+ def __saveConfig(self):
+ @Slot(bool)
+ def f(save=True):
+ if save:
+ ConceptModel.instance().save_config()
+ return f
+
+ def __appendConcept(self):
+ @Slot()
+ def f():
+ ConceptModel.instance().create_new_concept()
+ QtW.QApplication.instance().conceptsChanged.emit(True)
+ return f
+
+ def __toggleConcepts(self):
+ @Slot()
+ def f():
+ ConceptModel.instance().toggle_concepts()
+ QtW.QApplication.instance().conceptsChanged.emit(True)
+ return f
+
+ def __clearFilters(self):
+ def f():
+ self.ui.searchLed.setText("")
+ self.ui.typeCmb.setCurrentIndex(self.ui.typeCmb.findData(ConceptType.ALL))
+ self.ui.showDisabledCbx.setChecked(True)
+
+ QtW.QApplication.instance().conceptsChanged.emit(False)
+
+ return f
+
+ def __updateConcepts(self):
+ @Slot(bool)
+ def f(save=False):
+ for c in self.children:
+ c._disconnectAll()
+
+ self.ui.listWidget.clear()
+ self.children = []
+
+ for idx, _ in ConceptModel.instance().get_filtered_concepts(self.ui.searchLed.text(), self.ui.typeCmb.currentData(), self.ui.showDisabledCbx.isChecked()):
+ wdg = WidgetConceptController(self.loader, self.concept_window, idx, parent=self)
+ self.children.append(wdg)
+ self._appendWidget(self.ui.listWidget, wdg, self_delete_fn=self.__deleteConcept(idx), self_clone_fn=self.__cloneConcept(idx))
+
+ if ConceptModel.instance().some_concepts_enabled():
+ self.ui.toggleBtn.setText(QCA.translate("main_window_tabs", "Disable All"))
+ else:
+ self.ui.toggleBtn.setText(QCA.translate("main_window_tabs", "Enable All"))
+
+ return f
+
+ def __cloneConcept(self, idx):
+ @Slot()
+ def f():
+ ConceptModel.instance().clone_concept(idx)
+ QtW.QApplication.instance().conceptsChanged.emit(True)
+
+ return f
+
+ def __deleteConcept(self, idx):
+ @Slot()
+ def f():
+ ConceptModel.instance().delete_concept(idx)
+ QtW.QApplication.instance().conceptsChanged.emit(True)
+
+ return f
diff --git a/modules/ui/controllers/tabs/DataController.py b/modules/ui/controllers/tabs/DataController.py
new file mode 100644
index 000000000..0d3e28ad5
--- /dev/null
+++ b/modules/ui/controllers/tabs/DataController.py
@@ -0,0 +1,31 @@
+from modules.ui.controllers.BaseController import BaseController
+
+import PySide6.QtWidgets as QtW
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class DataController(BaseController):
+ state_ui_connections = {
+ "aspect_ratio_bucketing": "aspectBucketingCbx",
+ "latent_caching": "latentCachingCbx",
+ "clear_cache_before_training": "clearCacheCbx"
+ }
+
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/tabs/data.ui", name=QCA.translate("main_window_tabs", "Data"), parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connect([self.ui.latentCachingCbx.toggled, QtW.QApplication.instance().stateChanged],
+ self.__updateCaching(), update_after_connect=True)
+
+ ###Reactions###
+
+ def __updateCaching(self):
+ @Slot()
+ def f():
+ enabled = self.ui.latentCachingCbx.isChecked()
+ self.ui.clearCacheCbx.setEnabled(enabled)
+ return f
diff --git a/modules/ui/controllers/tabs/EmbeddingsController.py b/modules/ui/controllers/tabs/EmbeddingsController.py
new file mode 100644
index 000000000..a59c749ab
--- /dev/null
+++ b/modules/ui/controllers/tabs/EmbeddingsController.py
@@ -0,0 +1,29 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.util.enum.DataType import DataType
+
+from PySide6.QtCore import QCoreApplication as QCA
+
+
+class EmbeddingsController(BaseController):
+ state_ui_connections = {
+ "embedding.model_name": "baseEmbeddingLed",
+ "embedding.token_count": "tokenSbx",
+ "embedding.initial_embedding_text": "initialEmbeddingLed",
+ "embedding_weight_dtype": "embeddingDTypeCmb",
+ "embedding.placeholder": "placeholderLed",
+ "embedding.is_output_embedding": "outputCbx",
+ }
+
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/tabs/embeddings.ui", name=QCA.translate("main_window_tabs", "Embeddings"), parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connectFileDialog(self.ui.baseEmbeddingBtn, self.ui.baseEmbeddingLed, is_dir=False, save=False,
+ title=QCA.translate("dialog_window", "Open base embeddings"),
+ filters=QCA.translate("filetype_filters", "Safetensors (*.safetensors);;Diffusers (model_index.json);;Checkpoints (*.ckpt *.pt *.bin);;All Files (*.*)"))
+
+ def _loadPresets(self):
+ for e in DataType.enabled_values(context="embeddings"):
+ self.ui.embeddingDTypeCmb.addItem(e.pretty_print(), userData=e)
diff --git a/modules/ui/controllers/tabs/GeneralController.py b/modules/ui/controllers/tabs/GeneralController.py
new file mode 100644
index 000000000..f9959cce1
--- /dev/null
+++ b/modules/ui/controllers/tabs/GeneralController.py
@@ -0,0 +1,110 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.StateModel import StateModel
+from modules.util.enum.GradientReducePrecision import GradientReducePrecision
+from modules.util.enum.TimeUnit import TimeUnit
+
+import PySide6.QtGui as QtGui
+import PySide6.QtWidgets as QtW
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class GeneralController(BaseController):
+ state_ui_connections = {
+ "workspace_dir": "workspaceLed",
+ "continue_last_backup": "continueCbx",
+ "debug_mode": "debugCbx",
+ "tensorboard": "tensorboardCbx",
+ "tensorboard_expose": "exposeTensorboardCbx",
+ "validation": "validateCbx",
+ "dataloader_threads": "dataloaderSbx",
+ "train_device": "trainDeviceLed",
+ "multi_gpu": "multiGpuCbx",
+ "gradient_reduce_precision": "gradientReduceCmb",
+ "async_gradient_reduce": "asyncGradientCbx",
+ "temp_device": "tempDeviceLed",
+ "cache_dir": "cacheLed",
+ "only_cache": "onlyCacheCbx",
+ "debug_dir": "debugLed",
+ "tensorboard_always_on": "alwaysOnTensorboardCbx",
+ "tensorboard_port": "tensorboardSbx",
+ "validate_after": "validateSbx",
+ "validate_after_unit": "validateCmb",
+ "device_indexes": "deviceIndexesLed",
+ "fused_gradient_reduce": "fusedGradientCbx",
+ "async_gradient_reduce_buffer": "bufferSbx",
+ }
+
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/tabs/general.ui", name=QCA.translate("main_window_tabs", "General"), parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connectFileDialog(self.ui.workspaceBtn, self.ui.workspaceLed, is_dir=True, save=False,
+ title=QCA.translate("dialog_window", "Open Workspace directory"))
+ self._connectFileDialog(self.ui.cacheBtn, self.ui.cacheLed, is_dir=True, save=False,
+ title=QCA.translate("dialog_window", "Open Cache directory"))
+ self._connectFileDialog(self.ui.debugBtn, self.ui.debugLed, is_dir=True, save=False,
+ title=QCA.translate("dialog_window", "Open Debug directory"))
+
+ self._connect([self.ui.alwaysOnTensorboardCbx.toggled, self.ui.workspaceLed.editingFinished, QtW.QApplication.instance().stateChanged],
+ self.__toggleTensorboard(), update_after_connect=True)
+
+ self._connect([self.ui.validateCbx.toggled, QtW.QApplication.instance().stateChanged],
+ self.__updateValidate(), update_after_connect=True)
+
+ self._connect([self.ui.tensorboardCbx.toggled, QtW.QApplication.instance().stateChanged],
+ self.__updateTensorboard(), update_after_connect=True)
+
+ self._connect([self.ui.debugCbx.toggled, QtW.QApplication.instance().stateChanged],
+ self.__updateDebug(), update_after_connect=True)
+
+ def _loadPresets(self):
+ for e in GradientReducePrecision.enabled_values():
+ self.ui.gradientReduceCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in TimeUnit.enabled_values():
+ self.ui.validateCmb.addItem(e.pretty_print(), userData=e)
+
+ def _connectInputValidation(self):
+ self.ui.deviceIndexesLed.setValidator(QtGui.QRegularExpressionValidator(r"(\d+(,\d+)*)?", self.ui))
+
+ ###Reactions###
+
+ def __updateValidate(self):
+ @Slot()
+ def f():
+ enabled = self.ui.validateCbx.isChecked()
+ self.ui.validateLbl.setEnabled(enabled)
+ self.ui.validateSbx.setEnabled(enabled)
+ self.ui.validateCmb.setEnabled(enabled)
+ return f
+
+ def __updateTensorboard(self):
+ @Slot()
+ def f():
+ enabled = self.ui.tensorboardCbx.isChecked()
+ self.ui.alwaysOnTensorboardCbx.setEnabled(enabled)
+ self.ui.exposeTensorboardCbx.setEnabled(enabled)
+ self.ui.tensorboardLbl.setEnabled(enabled)
+ self.ui.tensorboardSbx.setEnabled(enabled)
+ return f
+
+ def __updateDebug(self):
+ @Slot()
+ def f():
+ enabled = self.ui.debugCbx.isChecked()
+ self.ui.debugLbl.setEnabled(enabled)
+ self.ui.debugLed.setEnabled(enabled)
+ self.ui.debugBtn.setEnabled(enabled)
+ return f
+
+ def __toggleTensorboard(self):
+ @Slot()
+ def f():
+ if self.ui.alwaysOnTensorboardCbx.isChecked():
+ StateModel.instance().start_tensorboard()
+ else:
+ StateModel.instance().stop_tensorboard()
+ return f
diff --git a/modules/ui/controllers/tabs/LoraController.py b/modules/ui/controllers/tabs/LoraController.py
new file mode 100644
index 000000000..45546d446
--- /dev/null
+++ b/modules/ui/controllers/tabs/LoraController.py
@@ -0,0 +1,81 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.util.enum.DataType import DataType
+from modules.util.enum.ModelType import PeftType
+
+import PySide6.QtWidgets as QtW
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class LoraController(BaseController):
+ state_ui_connections = {
+ "peft_type": "typeCmb",
+ "lora_model_name": "baseModelLed",
+ "lora_rank": "rankSbx",
+ "lora_alpha": "alphaSbx",
+ "lora_decompose": "decomposeCbx",
+ "lora_decompose_norm_epsilon": "normCbx",
+ "lora_decompose_output_axis": "outputAxisCbx",
+ "lora_weight_dtype": "weightDTypeCmb",
+ "bundle_additional_embeddings": "bundleCbx",
+ "oft_block_size": "oftBlockSizeSbx",
+ "oft_coft": "coftCbx",
+ "coft_eps": "coftLed",
+ "oft_block_share": "blockShareCbx",
+ }
+
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/tabs/lora.ui", name=QCA.translate("main_window_tabs", "Lora"), parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connectFileDialog(self.ui.baseModelBtn, self.ui.baseModelLed, is_dir=False, save=False,
+ title=QCA.translate("dialog_window", "Open LoRA/LoHA/OFT 2 base model"),
+ filters=QCA.translate("filetype_filters", "Safetensors (*.safetensors);;Diffusers (model_index.json);;Checkpoints (*.ckpt *.pt *.bin);;All Files (*.*)"))
+
+ self._connect([QtW.QApplication.instance().stateChanged, self.ui.typeCmb.activated],
+ self.__updateType(), update_after_connect=True)
+
+ self._connect([QtW.QApplication.instance().stateChanged, self.ui.decomposeCbx.toggled],
+ self.__updateDora(), update_after_connect=True)
+
+ def _connectInputValidation(self):
+ # Alpha cannot be higher than rank.
+ self._connect(self.ui.rankSbx.valueChanged, lambda x: (self.ui.alphaSbx.setMaximum(x)))
+
+ self._connectScientificNotation(self.ui.coftLed, min=0.0)
+
+
+ def _loadPresets(self):
+ for e in PeftType.enabled_values():
+ self.ui.typeCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in DataType.enabled_values(context="lora"):
+ self.ui.weightDTypeCmb.addItem(e.pretty_print(), userData=e)
+
+ ###Reactions###
+
+ def __updateDora(self):
+ @Slot()
+ def f():
+ enabled = self.ui.decomposeCbx.isChecked()
+ self.ui.normCbx.setEnabled(enabled)
+ self.ui.outputAxisCbx.setEnabled(enabled)
+ return f
+
+ def __updateType(self):
+ @Slot()
+ def f():
+ self.ui.doraFrm.setVisible(self.ui.typeCmb.currentData() == PeftType.LORA)
+ self.ui.oftFrm.setVisible(self.ui.typeCmb.currentData() == PeftType.OFT_2)
+
+ self.ui.oftBlockSizeLbl.setVisible(self.ui.typeCmb.currentData() == PeftType.OFT_2)
+ self.ui.oftBlockSizeSbx.setVisible(self.ui.typeCmb.currentData() == PeftType.OFT_2)
+
+ self.ui.rankLbl.setVisible(self.ui.typeCmb.currentData() != PeftType.OFT_2)
+ self.ui.rankSbx.setVisible(self.ui.typeCmb.currentData() != PeftType.OFT_2)
+
+ self.ui.alphaLbl.setVisible(self.ui.typeCmb.currentData() != PeftType.OFT_2)
+ self.ui.alphaSbx.setVisible(self.ui.typeCmb.currentData() != PeftType.OFT_2)
+ return f
diff --git a/modules/ui/controllers/tabs/ModelController.py b/modules/ui/controllers/tabs/ModelController.py
new file mode 100644
index 000000000..4e32eff02
--- /dev/null
+++ b/modules/ui/controllers/tabs/ModelController.py
@@ -0,0 +1,193 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.StateModel import StateModel
+from modules.util.enum.ConfigPart import ConfigPart
+from modules.util.enum.DataType import DataType
+from modules.util.enum.ModelFlags import ModelFlags
+from modules.util.enum.ModelFormat import ModelFormat
+from modules.util.enum.ModelType import ModelType
+from modules.util.enum.TrainingMethod import TrainingMethod
+
+import PySide6.QtWidgets as QtW
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class ModelController(BaseController):
+ state_ui_connections = {
+ "secrets.huggingface_token": "huggingfaceLed",
+ "base_model_name": "baseModelLed",
+ "weight_dtype": "weightDTypeCmb",
+ "output_model_destination": "modelOutputLed",
+ "output_dtype": "outputDTypeCmb",
+ "output_model_format": "outputFormatCmb",
+ "include_train_config": "configCmb",
+ "unet.weight_dtype": "unetDTypeCmb",
+ "prior.weight_dtype": "priorDTypeCmb",
+ "text_encoder.weight_dtype": "te1DTypeCmb",
+ "text_encoder_2.weight_dtype": "te2DTypeCmb",
+ "text_encoder_3.weight_dtype": "te3DTypeCmb",
+ "text_encoder_4.weight_dtype": "te4DTypeCmb",
+ "vae.weight_dtype": "vaeDTypeCmb",
+ "effnet_encoder.weight_dtype": "effnetDTypeCmb",
+ "decoder.weight_dtype": "decDTypeCmb",
+ "decoder_text_encoder.weight_dtype": "decTeDTypeCmb",
+ "decoder_vqgan.weight_dtype": "vqganDTypeCmb",
+ "prior.model_name": "priorLed",
+ "text_encoder_4.model_name": "te4Led",
+ "vae.model_name": "vaeLed",
+ "effnet_encoder.model_name": "effnetLed",
+ "decoder.model_name": "decLed",
+ "compile": "compileTransformerCbx",
+ "transformer.model_name": "transformerLed",
+ "transformer.weight_dtype": "transformerDTypeCmb",
+ }
+
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/tabs/model.ui", name=QCA.translate("main_window_tabs", "Model"), parent=parent)
+
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ for ui_name in ["baseModel", "prior", "te4", "vae", "effnet", "dec"]:
+ btn = self.ui.findChild(QtW.QToolButton, f"{ui_name}Btn")
+ led = self.ui.findChild(QtW.QLineEdit, f"{ui_name}Led")
+ self._connectFileDialog(btn, led, is_dir=False, save=False,
+ title=QCA.translate("dialog_window", "Open model"),
+ filters=QCA.translate("filetype_filters",
+ "Safetensors (*.safetensors);;Diffusers (model_index.json);;Checkpoints (*.ckpt *.pt *.bin);;All Files (*.*)"))
+
+ self._connectFileDialog(self.ui.transformerBtn, self.ui.transformerLed, is_dir=False, save=False,
+ title=QCA.translate("dialog_window", "Open model"),
+ filters=QCA.translate("filetype_filters",
+ "Safetensors (*.safetensors);;Diffusers (model_index.json);;Checkpoints (*.ckpt *.pt *.bin);;GGUF (*.gguf);;All Files (*.*)"))
+
+ self._connectFileDialog(self.ui.modelOutputBtn, self.ui.modelOutputLed, is_dir=False, save=True,
+ title=QCA.translate("dialog_window", "Save output model"),
+ filters=QCA.translate("filetype_filters",
+ "Safetensors (*.safetensors);;Diffusers (model_index.json);;Checkpoints (*.ckpt *.pt *.bin);;All Files (*.*)"))
+
+ self._connect(QtW.QApplication.instance().modelChanged, self.__updateModel(), update_after_connect=True, initial_args=[StateModel.instance().get_state("model_type"), StateModel.instance().get_state("training_method")])
+
+ def _connectInputValidation(self):
+ self._connect(self.ui.transformerLed.editingFinished, self.__forceGGUF(from_line_edit=True))
+ self._connect(self.ui.transformerDTypeCmb.activated, self.__forceGGUF(from_line_edit=False))
+
+ def _loadPresets(self):
+ for e in ConfigPart.enabled_values():
+ self.ui.configCmb.addItem(e.pretty_print(), userData=e)
+
+ for ui_name in ["weightDTypeCmb", "unetDTypeCmb", "priorDTypeCmb", "te1DTypeCmb", "te2DTypeCmb", "te3DTypeCmb", "te4DTypeCmb",
+ "vaeDTypeCmb", "effnetDTypeCmb", "decDTypeCmb", "decTeDTypeCmb", "vqganDTypeCmb"]:
+ ui_elem = self.ui.findChild(QtW.QComboBox, ui_name)
+ for e in self.__createDTypes(include_none=(ui_name=="weightDTypeCmb")):
+ ui_elem.addItem(e.pretty_print(), userData=e)
+
+ for e in DataType.enabled_values(context="output_dtype"):
+ self.ui.outputDTypeCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in DataType.enabled_values(context="transformer_dtype"):
+ self.ui.transformerDTypeCmb.addItem(e.pretty_print(), userData=e)
+
+ ###Reactions###
+
+ def __forceGGUF(self, from_line_edit=False):
+ @Slot()
+ def f():
+ if from_line_edit:
+ if self.ui.transformerLed.text().endswith(".gguf"):
+ self.ui.transformerDTypeCmb.setCurrentIndex(self.ui.transformerDTypeCmb.findData(DataType.GGUF))
+ else:
+ self.ui.transformerDTypeCmb.setCurrentIndex(0)
+ else:
+ if self.ui.transformerDTypeCmb.currentData() == DataType.GGUF and not self.ui.transformerLed.text().endswith(".gguf"):
+ self.ui.transformerLed.setText("")
+ return f
+
+ def __updateModel(self):
+ @Slot(ModelType, TrainingMethod)
+ def f(model_type, training_method):
+ flags = ModelFlags.getFlags(model_type, training_method)
+
+
+ self.ui.outputFormatCmb.clear()
+ if ModelFlags.ALLOW_SAFETENSORS in flags:
+ self.ui.outputFormatCmb.addItem("Safetensors", userData=ModelFormat.SAFETENSORS)
+ if ModelFlags.ALLOW_DIFFUSERS in flags:
+ self.ui.outputFormatCmb.addItem("Diffusers", userData=ModelFormat.DIFFUSERS)
+ #if ModelFlags.ALLOW_LEGACY_SAFETENSORS in flags:
+ # self.ui.outputFormatCmb.addItem("Legacy Safetensors", userData=ModelFormat.LEGACY_SAFETENSORS)
+
+ self.ui.unetDTypeLbl.setVisible(ModelFlags.UNET in flags)
+ self.ui.unetDTypeCmb.setVisible(ModelFlags.UNET in flags)
+
+ self.ui.priorDTypeLbl.setVisible(ModelFlags.PRIOR in flags)
+ self.ui.priorDTypeCmb.setVisible(ModelFlags.PRIOR in flags)
+
+ self.ui.transformerDTypeLbl.setVisible(ModelFlags.TRANSFORMER in flags)
+ self.ui.transformerDTypeCmb.setVisible(ModelFlags.TRANSFORMER in flags)
+
+ self.ui.te1DTypeLbl.setVisible(ModelFlags.TE1 in flags)
+ self.ui.te1DTypeCmb.setVisible(ModelFlags.TE1 in flags)
+
+ self.ui.te2DTypeLbl.setVisible(ModelFlags.TE2 in flags)
+ self.ui.te2DTypeCmb.setVisible(ModelFlags.TE2 in flags)
+
+ self.ui.te3DTypeLbl.setVisible(ModelFlags.TE3 in flags)
+ self.ui.te3DTypeCmb.setVisible(ModelFlags.TE3 in flags)
+
+ self.ui.te4DTypeLbl.setVisible(ModelFlags.TE4 in flags)
+ self.ui.te4DTypeCmb.setVisible(ModelFlags.TE4 in flags)
+
+ self.ui.vaeDTypeLbl.setVisible(ModelFlags.VAE in flags)
+ self.ui.vaeDTypeCmb.setVisible(ModelFlags.VAE in flags)
+
+ self.ui.effnetLbl.setVisible(ModelFlags.EFFNET in flags)
+ self.ui.effnetLed.setVisible(ModelFlags.EFFNET in flags)
+ self.ui.effnetBtn.setVisible(ModelFlags.EFFNET in flags)
+ self.ui.effnetDTypeLbl.setVisible(ModelFlags.EFFNET in flags)
+ self.ui.effnetDTypeCmb.setVisible(ModelFlags.EFFNET in flags)
+
+ self.ui.decLbl.setVisible(ModelFlags.DEC in flags)
+ self.ui.decLed.setVisible(ModelFlags.DEC in flags)
+ self.ui.decBtn.setVisible(ModelFlags.DEC in flags)
+ self.ui.decDTypeLbl.setVisible(ModelFlags.DEC in flags)
+ self.ui.decDTypeCmb.setVisible(ModelFlags.DEC in flags)
+ self.ui.vqganDTypeLbl.setVisible(ModelFlags.DEC in flags)
+ self.ui.vqganDTypeCmb.setVisible(ModelFlags.DEC in flags)
+
+ self.ui.decTeDTypeLbl.setVisible(ModelFlags.DEC_TE in flags)
+ self.ui.decTeDTypeCmb.setVisible(ModelFlags.DEC_TE in flags)
+
+ self.ui.priorLbl.setVisible(ModelFlags.OVERRIDE_PRIOR in flags)
+ self.ui.priorLed.setVisible(ModelFlags.OVERRIDE_PRIOR in flags)
+ self.ui.priorBtn.setVisible(ModelFlags.OVERRIDE_PRIOR in flags)
+
+ self.ui.transformerLbl.setVisible(ModelFlags.OVERRIDE_TRANSFORMER in flags)
+ self.ui.transformerLed.setVisible(ModelFlags.OVERRIDE_TRANSFORMER in flags)
+ self.ui.transformerBtn.setVisible(ModelFlags.OVERRIDE_TRANSFORMER in flags)
+
+ self.ui.te4Lbl.setVisible(ModelFlags.OVERRIDE_TE4 in flags)
+ self.ui.te4Led.setVisible(ModelFlags.OVERRIDE_TE4 in flags)
+ self.ui.te4Btn.setVisible(ModelFlags.OVERRIDE_TE4 in flags)
+
+ if ModelFlags.TE1 in flags and ModelFlags.TE2 not in flags:
+ self.ui.te1DTypeLbl.setText(
+ QCA.translate("model_tab_label", "Override Text Encoder Data Type")
+ )
+ else:
+ self.ui.te1DTypeLbl.setText(
+ QCA.translate("model_tab_label", "Override Text Encoder 1 Data Type")
+ )
+
+ return f
+
+ ###Utils###
+
+ def __createDTypes(self, include_none=True):
+ options = DataType.enabled_values(context="model_dtypes")
+
+ if include_none:
+ options.insert(0, DataType.NONE)
+
+ return options
diff --git a/modules/ui/controllers/tabs/SamplingController.py b/modules/ui/controllers/tabs/SamplingController.py
new file mode 100644
index 000000000..1545ebaa7
--- /dev/null
+++ b/modules/ui/controllers/tabs/SamplingController.py
@@ -0,0 +1,176 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.controllers.widgets.SampleController import SampleController
+from modules.ui.controllers.windows.NewSampleController import NewSampleController
+from modules.ui.controllers.windows.SampleController import SampleController as SampleControllerWindow
+from modules.ui.models.SampleModel import SampleModel
+from modules.ui.models.StateModel import StateModel
+from modules.ui.models.TrainingModel import TrainingModel
+from modules.ui.utils.WorkerPool import WorkerPool
+from modules.util.enum.ImageFormat import ImageFormat
+from modules.util.enum.TimeUnit import TimeUnit
+
+import PySide6.QtGui as QtGui
+import PySide6.QtWidgets as QtW
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class SamplingController(BaseController):
+ state_ui_connections = {
+ "sample_after": "sampleAfterSbx",
+ "sample_after_unit": "sampleAfterCmb",
+ "sample_skip_first": "skipSbx",
+ "sample_image_format": "formatCmb",
+ "non_ema_sampling": "nonEmaCbx",
+ "samples_to_tensorboard": "tensorboardCbx",
+ }
+
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/tabs/sampling.ui", name=QCA.translate("main_window_tabs", "Sampling"), parent=parent)
+
+ ###FSM###
+
+ def _setup(self):
+ self.children = []
+ self.sample_params_window = NewSampleController(self.loader, parent=self)
+ self.manual_sample_window = SampleControllerWindow(self.loader, parent=None)
+
+ def _connectUIBehavior(self):
+ self._connect(self.ui.addSampleBtn.clicked, self.__appendSample())
+ self._connect(self.ui.toggleBtn.clicked, self.__toggleSamples())
+ self._connect(self.ui.manualSampleBtn.clicked, self.__openSampleWindow())
+ self._connect(self.ui.sampleNowBtn.clicked, self.__startSample())
+
+ self._connect(QtW.QApplication.instance().stateChanged, self.__updateConfigs(), update_after_connect=True)
+
+ self._connect([QtW.QApplication.instance().samplesChanged, QtW.QApplication.instance().stateChanged],
+ self.__updateSamples(), update_after_connect=True)
+
+ self._connect([self.ui.configCmb.textActivated, QtW.QApplication.instance().stateChanged],
+ self.__loadConfig(), update_after_connect=True, initial_args=[self.ui.configCmb.currentText()])
+
+ self._connect([QtW.QApplication.instance().aboutToQuit, QtW.QApplication.instance().samplesChanged],
+ self.__saveConfig())
+
+
+ def _connectInputValidation(self):
+ self.ui.configCmb.setValidator(QtGui.QRegularExpressionValidator(r"[a-zA-Z0-9_\-.][a-zA-Z0-9_\-. ]*", self.ui))
+
+
+ def _loadPresets(self):
+ for e in TimeUnit.enabled_values():
+ self.ui.sampleAfterCmb.addItem(e.pretty_print(), userData=e)
+ for e in ImageFormat.enabled_values():
+ self.ui.formatCmb.addItem(e.pretty_print(), userData=e)
+
+ ###Reactions###
+
+ def __openSampleWindow(self):
+ @Slot()
+ def f():
+ self._openWindow(self.manual_sample_window, fixed_size=True)
+ return f
+
+ def __loadConfig(self):
+ def f(filename=None):
+ if filename is None:
+ filename = self.ui.configCmb.currentText()
+ SampleModel.instance().load_config(filename)
+ QtW.QApplication.instance().samplesChanged.emit()
+ return f
+
+ def __saveConfig(self):
+ @Slot()
+ def f():
+ SampleModel.instance().save_config()
+ return f
+
+ def __updateConfigs(self):
+ @Slot()
+ def f():
+ configs = SampleModel.instance().load_available_config_names("training_samples", include_default=False)
+ if len(configs) == 0:
+ configs.append(("samples", "training_samples/samples.json"))
+
+ self.ui.configCmb.clear()
+ for k, v in configs:
+ self.ui.configCmb.addItem(k, userData=v)
+
+ self.ui.configCmb.setCurrentIndex(self.ui.configCmb.findData(StateModel.instance().get_state("sample_definition_file_name")))
+ return f
+
+ def __updateSamples(self):
+ @Slot()
+ def f():
+ for c in self.children:
+ c._disconnectAll()
+
+ self.ui.listWidget.clear()
+ self.children = []
+
+ for idx, _ in enumerate(SampleModel.instance().get_state("")):
+ wdg = SampleController(self.loader, self.sample_params_window, idx, parent=self)
+ self.children.append(wdg)
+ self._appendWidget(self.ui.listWidget, wdg, self_delete_fn=self.__deleteSample(idx), self_clone_fn=self.__cloneSample(idx))
+
+ if SampleModel.instance().some_samples_enabled():
+ self.ui.toggleBtn.setText(QCA.translate("main_window_tabs", "Disable All"))
+ else:
+ self.ui.toggleBtn.setText(QCA.translate("main_window_tabs", "Enable All"))
+
+ return f
+
+ def __appendSample(self):
+ @Slot()
+ def f():
+ SampleModel.instance().create_new_sample()
+ QtW.QApplication.instance().samplesChanged.emit()
+ return f
+
+ def __toggleSamples(self):
+ @Slot()
+ def f():
+ SampleModel.instance().toggle_samples()
+ QtW.QApplication.instance().samplesChanged.emit()
+ return f
+
+ def __cloneSample(self, idx):
+ @Slot()
+ def f():
+ SampleModel.instance().clone_sample(idx)
+ QtW.QApplication.instance().samplesChanged.emit()
+
+ return f
+
+ def __deleteSample(self, idx):
+ @Slot()
+ def f():
+ SampleModel.instance().delete_sample(idx)
+ QtW.QApplication.instance().samplesChanged.emit()
+
+ return f
+
+ def __startSample(self):
+ @Slot()
+ def f():
+ worker, name = WorkerPool.instance().createNamed(self.__sampleNow(), "sampling_operations", poolless=True, daemon=True,
+ inject_progress_callback=True)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableControls(False), result_fn=None,
+ finished_fn=self.__enableControls(True))
+ WorkerPool.instance().start(name)
+
+ return f
+
+ def __enableControls(self, enabled):
+ @Slot()
+ def f():
+ self.ui.sampleNowBtn.setEnabled(enabled)
+ return f
+
+ ###Utils###
+
+ def __sampleNow(self):
+ def f(progress_fn=None):
+ TrainingModel.instance().sample_now()
+ return f
diff --git a/modules/ui/controllers/tabs/ToolsController.py b/modules/ui/controllers/tabs/ToolsController.py
new file mode 100644
index 000000000..5f110ce58
--- /dev/null
+++ b/modules/ui/controllers/tabs/ToolsController.py
@@ -0,0 +1,50 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.controllers.windows.BulkCaptionController import BulkCaptionController
+from modules.ui.controllers.windows.BulkImageController import BulkImageController
+from modules.ui.controllers.windows.CaptionController import CaptionController
+from modules.ui.controllers.windows.ConvertController import ConvertController
+from modules.ui.controllers.windows.DatasetController import DatasetController
+from modules.ui.controllers.windows.MaskController import MaskController
+from modules.ui.controllers.windows.ProfileController import ProfileController
+from modules.ui.controllers.windows.SampleController import SampleController
+from modules.ui.controllers.windows.VideoController import VideoController
+
+from PySide6.QtCore import QCoreApplication as QCA
+
+
+class ToolsController(BaseController):
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/tabs/tools.ui", name=QCA.translate("main_window_tabs", "Tools"), parent=parent)
+
+ ###FSM###
+
+ def _setup(self):
+
+ self.children = {"dataset": DatasetController(self.loader, parent=None),
+ "caption": CaptionController(self.loader, parent=None),
+ "mask": MaskController(self.loader, parent=None),
+ "image": BulkImageController(self.loader, parent=None),
+ "bulk_caption": BulkCaptionController(self.loader, parent=None),
+ "video": VideoController(self.loader, parent=None),
+ "convert": ConvertController(self.loader, parent=None),
+ "sample": SampleController(self.loader, parent=None),
+ "profile": ProfileController(self.loader, parent=None)}
+
+ def _connectUIBehavior(self):
+ self._connect(self.ui.datasetBtn.clicked, lambda: self.__open("dataset"))
+ self._connect(self.ui.imageBtn.clicked, lambda: self.__open("image"))
+ self._connect(self.ui.maskBtn.clicked, lambda: self.__open("mask"))
+ self._connect(self.ui.captionBtn.clicked, lambda: self.__open("caption"))
+ self._connect(self.ui.bulkCaptionBtn.clicked, lambda: self.__open("bulk_caption"))
+ self._connect(self.ui.videoBtn.clicked, lambda: self.__open("video"))
+ self._connect(self.ui.convertBtn.clicked, lambda: self.__open("convert"))
+ self._connect(self.ui.samplingBtn.clicked, lambda: self.__open("sample"))
+ self._connect(self.ui.profilingBtn.clicked, lambda: self.__open("profile"))
+
+ ###Utils###
+
+ def __open(self, window):
+ if self.children[window].ui.isHidden():
+ self._openWindow(self.children[window], fixed_size=window != "dataset")
+ else:
+ self.children[window].ui.activateWindow()
diff --git a/modules/ui/controllers/tabs/TrainingController.py b/modules/ui/controllers/tabs/TrainingController.py
new file mode 100644
index 000000000..e77f8ee60
--- /dev/null
+++ b/modules/ui/controllers/tabs/TrainingController.py
@@ -0,0 +1,414 @@
+
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.controllers.windows.OptimizerController import OptimizerController
+from modules.ui.models.StateModel import StateModel
+from modules.ui.models.TimestepGenerator import TimestepGenerator
+from modules.ui.utils.FigureWidget import FigureWidget
+from modules.util.enum.DataType import DataType
+from modules.util.enum.EMAMode import EMAMode
+from modules.util.enum.GradientCheckpointingMethod import GradientCheckpointingMethod
+from modules.util.enum.LearningRateScaler import LearningRateScaler
+from modules.util.enum.LearningRateScheduler import LearningRateScheduler
+from modules.util.enum.LossScaler import LossScaler
+from modules.util.enum.LossWeight import LossWeight
+from modules.util.enum.ModelFlags import ModelFlags
+from modules.util.enum.ModelType import ModelType
+from modules.util.enum.Optimizer import Optimizer
+from modules.util.enum.TimestepDistribution import TimestepDistribution
+from modules.util.enum.TimeUnit import TimeUnit
+from modules.util.enum.TrainingMethod import TrainingMethod
+
+import PySide6.QtGui as QtGui
+import PySide6.QtWidgets as QtW
+from matplotlib import pyplot as plt
+from PySide6 import QtWidgets
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class TrainingController(BaseController):
+ state_ui_connections = {
+ "optimizer.optimizer": "optimizerCmb",
+ "learning_rate_scheduler": "schedulerCmb",
+ "learning_rate": "learningRateLed",
+ "learning_rate_warmup_steps": "warmupStepsSbx",
+ "learning_rate_min_factor": "minFactorSbx",
+ "learning_rate_cycles": "cyclesSbx",
+ "epochs": "epochsSbx",
+ "batch_size": "batchSizeSbx",
+ "gradient_accumulation_steps": "accumulationStepsSbx",
+ "learning_rate_scaler": "scalerCmb",
+ "clip_grad_norm": "clipGradNormSbx",
+
+ "ema": "emaCmb",
+ "ema_decay": "emaDecaySbx",
+ "ema_update_step_interval": "emaUpdateIntervalSbx",
+ "gradient_checkpointing": "gradientCheckpointingCmb",
+ "enable_async_offloading": "asyncOffloadCbx",
+ "enable_activation_offloading": "offloadActivationsCbx",
+ "layer_offload_fraction": "layerOffloadFractionSbx",
+ "train_dtype": "trainDTypeCmb",
+ "fallback_train_dtype": "fallbackDTypeCmb",
+ "enable_autocast_cache": "autocastCacheCbx",
+ "resolution": "resolutionLed",
+ "frames": "framesSbx",
+ "force_circular_padding": "circularPaddingCbx",
+
+ "masked_training": "maskedTrainingCbx",
+ "unmasked_probability": "unmaskedProbabilitySbx",
+ "unmasked_weight": "unmaskedWeightSbx",
+ "normalize_masked_area_loss": "normalizeMaskedAreaCbx",
+ "masked_prior_preservation_weight": "maskedPriorPreservationSbx",
+ "custom_conditioning_image": "customConditioningImageCbx",
+ "mse_strength": "mseSbx",
+ "mae_strength": "maeSbx",
+ "huber_strength": "huberStrengthSbx",
+ "huber_delta": "huberDeltaSbx",
+ "log_cosh_strength": "logcoshSbx",
+ "vb_loss_strength": "vbLossSbx",
+ "loss_weight_fn": "lossWeightFunctionCmb",
+ "loss_weight_strength": "gammaSbx",
+ "loss_scaler": "lossScalerCmb",
+
+ "layer_filter_preset": "layerFilterCmb",
+ "layer_filter": "layerFilterLed",
+ "layer_filter_regex": "layerFilterRegexCbx",
+
+ "embedding_learning_rate": "embeddingLearningRateLed",
+ "preserve_embedding_norm": "embeddingNormCbx",
+
+ "offset_noise_weight": "offsetNoiseWeightSbx",
+ "generalized_offset_noise": "generalizedOffsetNoiseCbx",
+ "perturbation_noise_weight": "perturbationNoiseWeightSbx",
+ "timestep_distribution": "timestepDistributionCmb",
+ "min_noising_strength": "minNoisingStrengthSbx",
+ "max_noising_strength": "maxNoisingStrengthSbx",
+ "noising_weight": "noisingWeightSbx",
+ "noising_bias": "noisingBiasSbx",
+ "timestep_shift": "timestepShiftSbx",
+ "dynamic_timestep_shifting": "dynamicTimestepShiftingCbx",
+
+ "text_encoder.include": "te1IncludeCbx",
+ "text_encoder.train": "te1TrainCbx",
+ "text_encoder.train_embedding": "te1TrainEmbCbx",
+ "text_encoder.dropout_probability": "te1DropoutSbx",
+ "text_encoder.stop_training_after": "te1StopTrainingSbx",
+ "text_encoder.stop_training_after_unit": "te1StopTrainingCmb",
+ "text_encoder.learning_rate": "te1LearningRateLed",
+ "text_encoder_layer_skip": "te1ClipSkipSbx",
+
+ "text_encoder_2.include": "te2IncludeCbx",
+ "text_encoder_2.train": "te2TrainCbx",
+ "text_encoder_2.train_embedding": "te2TrainEmbCbx",
+ "text_encoder_2.dropout_probability": "te2DropoutSbx",
+ "text_encoder_2.stop_training_after": "te2StopTrainingSbx",
+ "text_encoder_2.stop_training_after_unit": "te2StopTrainingCmb",
+ "text_encoder_2.learning_rate": "te2LearningRateLed",
+ "text_encoder_2_layer_skip": "te2ClipSkipSbx",
+ "text_encoder_2_sequence_length": "te2SeqLenSbx",
+
+ "text_encoder_3.include": "te3IncludeCbx",
+ "text_encoder_3.train": "te3TrainCbx",
+ "text_encoder_3.train_embedding": "te3TrainEmbCbx",
+ "text_encoder_3.dropout_probability": "te3DropoutSbx",
+ "text_encoder_3.stop_training_after": "te3StopTrainingSbx",
+ "text_encoder_3.stop_training_after_unit": "te3StopTrainingCmb",
+ "text_encoder_3.learning_rate": "te3LearningRateLed",
+ "text_encoder_3_layer_skip": "te3ClipSkipSbx",
+
+ "text_encoder_4.include": "te4IncludeCbx",
+ "text_encoder_4.train": "te4TrainCbx",
+ "text_encoder_4.train_embedding": "te4TrainEmbCbx",
+ "text_encoder_4.dropout_probability": "te4DropoutSbx",
+ "text_encoder_4.stop_training_after": "te4StopTrainingSbx",
+ "text_encoder_4.stop_training_after_unit": "te4StopTrainingCmb",
+ "text_encoder_4.learning_rate": "te4LearningRateLed",
+ "text_encoder_4_layer_skip": "te4ClipSkipSbx",
+
+ "unet.train": "unetTrainCbx",
+ "unet.stop_training_after": "unetStopSbx",
+ "unet.stop_training_after_unit": "unetStopCmb",
+ "unet.learning_rate": "unetLearningRateLed",
+ "rescale_noise_scheduler_to_zero_terminal_snr": "unetRescaleCbx",
+
+ "prior.train": "priorTrainCbx",
+ "prior.stop_training_after": "priorStopSbx",
+ "prior.stop_training_after_unit": "priorStopCmb",
+ "prior.learning_rate": "priorLearningRateLed",
+
+ "transformer.train": "transformerTrainCbx",
+ "transformer.stop_training_after": "transformerStopSbx",
+ "transformer.stop_training_after_unit": "transformerStopCmb",
+ "transformer.learning_rate": "transformerLearningRateLed",
+ "transformer.attention_mask": "transformerAttnMaskCbx",
+ "transformer.guidance_scale": "transformerGuidanceSbx",
+
+ "custom_learning_rate_scheduler": "schedulerClassLed",
+ }
+
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/tabs/training.ui", name=QCA.translate("main_window_tabs", "Training"), parent=parent)
+
+ ###FSM###
+
+ def _setup(self):
+ self.optimizer_window = OptimizerController(self.loader, parent=self)
+
+ plt.set_loglevel('WARNING') # suppress errors about data type in bar chart
+
+ self.canvas = FigureWidget(parent=self.ui, width=4, height=4, zoom_tools=True)
+ self.canvas.setFixedHeight(300)
+
+ self.ax = self.canvas.figure.subplots()
+ self.ui.previewLay.addWidget(self.canvas.toolbar) # Matplotlib toolbar, in case we want the user to zoom in.
+ self.ui.previewLay.addWidget(self.canvas)
+
+ self.ax.tick_params(axis='x', which="both")
+ self.ax.tick_params(axis='y', which="both")
+
+ def _connectUIBehavior(self):
+ self._connect(self.ui.optimizerBtn.clicked, lambda: self._openWindow(self.optimizer_window, fixed_size=True))
+
+ self._connect([self.ui.layerFilterCmb.activated, QtW.QApplication.instance().stateChanged],
+ self.__connectLayerFilter())
+
+ self._connect(QtW.QApplication.instance().stateChanged, self.__updateSchedulerParams(), update_after_connect=True)
+ self._connect(self.ui.tableWidget.currentCellChanged, self.__changeCell())
+ self._connect(self.ui.updatePreviewBtn.clicked, self.__updatePreview())
+
+ self._connect(QtW.QApplication.instance().modelChanged, self.__updateModel(), update_after_connect=True,
+ initial_args=[StateModel.instance().get_state("model_type"), StateModel.instance().get_state("training_method")])
+
+
+ self._connect([self.ui.schedulerCmb.activated, QtW.QApplication.instance().stateChanged],
+ self.__enableCustomScheduler(), update_after_connect=True)
+
+ self._connect([self.ui.maskedTrainingCbx.toggled, QtW.QApplication.instance().stateChanged],
+ self.__enableMaskedTraining(), update_after_connect=True)
+
+ self._connect([self.ui.optimizerCmb.activated, QtW.QApplication.instance().stateChanged], self.__updateOptimizer(), update_after_connect=True)
+
+ # At the beginning invalidate the gui.
+ self.optimizer_window.ui.optimizerCmb.setCurrentIndex(self.ui.optimizerCmb.currentIndex())
+
+
+ def _connectInputValidation(self):
+ self.ui.resolutionLed.setValidator(QtGui.QRegularExpressionValidator(r"\d+(x\d+(,\d+x\d+)*)?", self.ui))
+ self._connect(self.ui.minNoisingStrengthSbx.valueChanged, self.__validateNoisingStrength("min"))
+ self._connect(self.ui.maxNoisingStrengthSbx.valueChanged, self.__validateNoisingStrength("max"))
+
+ self._connectScientificNotation(self.ui.learningRateLed, min=0.0)
+ self._connectScientificNotation(self.ui.te1LearningRateLed, min=0.0)
+ self._connectScientificNotation(self.ui.te2LearningRateLed, min=0.0)
+ self._connectScientificNotation(self.ui.te3LearningRateLed, min=0.0)
+ self._connectScientificNotation(self.ui.te4LearningRateLed, min=0.0)
+ self._connectScientificNotation(self.ui.unetLearningRateLed, min=0.0)
+ self._connectScientificNotation(self.ui.transformerLearningRateLed, min=0.0)
+ self._connectScientificNotation(self.ui.priorLearningRateLed, min=0.0)
+ self._connectScientificNotation(self.ui.embeddingLearningRateLed, min=0.0)
+
+ def _loadPresets(self):
+ for ui_name in ["unetStopCmb", "te1StopTrainingCmb", "te2StopTrainingCmb", "te3StopTrainingCmb", "te4StopTrainingCmb",
+ "priorStopCmb", "transformerStopCmb"]:
+ ui_elem = self.ui.findChild(QtWidgets.QComboBox, ui_name)
+ for e in TimeUnit.enabled_values():
+ ui_elem.addItem(e.pretty_print(), userData=e)
+
+ for e in TimestepDistribution.enabled_values():
+ self.ui.timestepDistributionCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in LossScaler.enabled_values():
+ self.ui.lossScalerCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in GradientCheckpointingMethod.enabled_values():
+ self.ui.gradientCheckpointingCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in EMAMode.enabled_values():
+ self.ui.emaCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in LearningRateScaler.enabled_values():
+ self.ui.scalerCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in Optimizer.enabled_values():
+ self.ui.optimizerCmb.addItem(e.pretty_print(), userData=e)
+ self.optimizer_window.ui.optimizerCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in LearningRateScheduler.enabled_values():
+ self.ui.schedulerCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in DataType.enabled_values(context="training_dtype"):
+ self.ui.trainDTypeCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in DataType.enabled_values(context="training_fallback"):
+ self.ui.fallbackDTypeCmb.addItem(e.pretty_print(), userData=e)
+
+
+
+ ###Reactions###
+
+ def __updatePreview(self):
+ @Slot()
+ def f():
+ generator = TimestepGenerator(
+ timestep_distribution=StateModel.instance().get_state("timestep_distribution"),
+ min_noising_strength=StateModel.instance().get_state("min_noising_strength"),
+ max_noising_strength=StateModel.instance().get_state("max_noising_strength"),
+ noising_weight=StateModel.instance().get_state("noising_weight"),
+ noising_bias=StateModel.instance().get_state("noising_bias"),
+ timestep_shift=StateModel.instance().get_state("timestep_shift"),
+ )
+
+ self.ax.cla()
+ self.ax.hist(generator.generate(), bins=1000, range=(0, 999))
+ self.canvas.draw_idle()
+ return f
+
+
+ def __updateOptimizer(self):
+ @Slot()
+ def f():
+ self.optimizer_window.ui.optimizerCmb.setCurrentIndex(self.ui.optimizerCmb.currentIndex())
+ QtW.QApplication.instance().optimizerChanged.emit(self.ui.optimizerCmb.currentData())
+ return f
+
+ def __updateModel(self):
+ @Slot(ModelType, TrainingMethod)
+ def f(model_type, training_method):
+ flags = ModelFlags.getFlags(model_type, training_method)
+ presets = ModelFlags.getPresets(model_type)
+
+
+ self.ui.layerFilterCmb.clear()
+ for k, v in presets.items():
+ self.ui.layerFilterCmb.addItem(k, userData=v)
+ self.ui.layerFilterCmb.addItem("custom", userData=[])
+
+ self.ui.lossWeightFunctionCmb.clear()
+ for e in LossWeight.enabled_values("flow_matching" if model_type.is_flow_matching() else ""):
+ self.ui.lossWeightFunctionCmb.addItem(e.pretty_print(), userData=e)
+
+ self.ui.gammaSbx.setVisible(not model_type.is_flow_matching())
+
+ self.ui.te2SeqLenLbl.setVisible(ModelFlags.OVERRIDE_SEQUENCE_LENGTH_TE2 in flags)
+ self.ui.te2SeqLenSbx.setVisible(ModelFlags.OVERRIDE_SEQUENCE_LENGTH_TE2 in flags)
+
+ self.ui.te1Gbx.setVisible(ModelFlags.TE1 in flags)
+ self.ui.te2Gbx.setVisible(ModelFlags.TE2 in flags)
+ self.ui.te3Gbx.setVisible(ModelFlags.TE3 in flags)
+ self.ui.te4Gbx.setVisible(ModelFlags.TE4 in flags)
+
+ self.ui.unetGbx.setVisible(ModelFlags.UNET in flags)
+ self.ui.transformerGbx.setVisible(ModelFlags.TRANSFORMER in flags)
+ self.ui.priorGbx.setVisible(ModelFlags.TRAIN_TRANSFORMER in flags)
+
+ self.ui.generalizedOffsetNoiseCbx.setVisible(ModelFlags.GENERALIZED_OFFSET_NOISE in flags)
+
+ self.ui.te1IncludeCbx.setVisible(ModelFlags.TE_INCLUDE in flags)
+ self.ui.te2IncludeCbx.setVisible(ModelFlags.TE_INCLUDE in flags)
+ self.ui.te3IncludeCbx.setVisible(ModelFlags.TE_INCLUDE in flags)
+ self.ui.te4IncludeCbx.setVisible(ModelFlags.TE_INCLUDE in flags)
+
+
+ self.ui.vbLossLbl.setVisible(ModelFlags.VB_LOSS in flags)
+ self.ui.vbLossSbx.setVisible(ModelFlags.VB_LOSS in flags)
+
+ self.ui.transformerGuidanceLbl.setVisible(ModelFlags.GUIDANCE_SCALE in flags)
+ self.ui.transformerGuidanceSbx.setVisible(ModelFlags.GUIDANCE_SCALE in flags)
+
+ self.ui.dynamicTimestepShiftingCbx.setVisible(ModelFlags.DYNAMIC_TIMESTEP_SHIFTING in flags)
+
+ self.ui.transformerAttnMaskCbx.setVisible(ModelFlags.DISABLE_FORCE_ATTN_MASK not in flags)
+
+ self.ui.te1ClipSkipSbx.setVisible(ModelFlags.DISABLE_CLIP_SKIP not in flags)
+ self.ui.te2ClipSkipSbx.setVisible(ModelFlags.DISABLE_CLIP_SKIP not in flags)
+ self.ui.te3ClipSkipSbx.setVisible(ModelFlags.DISABLE_CLIP_SKIP not in flags)
+ self.ui.te4ClipSkipSbx.setVisible(ModelFlags.DISABLE_CLIP_SKIP not in flags)
+ self.ui.te1ClipSkipLbl.setVisible(ModelFlags.DISABLE_CLIP_SKIP not in flags)
+ self.ui.te2ClipSkipLbl.setVisible(ModelFlags.DISABLE_CLIP_SKIP not in flags)
+ self.ui.te3ClipSkipLbl.setVisible(ModelFlags.DISABLE_CLIP_SKIP not in flags)
+ self.ui.te4ClipSkipLbl.setVisible(ModelFlags.DISABLE_CLIP_SKIP not in flags)
+
+ self.ui.framesLbl.setVisible(ModelFlags.VIDEO_TRAINING in flags)
+ self.ui.framesSbx.setVisible(ModelFlags.VIDEO_TRAINING in flags)
+
+ self.ui.te4ClipSkipLbl.setVisible(ModelFlags.DISABLE_TE4_LAYER_SKIP not in flags)
+ self.ui.te4ClipSkipSbx.setVisible(ModelFlags.DISABLE_TE4_LAYER_SKIP not in flags)
+
+ return f
+
+ def __enableMaskedTraining(self):
+ @Slot()
+ def f():
+ enabled = self.ui.maskedTrainingCbx.isChecked()
+ self.ui.unmaskedProbabilityLbl.setEnabled(enabled)
+ self.ui.unmaskedProbabilitySbx.setEnabled(enabled)
+ self.ui.unmaskedWeightLbl.setEnabled(enabled)
+ self.ui.unmaskedWeightSbx.setEnabled(enabled)
+ self.ui.normalizeMaskedAreaCbx.setEnabled(enabled)
+ self.ui.maskedPriorPreservationLbl.setEnabled(enabled)
+ self.ui.maskedPriorPreservationSbx.setEnabled(enabled)
+ self.ui.customConditioningImageCbx.setEnabled(enabled)
+
+ return f
+
+ def __validateNoisingStrength(self, direction):
+ @Slot(float)
+ def f(value):
+ min = self.ui.minNoisingStrengthSbx.value()
+ max = self.ui.maxNoisingStrengthSbx.value()
+
+ if direction == "min" and min > max:
+ self.ui.minNoisingStrengthSbx.setValue(max)
+
+ if direction == "max" and max < min:
+ self.ui.maxNoisingStrengthSbx.setValue(min)
+
+ return f
+
+
+ def __changeCell(self):
+ @Slot(int, int, int, int)
+ def f(currentRow, currentColumn, previousRow, previousColumn):
+ total_rows = self.ui.tableWidget.rowCount()
+
+ key = self.ui.tableWidget.item(previousRow, 0)
+ value = self.ui.tableWidget.item(previousRow, 1)
+
+ if key is not None and value is not None and key.text() != "" and value.text() != "":
+ StateModel.instance().set_scheduler_params(previousRow, key.text(), value.text())
+
+ if previousRow == total_rows - 1 and previousColumn == 1:
+ self.ui.tableWidget.insertRow(total_rows)
+ self.ui.tableWidget.editItem(self.ui.tableWidget.item(total_rows, 0))
+ self.ui.tableWidget.setCurrentCell(total_rows, 0) # TODO: it inserts correctly a new cell, but tab selection returns to the first cell.
+
+ return f
+
+ def __updateSchedulerParams(self):
+ @Slot()
+ def f():
+ param_dict = StateModel.instance().get_state("scheduler_params")
+
+ self.ui.tableWidget.clearContents()
+ for idx, param in enumerate(param_dict):
+ self.ui.tableWidget.insertRow(idx)
+ self.ui.tableWidget.setItem(idx, 0, QtW.QTableWidgetItem(param["key"]))
+ self.ui.tableWidget.setItem(idx, 1, QtW.QTableWidgetItem(param["value"]))
+ return f
+
+ def __enableCustomScheduler(self):
+ @Slot()
+ def f():
+ self.ui.tableWidget.setEnabled(self.ui.schedulerCmb.currentData() == LearningRateScheduler.CUSTOM)
+ self.ui.schedulerClassLed.setEnabled(self.ui.schedulerCmb.currentData() == LearningRateScheduler.CUSTOM)
+ self.ui.schedulerLbl.setEnabled(self.ui.schedulerCmb.currentData() == LearningRateScheduler.CUSTOM)
+ return f
+
+
+ def __connectLayerFilter(self):
+ @Slot()
+ def f():
+ self.ui.layerFilterRegexCbx.setEnabled(self.ui.layerFilterCmb.currentText() == "custom")
+ if self.ui.layerFilterCmb.currentData() is not None:
+ self.ui.layerFilterLed.setText(",".join(self.ui.layerFilterCmb.currentData()))
+ return f
diff --git a/modules/ui/controllers/widgets/ConceptController.py b/modules/ui/controllers/widgets/ConceptController.py
new file mode 100644
index 000000000..05ff40ab6
--- /dev/null
+++ b/modules/ui/controllers/widgets/ConceptController.py
@@ -0,0 +1,48 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.ConceptModel import ConceptModel
+
+import PySide6.QtGui as QtGui
+import PySide6.QtWidgets as QtW
+from PIL.ImageQt import ImageQt
+from PySide6.QtCore import Slot
+
+
+class ConceptController(BaseController):
+ def __init__(self, loader, concept_window, idx, parent=None):
+ self.concept_window = concept_window
+ self.idx = idx
+ super().__init__(loader, "modules/ui/views/widgets/concept.ui", invalidate_once=False, name=None, parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connect(self.ui.conceptBtn.clicked, self.__openConceptWindow())
+ self._connect(self.ui.enableCbx.clicked, self.__enableConcept())
+
+ self._connect(QtW.QApplication.instance().conceptsChanged, self.__updateConcept(), update_after_connect=True)
+
+ ###Reactions###
+
+ def __openConceptWindow(self):
+ @Slot()
+ def f():
+ self._openWindow(self.concept_window, fixed_size=False)
+ QtW.QApplication.instance().openConcept.emit(self.idx)
+ return f
+
+ def __enableConcept(self):
+ @Slot()
+ def f():
+ ConceptModel.instance().set_state(f"{self.idx}.enabled", self.ui.enableCbx.isChecked())
+ QtW.QApplication.instance().conceptsChanged.emit(True)
+ return f
+
+ def __updateConcept(self):
+ @Slot()
+ def f():
+ self.ui.enableCbx.setChecked(ConceptModel.instance().get_state(f"{self.idx}.enabled"))
+ self.ui.enableCbx.setText(ConceptModel.instance().get_concept_name(self.idx))
+
+ img = ConceptModel.instance().get_preview_icon(self.idx)
+ self.ui.conceptBtn.setIcon(QtGui.QIcon(QtGui.QPixmap.fromImage(ImageQt(img))))
+ return f
diff --git a/modules/ui/controllers/widgets/EmbeddingController.py b/modules/ui/controllers/widgets/EmbeddingController.py
new file mode 100644
index 000000000..9791207b9
--- /dev/null
+++ b/modules/ui/controllers/widgets/EmbeddingController.py
@@ -0,0 +1,39 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.StateModel import StateModel
+from modules.util.enum.TimeUnit import TimeUnit
+
+import PySide6.QtWidgets as QtW
+from PySide6.QtCore import QCoreApplication as QCA
+
+
+class EmbeddingController(BaseController):
+ def __init__(self, loader, idx, parent=None):
+ self.idx = idx
+ super().__init__(loader, "modules/ui/views/widgets/embedding.ui", invalidate_once=False, name=None, parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connectFileDialog(self.ui.baseEmbeddingBtn, self.ui.baseEmbeddingLed, is_dir=False, save=False,
+ title=QCA.translate("dialog_window", "Open base embeddings"),
+ filters=QCA.translate("filetype_filters",
+ "Safetensors (*.safetensors);;Diffusers (model_index.json);;Checkpoints (*.ckpt *.pt *.bin);;All Files (*.*)"))
+
+ self.dynamic_state_ui_connections = {
+ "additional_embeddings.{idx}.model_name": "baseEmbeddingLed",
+ "additional_embeddings.{idx}.placeholder": "placeholderLed",
+ "additional_embeddings.{idx}.token_count": "tokenSbx",
+ "additional_embeddings.{idx}.train": "trainCbx",
+ "additional_embeddings.{idx}.is_output_embedding": "outputEmbeddingCbx",
+ "additional_embeddings.{idx}.stop_training_after": "stopTrainingSbx",
+ "additional_embeddings.{idx}.stop_training_after_unit": "stopTrainingCmb",
+ "additional_embeddings.{idx}.initial_embedding_text": "initialEmbeddingLed",
+ }
+
+ self._connectStateUI(self.dynamic_state_ui_connections, StateModel.instance(),
+ signal=QtW.QApplication.instance().embeddingsChanged, update_after_connect=True,
+ idx=self.idx)
+
+ def _loadPresets(self):
+ for e in TimeUnit.enabled_values():
+ self.ui.stopTrainingCmb.addItem(e.pretty_print(), userData=e)
diff --git a/modules/ui/controllers/widgets/SampleController.py b/modules/ui/controllers/widgets/SampleController.py
new file mode 100644
index 000000000..e02a55f4d
--- /dev/null
+++ b/modules/ui/controllers/widgets/SampleController.py
@@ -0,0 +1,43 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.SampleModel import SampleModel
+
+import PySide6.QtGui as QtGui
+import PySide6.QtWidgets as QtW
+from PySide6.QtCore import Slot
+
+
+class SampleController(BaseController):
+ def __init__(self, loader, sample_window, idx, parent=None):
+ self.idx = idx
+ self.sample_window = sample_window
+
+ super().__init__(loader, "modules/ui/views/widgets/sample.ui", invalidate_once=False, name=None, parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connect(self.ui.openWindowBtn.clicked, self.__openSampleWindow())
+
+ self.dynamic_state_ui_connections = {
+ "{idx}.enabled": "enabledCbx",
+ "{idx}.width": "widthSbx",
+ "{idx}.height": "heightSbx",
+ "{idx}.seed": "seedLed",
+ "{idx}.prompt": "promptLed",
+ }
+
+ self._connectStateUI(self.dynamic_state_ui_connections, SampleModel.instance(),
+ signal=QtW.QApplication.instance().samplesChanged, update_after_connect=True, idx=self.idx)
+
+ def _connectInputValidation(self):
+ # We use regular expressions, instead of QIntValidator, to avoid hitting the maximum value.
+ self.ui.seedLed.setValidator(QtGui.QRegularExpressionValidator(r"-1|0|[1-9]\d*", self.ui))
+
+ ###Reactions###
+
+ def __openSampleWindow(self):
+ @Slot()
+ def f():
+ self._openWindow(self.sample_window, fixed_size=True)
+ QtW.QApplication.instance().openSample.emit(self.idx)
+ return f
diff --git a/modules/ui/controllers/widgets/SampleParamsController.py b/modules/ui/controllers/widgets/SampleParamsController.py
new file mode 100644
index 000000000..b5b9b082c
--- /dev/null
+++ b/modules/ui/controllers/widgets/SampleParamsController.py
@@ -0,0 +1,111 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.util.enum.NoiseScheduler import NoiseScheduler
+
+import PySide6.QtGui as QtGui
+import PySide6.QtWidgets as QtW
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import QObject, Slot
+
+
+class SampleParamsController(BaseController):
+ idx = 0
+ def __init__(self, loader, model_instance, read_signal=None, write_signal=None, parent=None):
+ self.model_instance = model_instance
+ self.read_signal = read_signal
+ self.write_signal = write_signal
+ self.idx = None
+
+ super().__init__(loader, "modules/ui/views/widgets/sampling_params.ui", invalidate_once=False, name=None, parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connectFileDialog(self.ui.imagePathBtn, self.ui.imagePathLed, is_dir=False, save=False,
+ title=QCA.translate("dialog_window", "Open Base Image"),
+ filters=QCA.translate("filetype_filters", "Image (*.jpg *.jpeg *.tif *.png *.webp)"))
+ self._connectFileDialog(self.ui.maskPathBtn, self.ui.maskPathLed, is_dir=False, save=False,
+ title=QCA.translate("dialog_window", "Open Mask Image"),
+ filters=QCA.translate("filetype_filters",
+ "Image (*.jpg *.jpeg *.tif *.png *.webp)"))
+
+ self.dynamic_state_ui_connections = {
+ "prompt": "promptLed",
+ "negative_prompt": "negativePromptLed",
+ "width": "widthSbx",
+ "height": "heightSbx",
+ "frames": "framesSbx",
+ "length": "lengthSbx",
+ "seed": "seedLed",
+ "random_seed": "randomSeedCbx",
+ "cfg_scale": "cfgSbx",
+ "diffusion_steps": "stepsSbx",
+ "noise_scheduler": "samplerCmb",
+ "sample_inpainting": "inpaintingCbx",
+ "base_image_path": "imagePathLed",
+ "mask_image_path": "maskPathLed",
+ }
+
+ # Since data should be read/write based on parent widget's signals, operations are performed in bulk, rather than connecting each ui element individually.
+ if self.read_signal is not None: # If we have a dynamic connection, we connect the signal to the update.
+ self._connect(self.read_signal, self.__readControls())
+ if self.write_signal is not None: # If we have a dynamic connection, we connect the signal to the update.
+ self._connect(self.write_signal, self.__writeControls())
+
+ def _loadPresets(self):
+ for e in NoiseScheduler.enabled_values():
+ self.ui.samplerCmb.addItem(e.pretty_print(), userData=e)
+
+ def _connectInputValidation(self):
+ # We use regular expressions, instead of QIntValidator, to avoid hitting the maximum value.
+ self.ui.seedLed.setValidator(QtGui.QRegularExpressionValidator(r"-1|0|[1-9]\d*", self.ui))
+
+ ###Reactions###
+
+ def __readControls(self):
+ def f(idx=None):
+ self.idx = idx
+ if idx is None:
+ data = self.model_instance.bulk_read(*self.dynamic_state_ui_connections, as_dict=True)
+ else:
+ data = self.model_instance.bulk_read(*[f"{self.idx}.{k}" for k in self.dynamic_state_ui_connections], as_dict=True)
+
+ for k, v in self.dynamic_state_ui_connections.items():
+ if self.idx is not None:
+ k = f"{self.idx}.{k}"
+
+ wdg = self.ui.findChild(QObject, v)
+ if data[k] is not None:
+ if isinstance(wdg, QtW.QCheckBox):
+ wdg.setChecked(data[k])
+ elif isinstance(wdg, QtW.QComboBox):
+ i = wdg.findData(data[k])
+ if i != -1:
+ wdg.setCurrentIndex(i)
+ elif isinstance(wdg, (QtW.QSpinBox, QtW.QDoubleSpinBox)):
+ wdg.setValue(float(data[k]))
+ elif isinstance(wdg, QtW.QLineEdit):
+ wdg.setText(str(data[k]))
+
+ return f
+
+ def __writeControls(self):
+ @Slot()
+ def f():
+ data = {}
+ for k, v in self.dynamic_state_ui_connections.items():
+ if self.idx is not None:
+ k = f"{self.idx}.{k}"
+
+ wdg = self.ui.findChild(QObject, v)
+ if isinstance(wdg, QtW.QCheckBox):
+ data[k] = wdg.isChecked()
+ elif isinstance(wdg, QtW.QComboBox):
+ data[k] = wdg.currentData()
+ elif isinstance(wdg, (QtW.QSpinBox, QtW.QDoubleSpinBox)):
+ data[k] = wdg.value()
+ elif isinstance(wdg, QtW.QLineEdit):
+ data[k] = wdg.text()
+
+
+ self.model_instance.bulk_write(data)
+ return f
diff --git a/modules/ui/controllers/windows/BulkCaptionController.py b/modules/ui/controllers/windows/BulkCaptionController.py
new file mode 100644
index 000000000..d352fc2ad
--- /dev/null
+++ b/modules/ui/controllers/windows/BulkCaptionController.py
@@ -0,0 +1,91 @@
+import os
+
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.BulkCaptionModel import BulkCaptionModel
+from modules.ui.utils.WorkerPool import WorkerPool
+from modules.util.enum.BulkEditMode import BulkEditMode
+
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class BulkCaptionController(BaseController):
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/windows/bulk_caption.ui", name=None, parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connectFileDialog(self.ui.directoryBtn, self.ui.directoryLed, is_dir=True, save=False, title=
+ QCA.translate("dialog_window", "Open Dataset directory"))
+
+ state_ui_connections = {
+ "directory": "directoryLed",
+ "add_text": "addLed",
+ "add_mode": "addCmb",
+ "remove_text": "removeLed",
+ "replace_text": "replaceLed",
+ "replace_with": "replaceWithLed",
+ "regex_pattern": "regexLed",
+ "regex_replace": "regexWithLed",
+ }
+
+ self._connectStateUI(state_ui_connections, BulkCaptionModel.instance(), update_after_connect=True)
+ self._connect(self.ui.applyBtn.clicked, self.__startProcessFiles(read_only=False))
+ self._connect(self.ui.previewBtn.clicked, self.__startProcessFiles(read_only=True))
+
+ self.__enableControls(True)()
+
+
+ def _loadPresets(self):
+ for e in BulkEditMode.enabled_values():
+ self.ui.addCmb.addItem(e.pretty_print(), userData=e)
+
+ ###Reactions###
+ def __updateStatus(self):
+ @Slot(dict)
+ def f(data):
+ if "status" in data:
+ self.ui.statusLbl.setText(data["status"])
+
+ if "data" in data:
+ self.ui.previewTed.setPlainText(data["data"])
+ return f
+
+ def __startProcessFiles(self, read_only):
+ @Slot()
+ def f():
+ if self.ui.directoryLed.text() != "":
+ if os.path.isdir(self.ui.directoryLed.text()):
+ worker, name = WorkerPool.instance().createNamed(self.__processFiles(read_only), "process_bulk_captions", inject_progress_callback=True)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableControls(False), result_fn=None,
+ finished_fn=self.__enableControls(True),
+ errored_fn=self.__enableControls(True), aborted_fn=self.__enableControls(True),
+ progress_fn=self.__updateStatus())
+ WorkerPool.instance().start(name)
+ else:
+ self._openAlert(QCA.translate("image_window", "Invalid Folder"),
+ QCA.translate("image_window", "The selected input folder does not exist"),
+ type="critical")
+ else:
+ self._openAlert(QCA.translate("image_window", "No Folder Selected"),
+ QCA.translate("image_window", "Please select an input folder"))
+
+ return f
+
+ def __enableControls(self, enabled):
+ @Slot()
+ def f():
+ self.ui.applyBtn.setEnabled(enabled)
+ self.ui.previewBtn.setEnabled(enabled)
+ return f
+
+
+ ###Utils###
+
+ def __processFiles(self, read_only):
+ def f(progress_fn=None):
+ return BulkCaptionModel.instance().bulk_edit(read_only=read_only, preview_n=10, progress_fn=progress_fn)
+
+ return f
diff --git a/modules/ui/controllers/windows/BulkImageController.py b/modules/ui/controllers/windows/BulkImageController.py
new file mode 100644
index 000000000..ce5ffa872
--- /dev/null
+++ b/modules/ui/controllers/windows/BulkImageController.py
@@ -0,0 +1,119 @@
+import os
+
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.BulkImageModel import BulkImageModel
+from modules.ui.utils.WorkerPool import WorkerPool
+from modules.util.enum.ImageMegapixels import ImageMegapixels
+from modules.util.enum.ImageOptimization import ImageOptimization
+
+import PySide6.QtGui as QtGui
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class BulkImageController(BaseController):
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/windows/bulk_image.ui", name=None, parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connectFileDialog(self.ui.directoryBtn, self.ui.directoryLed, is_dir=True, save=False, title=
+ QCA.translate("dialog_window", "Open Dataset directory"))
+
+ state_ui_connections = {
+ "directory": "directoryLed",
+ "verify_images": "verifyCbx",
+ "sequential_rename": "renameCbx",
+ "process_alpha": "replaceColorCbx",
+ "resize_large_images": "resizeCbx",
+ "resize_megapixels": "resizeCmb",
+ "alpha_bg_color": "colorLed",
+ "optimization_type": "optimizationCmb",
+ "resize_custom_megapixels": "customSbx",
+ }
+
+ self._connectStateUI(state_ui_connections, BulkImageModel.instance(), update_after_connect=True)
+ self._connect(self.ui.processBtn.clicked, self.__startProcessFiles())
+ self._connect(self.ui.cancelBtn.clicked, self.__stopProcessFiles())
+ self._connect(self.ui.resizeCmb.activated, self.__enableCustom(), update_after_connect=True)
+
+ self.__enableControls(True)()
+
+ def _loadPresets(self):
+ for e in ImageOptimization.enabled_values():
+ self.ui.optimizationCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in ImageMegapixels.enabled_values():
+ self.ui.resizeCmb.addItem(e.pretty_print(), userData=e)
+
+
+ def _connectInputValidation(self):
+ self.ui.colorLed.setValidator(QtGui.QRegularExpressionValidator("random|-1|#[0-9a-f]{6}|[a-z]+", self.ui))
+
+ ###Reactions###
+
+ def __startProcessFiles(self):
+ @Slot()
+ def f():
+ self.ui.statusTed.setPlainText("")
+
+ if self.ui.directoryLed.text() != "":
+ if os.path.isdir(self.ui.directoryLed.text()):
+ worker, name = WorkerPool.instance().createNamed(self.__processFiles(), "process_images", abort_flag=BulkImageModel.instance().abort_flag, inject_progress_callback=True)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableControls(False), result_fn=None,
+ finished_fn=self.__enableControls(True),
+ errored_fn=self.__enableControls(True), aborted_fn=self.__enableControls(True),
+ progress_fn=self.__updateStatus())
+ WorkerPool.instance().start(name)
+ else:
+ self._openAlert(QCA.translate("image_window", "Invalid Folder"),
+ QCA.translate("image_window", "The selected input folder does not exist"),
+ type="critical")
+ else:
+ self._openAlert(QCA.translate("image_window", "No Folder Selected"),
+ QCA.translate("image_window", "Please select an input folder"))
+
+ return f
+
+ def __stopProcessFiles(self):
+ @Slot()
+ def f():
+ BulkImageModel.instance().abort_flag.set()
+ return f
+
+ def __enableControls(self, enabled):
+ @Slot()
+ def f():
+ self.ui.processBtn.setEnabled(enabled)
+ self.ui.cancelBtn.setEnabled(not enabled)
+ if enabled:
+ self.ui.progressBar.setValue(0)
+ return f
+
+ def __enableCustom(self):
+ @Slot()
+ def f():
+ self.ui.customLbl.setEnabled(self.ui.resizeCmb.currentData() == ImageMegapixels.CUSTOM)
+ self.ui.customSbx.setEnabled(self.ui.resizeCmb.currentData() == ImageMegapixels.CUSTOM)
+ return f
+
+ ###Utils###
+
+ def __processFiles(self):
+ def f(progress_fn=None):
+ return BulkImageModel.instance().process_files(progress_fn=progress_fn)
+ return f
+
+ def __updateStatus(self):
+ progress_fn = self._updateProgress(self.ui.progressBar)
+ def f(data):
+ if "status" in data:
+ self.ui.statusLbl.setText(data["status"])
+
+ if "data" in data:
+ self.ui.statusTed.setPlainText(self.ui.statusTed.toPlainText() + "\n" + data["data"])
+
+ progress_fn(data)
+ return f
diff --git a/modules/ui/controllers/windows/CaptionController.py b/modules/ui/controllers/windows/CaptionController.py
new file mode 100644
index 000000000..961b63ac3
--- /dev/null
+++ b/modules/ui/controllers/windows/CaptionController.py
@@ -0,0 +1,81 @@
+import os
+
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.CaptionModel import CaptionModel
+from modules.ui.utils.WorkerPool import WorkerPool
+from modules.util.enum.GenerateCaptionsModel import GenerateCaptionsAction, GenerateCaptionsModel
+
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class CaptionController(BaseController):
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/windows/generate_caption.ui", name=None, parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connectFileDialog(self.ui.folderBtn, self.ui.folderLed, is_dir=True, save=False, title=
+ QCA.translate("dialog_window", "Open Dataset directory"))
+
+ state_ui_connections = {
+ "model": "modelCmb",
+ "path": "folderLed",
+ "caption": "initialCaptionLed",
+ "prefix": "captionPrefixLed",
+ "postfix": "captionPostfixLed",
+ "mode": "modeCmb",
+ "include_subdirectories": "includeSubfolderCbx"
+ }
+
+ self._connectStateUI(state_ui_connections, CaptionModel.instance(), update_after_connect=True)
+ self._connect(self.ui.createMaskBtn.clicked, self.__startCaption())
+
+ self.__enableControls(True)()
+
+ def _loadPresets(self):
+ for e in GenerateCaptionsModel.enabled_values():
+ self.ui.modelCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in GenerateCaptionsAction.enabled_values():
+ self.ui.modeCmb.addItem(e.pretty_print(), userData=e)
+
+ ###Reactions###
+
+ def __startCaption(self):
+ @Slot()
+ def f():
+ if self.ui.folderLed.text() != "":
+ if os.path.isdir(self.ui.folderLed.text()):
+ worker, name = WorkerPool.instance().createNamed(self.__createCaption(), "create_caption", inject_progress_callback=True)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableControls(False), result_fn=None,
+ finished_fn=self.__enableControls(True),
+ errored_fn=self.__enableControls(True), aborted_fn=self.__enableControls(True),
+ progress_fn=self._updateProgress(self.ui.progressBar))
+ WorkerPool.instance().start(name)
+ else:
+ self._openAlert(QCA.translate("caption_window", "Invalid Folder"),
+ QCA.translate("caption_window", "The selected input folder does not exist"), type="critical")
+ else:
+ self._openAlert(QCA.translate("caption_window", "No Folder Selected"),
+ QCA.translate("caption_window", "Please select an input folder"))
+
+ return f
+
+ def __enableControls(self, enabled):
+ @Slot()
+ def f():
+ self.ui.createMaskBtn.setEnabled(enabled)
+ if enabled:
+ self.ui.progressBar.setValue(0)
+ return f
+
+ ###Utils###
+
+ def __createCaption(self):
+ def f(progress_fn=None):
+ return CaptionModel.instance().create_captions(progress_fn=progress_fn)
+
+ return f
diff --git a/modules/ui/controllers/windows/ConceptController.py b/modules/ui/controllers/windows/ConceptController.py
new file mode 100644
index 000000000..84f174de8
--- /dev/null
+++ b/modules/ui/controllers/windows/ConceptController.py
@@ -0,0 +1,316 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.ConceptModel import ConceptModel
+from modules.ui.utils.FigureWidget import FigureWidget
+from modules.ui.utils.WorkerPool import WorkerPool
+from modules.util.enum.BalancingStrategy import BalancingStrategy
+from modules.util.enum.ConceptType import ConceptType
+from modules.util.enum.DropoutMode import DropoutMode
+from modules.util.enum.PromptSource import PromptSource
+from modules.util.enum.SpecialDropoutTags import SpecialDropoutTags
+
+import PySide6.QtGui as QtGui
+import PySide6.QtWidgets as QtW
+from matplotlib import pyplot as plt
+from PIL.ImageQt import ImageQt
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class ConceptController(BaseController):
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/windows/concept.ui", name=None, parent=parent)
+
+ ###FSM###
+
+ def _setup(self):
+ self.idx = 0
+ self.file_index = 0
+
+ plt.set_loglevel('WARNING') # suppress errors about data type in bar chart
+
+ self.canvas = FigureWidget(parent=self.ui, width=7, height=3, zoom_tools=True)
+ self.bucket_ax = self.canvas.figure.subplots()
+ self.ui.histogramLay.addWidget(self.canvas.toolbar) # Matplotlib toolbar, in case we want the user to zoom in.
+ self.ui.histogramLay.addWidget(self.canvas)
+
+ def _connectUIBehavior(self):
+ self._connectFileDialog(self.ui.pathBtn, self.ui.pathLed, is_dir=True, title=QCA.translate("dialog_window", "Open Dataset directory"))
+ self._connectFileDialog(self.ui.promptSourceBtn, self.ui.promptSourceLed, is_dir=False,
+ title=QCA.translate("dialog_window", "Open Prompt Source"),
+ filters=QCA.translate("filetype_filters",
+ "Text (*.txt)"))
+
+ self._connect(QtW.QApplication.instance().openConcept, self.__updateConcept())
+ self._connect(QtW.QApplication.instance().openConcept, self.__updateStats())
+ self._connect(self.ui.okBtn.clicked, self.__saveConcept())
+
+
+ self._connect(QtW.QApplication.instance().openConcept, self.__updateImage())
+
+ self.dynamic_state_ui_connections = {
+ # General tab.
+ "{idx}.enabled": "enabledCbx",
+ "{idx}.type": "conceptTypeCmb",
+ "{idx}.path": "pathLed",
+ "{idx}.text.prompt_source": "promptSourceCmb",
+ "{idx}.text.prompt_path": "promptSourceLed",
+ "{idx}.include_subdirectories": "includeSubdirectoriesCbx",
+ "{idx}.image_variations": "imageVariationsSbx",
+ "{idx}.text_variations": "textVariationsSbx",
+ "{idx}.balancing": "balancingSbx",
+ "{idx}.balancing_strategy": "balancingCmb",
+ "{idx}.loss_weight": "lossWeightSbx",
+ # Image augmentation tab.
+ "{idx}.image.enable_crop_jitter": "rndJitterCbx",
+ "{idx}.image.enable_random_flip": "rndFlipCbx",
+ "{idx}.image.enable_fixed_flip": "fixFlipCbx",
+ "{idx}.image.enable_random_rotate": "rndRotationCbx",
+ "{idx}.image.enable_fixed_rotate": "fixRotationCbx",
+ "{idx}.image.random_rotate_max_angle": "rotationSbx",
+ "{idx}.image.enable_random_brightness": "rndBrightnessCbx",
+ "{idx}.image.enable_fixed_brightness": "fixBrightnessCbx",
+ "{idx}.image.random_brightness_max_strength": "brightnessSbx",
+ "{idx}.image.enable_random_contrast": "rndContrastCbx",
+ "{idx}.image.enable_fixed_contrast": "fixContrastCbx",
+ "{idx}.image.random_contrast_max_strength": "contrastSbx",
+ "{idx}.image.enable_random_saturation": "rndSaturationCbx",
+ "{idx}.image.enable_fixed_saturation": "fixSaturationCbx",
+ "{idx}.image.random_saturation_max_strength": "saturationSbx",
+ "{idx}.image.enable_random_hue": "rndHueCbx",
+ "{idx}.image.enable_fixed_hue": "fixHueCbx",
+ "{idx}.image.random_hue_max_strength": "hueSbx",
+ "{idx}.image.enable_resolution_override": "fixResolutionOverrideCbx",
+ "{idx}.image.resolution_override": "resolutionOverrideLed",
+ "{idx}.image.enable_random_circular_mask_shrink": "rndCircularMaskCbx",
+ "{idx}.image.enable_random_mask_rotate_crop": "rndRotateCropCbx",
+ # Text augmentation tab.
+ "{idx}.text.enable_tag_shuffling": "tagShufflingCbx",
+ "{idx}.text.tag_delimiter": "tagDelimiterLed",
+ "{idx}.text.keep_tags_count": "keepTagCountSbx",
+ "{idx}.text.tag_dropout_enable": "tagDropoutCbx",
+ "{idx}.text.tag_dropout_mode": "dropoutModeCmb",
+ "{idx}.text.tag_dropout_probability": "dropoutProbabilitySbx",
+ "{idx}.text.tag_dropout_special_tags_mode": "specialDropoutTagsCmb",
+ "{idx}.text.tag_dropout_special_tags": "specialDropoutTagsLed",
+ "{idx}.text.tag_dropout_special_tags_regex": "specialTagsRegexCbx",
+ "{idx}.text.caps_randomize_enable": "randomizeCapitalizationCbx",
+ "{idx}.text.caps_randomize_probability": "capitalizationProbabilitySbx",
+ "{idx}.text.caps_randomize_mode": "capitalizationModeLed",
+ "{idx}.text.caps_randomize_lowercase": "forceLowercaseCbx",
+ }
+
+ self._connect(QtW.QApplication.instance().openConcept, self.__reconnectControls())
+
+ self._connect(self.ui.promptSourceCmb.activated, self.__enablePromptSource())
+ self._connect(self.ui.refreshBasicBtn.clicked, self.__startScan(advanced_scanning=False))
+ self._connect(self.ui.refreshAdvancedBtn.clicked, self.__startScan(advanced_scanning=True))
+ self._connect(self.ui.abortScanBtn.clicked, self.__abortScan())
+ self._connect(self.ui.downloadNowBtn.clicked, self.__startDownload())
+ self._connect(self.ui.prevBtn.clicked, self.__prevImage())
+ self._connect(self.ui.nextBtn.clicked, self.__nextImage())
+
+ self._connect([self.ui.prevBtn.clicked, self.ui.nextBtn.clicked, self.ui.updatePreviewBtn.clicked],
+ self.__updateImage())
+
+ self.__enableDownloadBtn(True)()
+ self.__enableScanBtn(True)()
+
+
+ def _loadPresets(self):
+ for e in PromptSource.enabled_values():
+ self.ui.promptSourceCmb.addItem(e.pretty_print(), userData=str(e)) # ConceptConfig serializes string, not enum
+
+ for e in DropoutMode.enabled_values():
+ self.ui.dropoutModeCmb.addItem(e.pretty_print(), userData=str(e)) # ConceptConfig serializes string, not enum
+
+ for e in SpecialDropoutTags.enabled_values():
+ self.ui.specialDropoutTagsCmb.addItem(e.pretty_print(), userData=str(e)) # ConceptConfig serializes string, not enum
+
+ for e in BalancingStrategy.enabled_values():
+ self.ui.balancingCmb.addItem(e.pretty_print(), userData=e)
+
+ # This always allows Prior Validation concepts, even when LORA is not selected. (The behavior is the same as original OneTrainer, delegating checks to non-ui methods).
+ for e in ConceptType.enabled_values(context="prior_pred_enabled"):
+ self.ui.conceptTypeCmb.addItem(e.pretty_print(), userData=e)
+
+ def _connectInputValidation(self):
+ self.ui.resolutionOverrideLed.setValidator(QtGui.QRegularExpressionValidator(r"\d+(x\d+(,\d+x\d+)*)?", self.ui))
+
+ ###Reactions###
+
+ def __reconnectControls(self):
+ @Slot()
+ def f():
+ self._disconnectGroup("idx")
+ self._connectStateUI(self.dynamic_state_ui_connections, ConceptModel.instance(), signal=QtW.QApplication.instance().openConcept, group="idx", update_after_connect=True, idx=self.idx)
+ self._invalidateUI()
+ return f
+
+ def __updateStats(self):
+ @Slot()
+ def f():
+ self.__enableScanBtn(True)()
+ stats_dict = ConceptModel.instance().pretty_print_stats(self.idx)
+
+ for k, v in {
+ "fileSizeLbl": "file_size",
+ "processingTimeLbl": "processing_time",
+ "dirCountLbl": "dir_count",
+ "imageCountLbl": "image_count",
+ "imageCountMaskLbl": "image_count_mask",
+ "imageCountCaptionLbl": "image_count_caption",
+ "videoCountLbl": "video_count",
+ "videoCountCaptionLbl": "video_count_caption",
+ "maskCountLbl": "mask_count",
+ "maskCountUnpairedLbl": "mask_count_unpaired",
+ "captionCountLbl": "caption_count",
+ "unpairedCaptionsLbl": "unpaired_captions",
+ "maxPixelsLbl": "max_pixels",
+ "avgPixelsLbl": "avg_pixels",
+ "minPixelsLbl": "min_pixels",
+ "lengthMaxLbl": "length_max",
+ "lengthAvgLbl": "length_avg",
+ "lengthMinLbl": "length_min",
+ "fpsMaxLbl": "fps_max",
+ "fpsAvgLbl": "fps_avg",
+ "fpsMinLbl": "fps_min",
+ "captionMaxLbl": "caption_max",
+ "captionAvgLbl": "caption_avg",
+ "captionMinLbl": "caption_min",
+ "smallBucketLbl": "small_bucket",
+ }.items():
+ self.ui.findChild(QtW.QLabel, k).setText(str(stats_dict[v]))
+
+ self.__updateHistogram(stats_dict)
+ return f
+
+ def __startScan(self, advanced_scanning):
+ @Slot()
+ def f():
+ worker, name = WorkerPool.instance().createNamed(self.__scanConcept(), "scan_concept", abort_flag=ConceptModel.instance().cancel_scan_flag, advanced_scanning=advanced_scanning)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableScanBtn(False), result_fn=None, finished_fn=self.__updateStats(), errored_fn=self.__enableScanBtn(True), aborted_fn=self.__enableScanBtn(True))
+ WorkerPool.instance().start(name)
+
+ return f
+
+ def __startDownload(self):
+ @Slot()
+ def f():
+ worker, name = WorkerPool.instance().createNamed(self.__downloadConcept(), "download_concept")
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableDownloadBtn(False), result_fn=None, finished_fn=self.__enableDownloadBtn(True),
+ errored_fn=self.__enableDownloadBtn(True), aborted_fn=self.__enableDownloadBtn(True))
+ WorkerPool.instance().start(name)
+ return f
+
+ def __downloadConcept(self):
+ @Slot()
+ def f():
+ ConceptModel.instance().download_dataset(self.idx)
+ return f
+
+ def __abortScan(self):
+ @Slot()
+ def f():
+ ConceptModel.instance().cancel_scan_flag.set()
+ return f
+
+ def __enablePromptSource(self):
+ @Slot(int)
+ def f(value):
+ if self.ui.promptSourceCmb.currentData() != "concept": # TODO: Replace with "PromptSource.CONCEPT" when ConceptConfig will accept enum instead of string.
+ self.ui.promptSourceLed.setEnabled(False)
+ self.ui.promptSourceBtn.setEnabled(False)
+ else:
+ self.ui.promptSourceLed.setEnabled(True)
+ self.ui.promptSourceBtn.setEnabled(True)
+ return f
+
+ def __prevImage(self):
+ @Slot()
+ def f():
+ image_count = ConceptModel.instance().get_state(f"{self.idx}.concept_stats.image_count")
+ if image_count is not None and image_count > 0:
+ self.file_index = (self.file_index + image_count - 1) % image_count
+ else:
+ self.file_index = max(0, self.file_index - 1)
+ return f
+
+ def __nextImage(self):
+ @Slot()
+ def f():
+ image_count = ConceptModel.instance().get_state(f"{self.idx}.concept_stats.image_count")
+ if image_count is not None and image_count > 0:
+ self.file_index = (self.file_index + 1) % image_count
+ else:
+ self.file_index += 1
+ return f
+
+ def __updateImage(self):
+ @Slot()
+ def f():
+ img, filename, caption = ConceptModel.instance().get_image(self.idx, self.file_index, show_augmentations=self.ui.showAugmentationsCbx.isChecked())
+ self.ui.previewLbl.setPixmap(QtGui.QPixmap.fromImage(ImageQt(img)))
+ self.ui.filenameLbl.setText(filename)
+ self.ui.promptTed.setPlainText(caption)
+ return f
+
+ def __updateConcept(self):
+ @Slot(int)
+ def f(idx):
+ self.idx = idx
+ self.file_index = 0
+
+ self.ui.nameLed.setText(ConceptModel.instance().get_concept_name(self.idx)) # Name has a different logic than other controls and cannot exploit the connection dictionary.
+ return f
+
+ def __saveConcept(self):
+ @Slot()
+ def f():
+ ConceptModel.instance().set_state(f"{self.idx}.name", self.ui.nameLed.text())
+
+ # No need to store statistics, as they are handled directly by the model.
+ QtW.QApplication.instance().conceptsChanged.emit(True)
+ self.ui.hide()
+ return f
+
+ def __enableDownloadBtn(self, enabled):
+ @Slot()
+ def f():
+ self.ui.downloadNowBtn.setEnabled(enabled)
+ return f
+
+ def __enableScanBtn(self, enabled):
+ @Slot()
+ def f():
+ self.ui.refreshBasicBtn.setEnabled(enabled)
+ self.ui.refreshAdvancedBtn.setEnabled(enabled)
+ self.ui.abortScanBtn.setEnabled(not enabled)
+ return f
+
+ ###Utils###
+
+ def __updateHistogram(self, stats_dict):
+ self.bucket_ax.cla()
+ self.canvas.figure.tight_layout()
+ self.canvas.figure.subplots_adjust(bottom=0.15)
+ self.bucket_ax.spines['top'].set_visible(False)
+ self.bucket_ax.tick_params(axis='x', which="both")
+ self.bucket_ax.tick_params(axis='y', which="both")
+ aspects = [str(x) for x in list(stats_dict["aspect_buckets"].keys())]
+ aspect_ratios = [ConceptModel.instance().decimal_to_aspect_ratio(x) for x in
+ list(stats_dict["aspect_buckets"].keys())]
+ counts = list(stats_dict["aspect_buckets"].values())
+ b = self.bucket_ax.bar(aspect_ratios, counts)
+ self.bucket_ax.bar_label(b)
+ sec = self.bucket_ax.secondary_xaxis(location=-0.1)
+ sec.spines["bottom"].set_linewidth(0)
+ sec.set_xticks([0, (len(aspects) - 1) / 2, len(aspects) - 1], labels=["Wide", "Square", "Tall"])
+ sec.tick_params('x', length=0)
+ self.canvas.draw_idle()
+
+ def __scanConcept(self):
+ def f(advanced_scanning):
+ return ConceptModel.instance().get_concept_stats(self.idx, advanced_scanning)
+ return f
diff --git a/modules/ui/controllers/windows/ConvertController.py b/modules/ui/controllers/windows/ConvertController.py
new file mode 100644
index 000000000..4a1b446bc
--- /dev/null
+++ b/modules/ui/controllers/windows/ConvertController.py
@@ -0,0 +1,88 @@
+import os
+
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.ConvertModel import ConvertModel
+from modules.ui.utils.WorkerPool import WorkerPool
+from modules.util.enum.DataType import DataType
+from modules.util.enum.ModelFormat import ModelFormat
+from modules.util.enum.ModelType import ModelType
+from modules.util.enum.TrainingMethod import TrainingMethod
+
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class ConvertController(BaseController):
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/windows/convert.ui", name=None, parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connectFileDialog(self.ui.inputBtn, self.ui.inputLed, is_dir=False, save=False,
+ title=QCA.translate("dialog_window", "Open Input model"),
+ filters=QCA.translate("filetype_filters", "Safetensors (*.safetensors);;Diffusers (model_index.json);;Checkpoints (*.ckpt *.pt *.bin);;All Files (*.*)"))
+ self._connectFileDialog(self.ui.outputBtn, self.ui.outputLed, is_dir=False, save=True,
+ title=QCA.translate("dialog_window", "Save Output model"),
+ filters=QCA.translate("filetype_filters", "Safetensors (*.safetensors);;Diffusers (model_index.json)"))
+
+ state_ui_connections = {
+ "model_type": "modelTypeCmb",
+ "training_method": "trainingMethodCmb",
+ "input_name": "inputLed",
+ "output_model_destination": "outputLed",
+ "output_model_format": "outputFormatCmb",
+ "output_dtype": "outputDTypeCmb",
+ }
+
+ self._connectStateUI(state_ui_connections, ConvertModel.instance(), update_after_connect=True)
+
+ self._connect(self.ui.convertBtn.clicked, self.__startConvert())
+
+ def _loadPresets(self):
+ for e in ModelType.enabled_values(context="convert_window"):
+ self.ui.modelTypeCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in TrainingMethod.enabled_values(context="convert_window"):
+ self.ui.trainingMethodCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in DataType.enabled_values(context="convert_window"):
+ self.ui.outputDTypeCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in ModelFormat.enabled_values(context="convert_window"):
+ self.ui.outputFormatCmb.addItem(e.pretty_print(), userData=e)
+
+
+ ###Reactions###
+
+ def __startConvert(self):
+ @Slot()
+ def f():
+ if self.ui.outputLed.text() != "" and self.ui.inputLed.text() != "":
+ if os.path.exists(self.ui.inputLed.text()):
+ worker, name = WorkerPool.instance().createNamed(self.__convert(), "convert_model")
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableButton(False), result_fn=None, finished_fn=self.__enableButton(True),
+ errored_fn=self.__enableButton(True), aborted_fn=self.__enableButton(True))
+ WorkerPool.instance().start(name)
+ else:
+ self._openAlert(QCA.translate("convert_window", "Cannot Open Input Model"),
+ QCA.translate("convert_window", "The selected input model does not exist"), type="critical")
+ else:
+ self._openAlert(QCA.translate("convert_window", "No Model Selected"),
+ QCA.translate("convert_window", "Please select input and output model files"))
+ return f
+
+ def __enableButton(self, enabled):
+ @Slot()
+ def f():
+ self.ui.convertBtn.setEnabled(enabled)
+ return f
+
+ ###Utils###
+
+ def __convert(self):
+ def f():
+ return ConvertModel.instance().convert_model()
+
+ return f
diff --git a/modules/ui/controllers/windows/DatasetController.py b/modules/ui/controllers/windows/DatasetController.py
new file mode 100644
index 000000000..9dd937eb5
--- /dev/null
+++ b/modules/ui/controllers/windows/DatasetController.py
@@ -0,0 +1,690 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.DatasetModel import DatasetModel
+from modules.ui.models.MaskHistoryModel import MaskHistoryModel
+from modules.ui.utils.FigureWidget import FigureWidget
+from modules.ui.utils.WorkerPool import WorkerPool
+from modules.util.enum.CaptionFilter import CaptionFilter
+from modules.util.enum.EditMode import EditMode
+from modules.util.enum.FileFilter import FileFilter
+from modules.util.enum.MouseButton import MouseButton
+from modules.util.enum.ToolType import ToolType
+
+import numpy as np
+import PySide6.QtGui as QtG
+import PySide6.QtWidgets as QtW
+from matplotlib.transforms import Bbox
+from PIL import Image
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import QEvent, QObject, Qt, Signal, Slot
+
+
+# Event filter for low-level events non-natively associated with signals.
+class DatasetEventFilter(QObject):
+ ctrlPressed = Signal()
+ ctrlReleased = Signal()
+ close = Signal()
+
+ def eventFilter(self, obj, event):
+ if event.type() == QEvent.Type.Close:
+ self.close.emit()
+ elif event.type() == QEvent.Type.KeyPress and QtG.QKeyEvent(event).key() == Qt.Key_Control:
+ self.ctrlPressed.emit()
+ elif event.type() == QEvent.Type.KeyRelease and QtG.QKeyEvent(event).key() == Qt.Key_Control:
+ self.ctrlReleased.emit()
+
+ return QObject.eventFilter(self, obj, event)
+
+class DatasetController(BaseController):
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/windows/dataset.ui", name=None, parent=parent)
+ self.swapped = False
+ self.isCtrlPressed = False
+ self.path = None
+
+ ###FSM###
+
+ def _setup(self):
+ self.theme = "dark" if QtG.QGuiApplication.styleHints().colorScheme() == QtG.Qt.ColorScheme.Dark else "light"
+
+ # (fn, type, text, name, icon, tooltip, shortcut, spinbox_range)
+ # spinbox_range = (min, step, max, default value)
+ self.tools = [
+ {
+ "fn": self.__prevImg(),
+ "type": ToolType.BUTTON,
+ "icon": f"resources/icons/buttons/{self.theme}/arrow-left.svg",
+ "tooltip": QCA.translate("toolbar_item", "Previous image (Left Arrow)"),
+ "shortcut": "Left"
+ },
+ {
+ "fn": self.__nextImg(),
+ "type": ToolType.BUTTON,
+ "icon": f"resources/icons/buttons/{self.theme}/arrow-right.svg",
+ "tooltip": QCA.translate("toolbar_item", "Next image (Right Arrow)"),
+ "shortcut": "Right"
+ },
+ {"type": ToolType.SEPARATOR},
+ {
+ "type": ToolType.CHECKABLE_BUTTON,
+ "tool": EditMode.DRAW,
+ "icon": f"resources/icons/buttons/{self.theme}/brush.svg",
+ "tooltip": QCA.translate("toolbar_item", "Draw (Left Click) or Erase (Right Click) mask (CTRL+E)"),
+ "shortcut": "Ctrl+E"
+ },
+ {
+ "type": ToolType.CHECKABLE_BUTTON,
+ "tool": EditMode.FILL,
+ "icon": f"resources/icons/buttons/{self.theme}/paint-bucket.svg",
+ "tooltip": QCA.translate("toolbar_item", "Fill (Left Click) or Erase-fill (Right Click) mask (CTRL+F)"),
+ "shortcut": "Ctrl+F"
+ },
+ {
+ "fn": self.__setBrushSize(),
+ "type": ToolType.SPINBOX,
+ "name": "brush_sbx",
+ "text": QCA.translate("toolbar_item", "Brush size"),
+ "tooltip": QCA.translate("toolbar_item", "Brush size (Mouse Wheel Up/Down)"),
+ "spinbox_range": (1, 256, 1),
+ "value": 10
+ },
+ {
+ "fn": self.__setAlpha(),
+ "type": ToolType.DOUBLE_SPINBOX,
+ "name": "alpha_sbx",
+ "text": QCA.translate("toolbar_item", "Mask opacity"),
+ "tooltip": QCA.translate("toolbar_item", "Mask opacity for preview (CTRL+Mouse Wheel Up/Down)"),
+ "spinbox_range": (0.05, 1.0, 0.05),
+ "value": 0.5
+ },
+ {
+ "fn": self.__swapMouse(),
+ "type": ToolType.CHECKABLE_BUTTON,
+ "icon": f"resources/icons/buttons/{self.theme}/mouse.svg",
+ "tooltip": QCA.translate("toolbar_item", "Invert left/right mouse buttons behavior (CTRL+I)"),
+ "shortcut": "Ctrl+I"
+ },
+ {"type": ToolType.SEPARATOR},
+ {
+ "fn": self.__clearMask(),
+ "type": ToolType.BUTTON,
+ "text": QCA.translate("toolbar_item", "Clear Mask"),
+ "tooltip": QCA.translate("toolbar_item", "Clear mask (Del, or Middle Click)"),
+ "shortcut": "Del"
+ },
+ {
+ "fn": self.__clearAll(),
+ "type": ToolType.BUTTON,
+ "text": QCA.translate("toolbar_item", "Clear All"),
+ "tooltip": QCA.translate("toolbar_item", "Clear mask and caption (CTRL+Del)"),
+ "shortcut": "Ctrl+Del"
+ },
+ {
+ "fn": self.__resetMask(),
+ "type": ToolType.BUTTON,
+ "text": QCA.translate("toolbar_item", "Reset Mask"),
+ "tooltip": QCA.translate("toolbar_item", "Reset mask (CTRL+R)"),
+ "shortcut": "Ctrl+R"
+ },
+ {"type": ToolType.SEPARATOR},
+ {
+ "fn": self.__saveMask(),
+ "type": ToolType.BUTTON,
+ "icon": f"resources/icons/buttons/{self.theme}/save.svg",
+ "tooltip": QCA.translate("toolbar_item", "Save mask (CTRL+S)"),
+ "shortcut": "Ctrl+S"
+ },
+ {
+ "fn": self.__undo(),
+ "type": ToolType.BUTTON,
+ "icon": f"resources/icons/buttons/{self.theme}/undo.svg",
+ "tooltip": QCA.translate("toolbar_item", "Undo (CTRL+Z)"),
+ "shortcut": "Ctrl+Z"
+ },
+ {
+ "fn": self.__redo(),
+ "type": ToolType.BUTTON,
+ "icon": f"resources/icons/buttons/{self.theme}/redo.svg",
+ "tooltip": QCA.translate("toolbar_item", "Redo (CTRL+Y)"),
+ "shortcut": "Ctrl+Y"
+ },
+ {"type": ToolType.SEPARATOR},
+ {
+ "fn": self.__deleteSample(),
+ "type": ToolType.BUTTON,
+ "icon": f"resources/icons/buttons/{self.theme}/trash-2.svg",
+ "tooltip": QCA.translate("toolbar_item", "Delete image, mask and caption (CTRL+SHIFT+Del)"),
+ "shortcut": "Ctrl+Shift+Del"
+ },
+ ]
+
+ self.num_files = 0
+ self.current_index = 0
+ self.alpha = 1.0
+ self.brush = 1
+ self.im = None
+ self.image = None
+ self.current_image_path = None
+ self.current_caption = ""
+
+ self.canvas = FigureWidget(parent=self.ui, width=7, height=5, zoom_tools=True, other_tools=self.tools, emit_clicked=True, emit_moved=True, emit_wheel=True, emit_released=True, use_data_coordinates=True)
+ self.ax = self.canvas.figure.subplots()
+ self.ax.set_axis_off()
+
+ self.ui.canvasLay.addWidget(self.canvas.toolbar)
+ self.ui.canvasLay.addWidget(self.canvas)
+
+ self.leafWidgets = {}
+
+
+ self.custom_event_filter = DatasetEventFilter(self.ui)
+
+ def _connectUIBehavior(self):
+ self._connect(self.ui.openBtn.clicked, self.__openDataset())
+ self._connect(self.ui.browseBtn.clicked, self.__browse())
+
+ self._connect(self.ui.saveCaptionBtn.clicked, self.__saveCaption())
+ self._connect(self.ui.deleteCaptionBtn.clicked, self.__deleteCaption())
+ self._connect(self.ui.resetCaptionBtn.clicked, self.__resetCaption())
+
+ self._connect(self.ui.fileTreeWdg.itemSelectionChanged, self.__selectFile())
+
+ self._connect([self.ui.fileFilterLed.editingFinished, self.ui.fileFilterCmb.activated,
+ self.ui.captionFilterLed.editingFinished, self.ui.captionFilterCmb.activated,
+ self.ui.maskFilterCbx.toggled, self.ui.captionFilterCbx.toggled],
+ self.__updateDataset())
+
+ self._connect(self.ui.includeSubdirCbx.toggled, self.__startScan())
+
+ self._connect(self.ui.captionTed.textChanged, self.__updateCaption())
+
+
+ self._connect(self.canvas.clicked, self.__onClicked())
+ self._connect(self.canvas.released, self.__onReleased())
+ self._connect(self.canvas.wheelUp, self.__onWheelUp())
+ self._connect(self.canvas.wheelDown, self.__onWheelDown())
+
+ self.canvas.registerTool(EditMode.DRAW, moved_fn=self.__onDrawMoved(), use_mpl_event=False)
+ self.canvas.registerTool(EditMode.FILL, clicked_fn=self.__onMaskClicked(), use_mpl_event=False)
+
+ self.ui.installEventFilter(self.custom_event_filter)
+
+ self._connect(self.custom_event_filter.ctrlPressed, self.__onCtrlPressed())
+ self._connect(self.custom_event_filter.ctrlReleased, self.__onCtrlReleased())
+ self._connect(self.custom_event_filter.close, self.__onClose())
+
+
+ def _loadPresets(self):
+ for e in FileFilter.enabled_values():
+ self.ui.fileFilterCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in CaptionFilter.enabled_values():
+ self.ui.captionFilterCmb.addItem(e.pretty_print(), userData=e)
+
+ ###Reactions###
+
+ def __updateCaption(self):
+ @Slot()
+ def f():
+ self.current_caption = self.ui.captionTed.toPlainText()
+ return f
+
+ def __openDataset(self):
+ @Slot()
+ def f():
+ if self.__saveChanged():
+ diag = QtW.QFileDialog()
+ self.path = diag.getExistingDirectory(parent=None, caption=QCA.translate("dialog_window", "Open Dataset directory"), dir=DatasetModel.instance().get_state("path"))
+ self.__startScan()()
+
+ return f
+
+ def __startScan(self):
+ @Slot()
+ def f():
+ if self.path is not None:
+ DatasetModel.instance().set_state("include_subdirectories", self.ui.includeSubdirCbx.isChecked())
+
+ worker, name = WorkerPool.instance().createNamed(self.__scan(), name="open_dataset", dir=self.path)
+ if worker is not None:
+ worker.connectCallbacks(finished_fn=self.__updateDataset())
+ WorkerPool.instance().start(name)
+ return f
+
+ def __updateDataset(self):
+ @Slot()
+ def f():
+ data = {
+ "file_filter": self.ui.fileFilterLed.text(),
+ "file_filter_mode": self.ui.fileFilterCmb.currentData(),
+ "caption_filter": self.ui.captionFilterLed.text(),
+ "caption_filter_mode": self.ui.captionFilterCmb.currentData(),
+ "filter_mask_exists": self.ui.maskFilterCbx.isChecked(),
+ "filter_caption_exists": self.ui.captionFilterCbx.isChecked(),
+ }
+ DatasetModel.instance().bulk_write(data)
+
+ files = DatasetModel.instance().get_filtered_files()
+
+
+ self.current_index = 0
+ self.num_files = len(files)
+ if self.num_files == 0:
+ self.ui.numFilesLbl.setText(QCA.translate("dataset_window", "No image found"))
+ else:
+ self.ui.numFilesLbl.setText(QCA.translate("dataset_window", "Dataset loaded"))
+
+ file_tree = {}
+ for i, file in enumerate(files):
+ self.__buildTree(file, file_tree, i)
+
+ self.ui.fileTreeWdg.clear()
+ self.leafWidgets = {}
+ self.__drawTree(self.ui.fileTreeWdg, file_tree)
+ return f
+
+
+ def __prevImg(self):
+ @Slot()
+ def f():
+ if self.num_files > 0 and self.__saveChanged():
+ self.current_index = (self.current_index + self.num_files - 1) % self.num_files
+ self.ui.fileTreeWdg.setCurrentItem(self.leafWidgets[self.current_index])
+ return f
+
+ def __nextImg(self):
+ @Slot()
+ def f():
+ if self.num_files > 0 and self.__saveChanged():
+ self.current_index = (self.current_index + 1) % self.num_files
+ self.ui.fileTreeWdg.setCurrentItem(self.leafWidgets[self.current_index])
+ return f
+
+ def __setBrushSize(self):
+ @Slot(int)
+ def f(val):
+ self.brush = val
+ return f
+
+ def __setAlpha(self):
+ @Slot(float)
+ def f(val):
+ self.alpha = val
+
+ self.__updateCanvas()
+ return f
+
+ def __swapMouse(self):
+ @Slot(bool)
+ def f(checked):
+ self.swapped = checked
+ return f
+
+ def __onCtrlPressed(self):
+ @Slot()
+ def f():
+ self.isCtrlPressed = True
+ return f
+
+ def __onCtrlReleased(self):
+ @Slot()
+ def f():
+ self.isCtrlPressed = False
+ return f
+
+ def __onClose(self):
+ @Slot()
+ def f():
+ if self.current_image_path is not None:
+ choice, new_caption = self.__checkCaptionChanged(cancel=False)
+ if choice == QtW.QMessageBox.StandardButton.Yes:
+ DatasetModel.instance().save_caption(self.current_image_path, new_caption)
+ else:
+ self.__resetCaption()()
+
+ choice, new_mask, mask_path = self.__checkMaskChanged(cancel=False)
+ if choice == QtW.QMessageBox.StandardButton.Yes:
+ Image.fromarray(new_mask, "L").convert("RGB").save(mask_path)
+ MaskHistoryModel.instance().load_mask(new_mask)
+ else:
+ MaskHistoryModel.instance().clear_history()
+ self.__updateCanvas()
+ return f
+
+ def __clearAll(self):
+ @Slot()
+ def f():
+ if self.current_image_path is not None:
+ choice = self._openAlert(QCA.translate("dataset_window", "Clear Mask and Caption"),
+ QCA.translate("dataset_window",
+ "Do you want to clear mask and caption? This operation will not change files on disk."),
+ type="question",
+ buttons=QtW.QMessageBox.StandardButton.Yes | QtW.QMessageBox.StandardButton.No)
+ if choice == QtW.QMessageBox.StandardButton.Yes:
+ MaskHistoryModel.instance().clear_history()
+ MaskHistoryModel.instance().delete_mask()
+ MaskHistoryModel.instance().commit()
+ self.ui.captionTed.setPlainText("")
+
+ self.__updateCanvas()
+ return f
+
+ def __resetMask(self):
+ @Slot()
+ def f():
+ if self.current_image_path is not None:
+ MaskHistoryModel.instance().clear_history()
+
+ self.__updateCanvas()
+ return f
+
+ def __clearMask(self):
+ @Slot()
+ def f():
+ if self.current_image_path is not None:
+ MaskHistoryModel.instance().delete_mask()
+ MaskHistoryModel.instance().commit()
+
+ self.__updateCanvas()
+ return f
+
+ def __undo(self):
+ @Slot()
+ def f():
+ if self.current_image_path is not None:
+ MaskHistoryModel.instance().undo()
+
+ self.__updateCanvas()
+ return f
+
+ def __redo(self):
+ @Slot()
+ def f():
+ if self.current_image_path is not None:
+ MaskHistoryModel.instance().redo()
+
+ self.__updateCanvas()
+ return f
+
+ def __saveMask(self):
+ @Slot()
+ def f():
+ if self.current_image_path is not None:
+ choice, new_mask, mask_path = self.__checkMaskChanged()
+ if choice == QtW.QMessageBox.StandardButton.Yes:
+ new_mask_img = Image.fromarray(new_mask, "L")
+ new_mask_img.convert("RGB").save(mask_path)
+ MaskHistoryModel.instance().load_mask(new_mask)
+
+ return f
+
+ def __saveCaption(self):
+ @Slot()
+ def f():
+ if self.current_image_path is not None:
+ choice, new_caption = self.__checkCaptionChanged()
+
+ if choice == QtW.QMessageBox.StandardButton.Yes:
+ DatasetModel.instance().save_caption(self.current_image_path, new_caption)
+ return f
+
+ def __deleteCaption(self):
+ @Slot()
+ def f():
+ if self.current_image_path is not None:
+ if self.ui.captionTed.toPlainText().strip() != "":
+ choice = self._openAlert(QCA.translate("dataset_window", "Delete Caption"),
+ QCA.translate("dataset_window", "Do you want to delete caption?"),
+ type="question",
+ buttons=QtW.QMessageBox.StandardButton.Yes | QtW.QMessageBox.StandardButton.No)
+ if choice == QtW.QMessageBox.StandardButton.Yes:
+ DatasetModel.instance().delete_caption(self.current_image_path)
+ self.ui.captionTed.setPlainText("")
+ return f
+
+ def __resetCaption(self):
+ @Slot()
+ def f():
+ if self.current_image_path is not None:
+ _, _, caption = DatasetModel.instance().get_sample(self.current_image_path)
+ if caption is not None:
+ self.ui.captionTed.setPlainText(caption.strip())
+ return f
+
+ def __deleteSample(self):
+ @Slot()
+ def f():
+ if self.current_image_path is not None:
+ choice = self._openAlert(QCA.translate("dataset_window", "Delete Sample"),
+ QCA.translate("dataset_window", "Do you really want to delete the sample (image, mask and caption)? This is not reversible."),
+ type="warning",
+ buttons=QtW.QMessageBox.StandardButton.Yes | QtW.QMessageBox.StandardButton.No)
+ if choice == QtW.QMessageBox.StandardButton.Yes:
+ DatasetModel.instance().delete_sample(self.current_image_path)
+ self.__updateDataset()()
+ self.__selectFile()()
+ self.ui.fileTreeWdg.setCurrentItem(self.leafWidgets[self.current_index])
+
+ return f
+
+ def __selectFile(self):
+ @Slot()
+ def f():
+ selected_wdg = self.ui.fileTreeWdg.selectedItems()
+ if len(selected_wdg) > 0:
+ if self.__saveChanged():
+ self.current_image_path = selected_wdg[0].fullpath
+
+ idx = selected_wdg[0].idx
+ if self.current_image_path is not None:
+ if self.num_files > 0:
+ if idx is not None:
+ self.current_index = idx
+ self.ui.numFilesLbl.setText(f"{self.current_index + 1}/{self.num_files}")
+
+ self.image, mask, caption = DatasetModel.instance().get_sample(self.current_image_path)
+
+ if caption is not None:
+ self.ui.captionTed.setPlainText(caption)
+
+ if mask is None:
+ mask = Image.new("L", self.image.size, 255)
+
+ MaskHistoryModel.instance().load_mask(np.asarray(mask))
+ self.im = self.ax.imshow(self.image)
+
+ self.__updateCanvas()
+ else:
+ self.ui.fileTreeWdg.setCurrentItem(self.leafWidgets[self.current_index])
+
+
+ return f
+
+ def __browse(self):
+ @Slot()
+ def f():
+ path = DatasetModel.instance().get_state("path")
+ if path is not None:
+ self._browse(path)
+ else:
+ self._openAlert(QCA.translate("dataset_window", "No Dataset Loaded"),
+ QCA.translate("dataset_window", "Please open a dataset first"))
+ return f
+
+ def __onClicked(self):
+ @Slot(MouseButton, int, int)
+ def f(btn, x, y):
+ if self.current_image_path is not None and btn == MouseButton.MIDDLE:
+ MaskHistoryModel.instance().delete_mask() # This click is also associated with a release, which will commit the change and update the canvas.
+ return f
+
+ def __onReleased(self):
+ @Slot(MouseButton, int, int)
+ def f(btn, x, y):
+ if self.current_image_path is not None:
+ MaskHistoryModel.instance().commit()
+
+ self.__updateCanvas()
+ return f
+
+ def __onWheelUp(self):
+ @Slot()
+ def f():
+ if self.current_image_path is not None:
+ if self.isCtrlPressed:
+ wdg = self.canvas.toolbar.findChild(QtW.QDoubleSpinBox, "alpha_sbx")
+ else:
+ wdg = self.canvas.toolbar.findChild(QtW.QSpinBox, "brush_sbx")
+ new_val = wdg.value() + wdg.singleStep()
+ wdg.setValue(new_val) # This will emit valueChanged, which is connected to self.__setBrushSize()
+ return f
+
+ def __onWheelDown(self):
+ @Slot()
+ def f():
+ if self.current_image_path is not None:
+ if self.isCtrlPressed:
+ wdg = self.canvas.toolbar.findChild(QtW.QDoubleSpinBox, "alpha_sbx")
+ else:
+ wdg = self.canvas.toolbar.findChild(QtW.QSpinBox, "brush_sbx")
+ new_val = wdg.value() - wdg.singleStep()
+ wdg.setValue(new_val)
+ return f
+
+ def __onMaskClicked(self):
+ @Slot(MouseButton, int, int)
+ def f(btn, x, y):
+ if self.current_image_path is not None:
+ if btn == MouseButton.LEFT:
+ MaskHistoryModel.instance().fill(x, y, 0 if not self.swapped else 255)
+ elif btn == MouseButton.RIGHT:
+ MaskHistoryModel.instance().fill(x, y, 255 if not self.swapped else 0)
+ self.__updateCanvas()
+ return f
+
+ def __onDrawMoved(self):
+ @Slot(MouseButton, int, int, int, int)
+ def f(btn, x0, y0, x1, y1):
+ if self.current_image_path is not None and x0 >= 0 and y0 >= 0 and x1 >= 0 and y1 >= 0:
+ if btn == MouseButton.LEFT:
+ MaskHistoryModel.instance().paint_stroke(x0, y0, x1, y1, int(self.brush), 0 if not self.swapped else 255, commit=False) # Draw stroke 0 from x0,y0 to x1,y1
+ self.__updateCanvas(blitbb=(x0 - self.brush, x1 + self.brush, y0 - self.brush, y1 + self.brush))
+ elif btn == MouseButton.RIGHT:
+ MaskHistoryModel.instance().paint_stroke(x0, y0, x1, y1, int(self.brush), 255 if not self.swapped else 0, commit=False)
+ self.__updateCanvas(blitbb=(x0 - self.brush, x1 + self.brush, y0 - self.brush, y1 + self.brush))
+
+ return f
+
+ ###Utils###
+
+ def __scan(self):
+ def f(dir):
+ DatasetModel.instance().set_state("path", dir)
+ DatasetModel.instance().scan()
+ return f
+
+ def __saveChanged(self):
+ if self.current_image_path is not None:
+ choice, new_caption = self.__checkCaptionChanged(cancel=True)
+ if choice != QtW.QMessageBox.StandardButton.Cancel:
+ if choice == QtW.QMessageBox.StandardButton.Yes:
+ DatasetModel.instance().save_caption(self.current_image_path, new_caption)
+
+ choice, new_mask, mask_path = self.__checkMaskChanged(cancel=True)
+ if choice != QtW.QMessageBox.StandardButton.Cancel:
+ if choice == QtW.QMessageBox.StandardButton.Yes:
+ Image.fromarray(new_mask, "L").convert("RGB").save(mask_path)
+
+ return choice != QtW.QMessageBox.StandardButton.Cancel
+ else:
+ return True
+
+ def __updateCanvas(self, blitbb=None):
+ if self.im is not None:
+ mask = np.clip(MaskHistoryModel.instance().get_state("current_mask")[..., np.newaxis].astype(float), 1 - self.alpha, 1)
+ self.im.set_data((np.asarray(self.image) * mask).astype(np.uint8))
+
+ if blitbb is not None:
+ self.canvas.blit(Bbox.from_extents(*blitbb))
+
+ self.canvas.draw_idle()
+
+
+ #def __buildTree(self, fullname, tree, idx, name=None):
+ def __buildTree(self, file, tree, idx, name=None):
+ if name is None:
+ name = file[0]
+ path = name.split("/")
+ if len(path) == 1:
+ tree[path[0]] = (idx, file)
+ elif len(path) > 1:
+ if path[0] not in tree:
+ tree[path[0]] = {}
+ self.__buildTree(file, tree[path[0]], idx, "/".join(path[1:]))
+
+ def __drawTree(self, parent, tree):
+ for k in sorted(tree.keys(), key=lambda x: DatasetModel.natural_sort_key(x)):
+ v = tree[k]
+ wdg = QtW.QTreeWidgetItem(parent, [k])
+ if isinstance(v, dict):
+ wdg.setIcon(0, QtG.QIcon(f"resources/icons/buttons/{self.theme}/folder-open.svg"))
+ wdg.fullpath = None
+ wdg.idx = None
+ self.__drawTree(wdg, v)
+ else:
+ has_caption = v[1][1]
+ has_mask = v[1][2]
+
+ if has_caption and has_mask:
+ wdg.setIcon(0, QtG.QIcon(f"resources/icons/buttons/{self.theme}/file-check-corner.svg"))
+ wdg.setToolTip(0, QCA.translate("dataset_window", "Caption and mask available"))
+ elif has_caption:
+ wdg.setIcon(0, QtG.QIcon(f"resources/icons/buttons/{self.theme}/file-minus-corner.svg"))
+ wdg.setToolTip(0, QCA.translate("dataset_window", "Missing mask"))
+ elif has_mask:
+ wdg.setIcon(0, QtG.QIcon(f"resources/icons/buttons/{self.theme}/file-scan.svg"))
+ wdg.setToolTip(0, QCA.translate("dataset_window", "Missing caption"))
+ else:
+ wdg.setIcon(0, QtG.QIcon(f"resources/icons/buttons/{self.theme}/file-x-corner.svg"))
+ wdg.setToolTip(0, QCA.translate("dataset_window", "No caption or mask found"))
+
+ wdg.fullpath = v[1][0]
+ wdg.idx = v[0]
+ self.leafWidgets[v[0]] = wdg
+
+ def __checkMaskChanged(self, cancel=False):
+ buttons = QtW.QMessageBox.StandardButton.Yes | QtW.QMessageBox.StandardButton.No
+ if cancel:
+ buttons |= QtW.QMessageBox.StandardButton.Cancel
+
+ mask = MaskHistoryModel.instance().get_state("original_mask")
+ new_mask = MaskHistoryModel.instance().get_state("current_mask")
+ mask_path, mask_exists = DatasetModel.instance().get_mask_path(self.current_image_path)
+
+ choice = QtW.QMessageBox.StandardButton.No
+ if not mask_exists:
+ choice = QtW.QMessageBox.StandardButton.Yes
+ elif np.not_equal(mask, new_mask).any():
+ choice = self._openAlert(QCA.translate("dataset_window", "Save Mask"),
+ QCA.translate("dataset_window", "Mask has changed. Do you want to save it?"),
+ type="question",
+ buttons=buttons)
+ return choice, new_mask, mask_path
+
+ def __checkCaptionChanged(self, cancel=False):
+ buttons = QtW.QMessageBox.StandardButton.Yes | QtW.QMessageBox.StandardButton.No
+ if cancel:
+ buttons |= QtW.QMessageBox.StandardButton.Cancel
+
+ _, _, caption = DatasetModel.instance().get_sample(self.current_image_path)
+
+ choice = QtW.QMessageBox.StandardButton.No
+
+ if caption is None:
+ choice = QtW.QMessageBox.StandardButton.Yes
+ elif caption.strip() != self.current_caption.strip():
+ choice = self._openAlert(QCA.translate("dataset_window", "Save Caption"),
+ QCA.translate("dataset_window", "Caption has changed. Do you want to save it?"),
+ type="question",
+ buttons=buttons)
+
+ return choice, self.current_caption
diff --git a/modules/ui/controllers/windows/MaskController.py b/modules/ui/controllers/windows/MaskController.py
new file mode 100644
index 000000000..602a3f651
--- /dev/null
+++ b/modules/ui/controllers/windows/MaskController.py
@@ -0,0 +1,86 @@
+import os
+
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.MaskModel import MaskModel
+from modules.ui.utils.WorkerPool import WorkerPool
+from modules.util.enum.GenerateMasksModel import GenerateMasksAction, GenerateMasksModel
+
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class MaskController(BaseController):
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/windows/generate_mask.ui", name=None, parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connectFileDialog(self.ui.folderBtn, self.ui.folderLed, is_dir=True, save=False, title=
+ QCA.translate("dialog_window", "Open Dataset directory"))
+
+ state_ui_connections = {
+ "model": "modelCmb",
+ "path": "folderLed",
+ "prompt": "promptLed",
+ "mode": "modeCmb",
+ "alpha": "alphaSbx",
+ "threshold": "thresholdSbx",
+ "smooth": "smoothSbx",
+ "expand": "expandSbx",
+ "include_subdirectories": "includeSubfolderCbx"
+ }
+
+ self._connectStateUI(state_ui_connections, MaskModel.instance(), update_after_connect=True)
+
+ self.__enableControls(True)()
+
+ self._connect(self.ui.createMaskBtn.clicked, self.__startMask())
+
+ def _loadPresets(self):
+ for e in GenerateMasksModel.enabled_values():
+ self.ui.modelCmb.addItem(e.pretty_print(), userData=e)
+
+ for e in GenerateMasksAction.enabled_values():
+ self.ui.modeCmb.addItem(e.pretty_print(), userData=e)
+
+ ###Reactions###
+
+ def __startMask(self):
+ @Slot()
+ def f():
+ if self.ui.folderLed.text() != "":
+ if os.path.isdir(self.ui.folderLed.text()):
+ worker, name = WorkerPool.instance().createNamed(self.__createMask(), "create_mask", inject_progress_callback=True)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableControls(False), result_fn=None,
+ finished_fn=self.__enableControls(True),
+ errored_fn=self.__enableControls(True), aborted_fn=self.__enableControls(True),
+ progress_fn=self._updateProgress(self.ui.progressBar))
+ WorkerPool.instance().start(name)
+ else:
+ self._openAlert(QCA.translate("mask_window", "Invalid Folder"),
+ QCA.translate("mask_window", "The selected input folder does not exist"),
+ type="critical")
+ else:
+ self._openAlert(QCA.translate("mask_window", "No Folder Selected"),
+ QCA.translate("mask_window", "Please select an input folder"))
+
+ return f
+
+ def __enableControls(self, enabled):
+ @Slot()
+ def f():
+ self.ui.createMaskBtn.setEnabled(enabled)
+ if enabled:
+ self.ui.progressBar.setValue(0)
+ return f
+
+
+ ###Utils###
+
+ def __createMask(self):
+ def f(progress_fn=None):
+ return MaskModel.instance().create_masks(progress_fn=progress_fn)
+
+ return f
diff --git a/modules/ui/controllers/windows/NewSampleController.py b/modules/ui/controllers/windows/NewSampleController.py
new file mode 100644
index 000000000..3b49e6934
--- /dev/null
+++ b/modules/ui/controllers/windows/NewSampleController.py
@@ -0,0 +1,30 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.controllers.widgets.SampleParamsController import SampleParamsController
+from modules.ui.models.SampleModel import SampleModel
+
+import PySide6.QtWidgets as QtW
+from PySide6.QtCore import Slot
+
+
+class NewSampleController(BaseController):
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/windows/new_sample.ui", name=None, parent=parent)
+
+ ###FSM###
+
+ def _setup(self):
+ self.samplingParams = SampleParamsController(self.loader, model_instance=SampleModel.instance(), write_signal=self.ui.okBtn.clicked, read_signal=QtW.QApplication.instance().openSample, parent=self.parent)
+ self.ui.paramsLay.addWidget(self.samplingParams.ui)
+
+ def _connectUIBehavior(self):
+ self._connect(self.ui.okBtn.clicked, self.__saveSample())
+
+ ###Reactions###
+
+ def __saveSample(self):
+ @Slot()
+ def f():
+ QtW.QApplication.instance().samplesChanged.emit()
+ self.ui.hide()
+
+ return f
diff --git a/modules/ui/controllers/windows/OptimizerController.py b/modules/ui/controllers/windows/OptimizerController.py
new file mode 100644
index 000000000..f4e6905d6
--- /dev/null
+++ b/modules/ui/controllers/windows/OptimizerController.py
@@ -0,0 +1,211 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.StateModel import StateModel
+
+import PySide6.QtWidgets as QtW
+from PySide6.QtCore import QCoreApplication as QCA
+
+
+class OptimizerController(BaseController):
+ # @formatter:off
+ optimizer_params = {
+ "adam_w_mode": {"title": QCA.translate("optimizer_parameter", "Adam W Mode"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether to use weight decay correction for Adam optimizer."), "type": "bool"},
+ "alpha": {"title": QCA.translate("optimizer_parameter", "Alpha"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Smoothing parameter for RMSprop and others."), "type": "float"},
+ "amsgrad": {"title": QCA.translate("optimizer_parameter", "AMSGrad"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether to use the AMSGrad variant for Adam."), "type": "bool"},
+ "beta1": {"title": QCA.translate("optimizer_parameter", "Beta1"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "optimizer_momentum term."), "type": "float"},
+ "beta2": {"title": QCA.translate("optimizer_parameter", "Beta2"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Coefficients for computing running averages of gradient."), "type": "float"},
+ "beta3": {"title": QCA.translate("optimizer_parameter", "Beta3"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Coefficient for computing the Prodigy stepsize."), "type": "float"},
+ "bias_correction": {"title": QCA.translate("optimizer_parameter", "Bias Correction"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether to use bias correction in optimization algorithms like Adam."), "type": "bool"},
+ "block_wise": {"title": QCA.translate("optimizer_parameter", "Block Wise"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether to perform block-wise model update."), "type": "bool"},
+ "capturable": {"title": QCA.translate("optimizer_parameter", "Capturable"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether some property of the optimizer can be captured."), "type": "bool"},
+ "centered": {"title": QCA.translate("optimizer_parameter", "Centered"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether to center the gradient before scaling. Great for stabilizing the training process."), "type": "bool"},
+ "clip_threshold": {"title": QCA.translate("optimizer_parameter", "Clip Threshold"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Clipping value for gradients."), "type": "float"},
+ "d0": {"title": QCA.translate("optimizer_parameter", "Initial D"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Initial D estimate for D-adaptation."), "type": "float"},
+ "d_coef": {"title": QCA.translate("optimizer_parameter", "D Coefficient"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Coefficient in the expression for the estimate of d."), "type": "float"},
+ "dampening": {"title": QCA.translate("optimizer_parameter", "Dampening"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Dampening for optimizer_momentum."), "type": "float"},
+ "decay_rate": {"title": QCA.translate("optimizer_parameter", "Decay Rate"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Rate of decay for moment estimation."), "type": "float"},
+ "decouple": {"title": QCA.translate("optimizer_parameter", "Decouple"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Use AdamW style optimizer_decoupled weight decay."), "type": "bool"},
+ "differentiable": {"title": QCA.translate("optimizer_parameter", "Differentiable"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether the optimization function is optimizer_differentiable."), "type": "bool"},
+ "eps": {"title": QCA.translate("optimizer_parameter", "EPS"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "A small value to prevent division by zero."), "type": "float"},
+ "eps2": {"title": QCA.translate("optimizer_parameter", "EPS 2"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "A small value to prevent division by zero."), "type": "float"},
+ "foreach": {"title": QCA.translate("optimizer_parameter", "ForEach"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether to use a foreach implementation if available. This implementation is usually faster."), "type": "bool"},
+ "fsdp_in_use": {"title": QCA.translate("optimizer_parameter", "FSDP in Use"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Flag for using sharded parameters."), "type": "bool"},
+ "fused": {"title": QCA.translate("optimizer_parameter", "Fused"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether to use a fused implementation if available. This implementation is usually faster and requires less memory."), "type": "bool"},
+ "fused_back_pass": {"title": QCA.translate("optimizer_parameter", "Fused Back Pass"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether to fuse the back propagation pass with the optimizer step. This reduces VRAM usage, but is not compatible with gradient accumulation."), "type": "bool"},
+ "growth_rate": {"title": QCA.translate("optimizer_parameter", "Growth Rate"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Limit for D estimate growth rate."), "type": "float"},
+ "initial_accumulator_value": {"title": QCA.translate("optimizer_parameter", "Initial Accumulator Value"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Initial value for Adagrad optimizer."), "type": "float"},
+ "initial_accumulator": {"title": QCA.translate("optimizer_parameter", "Initial Accumulator"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Sets the starting value for both moment estimates to ensure numerical stability and balanced adaptive updates early in training."), "type": "float"},
+ "is_paged": {"title": QCA.translate("optimizer_parameter", "Is Paged"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether the optimizer's internal state should be paged to CPU."), "type": "bool"},
+ "log_every": {"title": QCA.translate("optimizer_parameter", "Log Every"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Intervals at which logging should occur."), "type": "int"},
+ "lr_decay": {"title": QCA.translate("optimizer_parameter", "LR Decay"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Rate at which learning rate decreases."), "type": "float"},
+ "max_unorm": {"title": QCA.translate("optimizer_parameter", "Max Unorm"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Maximum value for gradient clipping by norms."), "type": "float"},
+ "maximize": {"title": QCA.translate("optimizer_parameter", "Maximize"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether to optimizer_maximize the optimization function."), "type": "bool"},
+ "min_8bit_size": {"title": QCA.translate("optimizer_parameter", "Min 8bit Size"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Minimum tensor size for 8-bit quantization."), "type": "int"},
+ "quant_block_size": {"title": QCA.translate("optimizer_parameter", "Quant Block Size"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Size of a block of normalized 8-bit quantization data. Larger values increase memory efficiency at the cost of data precision."), "type": "int"},
+ "momentum": {"title": QCA.translate("optimizer_parameter", "optimizer_momentum"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Factor to accelerate SGD in relevant direction."), "type": "float"},
+ "nesterov": {"title": QCA.translate("optimizer_parameter", "Nesterov"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether to enable Nesterov optimizer_momentum."), "type": "bool"},
+ "no_prox": {"title": QCA.translate("optimizer_parameter", "No Prox"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether to use proximity updates or not."), "type": "bool"},
+ "optim_bits": {"title": QCA.translate("optimizer_parameter", "Optim Bits"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Number of bits used for optimization."), "type": "int"},
+ "percentile_clipping": {"title": QCA.translate("optimizer_parameter", "Percentile Clipping"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Gradient clipping based on percentile values."), "type": "int"},
+ "relative_step": {"title": QCA.translate("optimizer_parameter", "Relative Step"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether to use a relative step size."), "type": "bool"},
+ "safeguard_warmup": {"title": QCA.translate("optimizer_parameter", "Safeguard Warmup"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Avoid issues during warm-up stage."), "type": "bool"},
+ "scale_parameter": {"title": QCA.translate("optimizer_parameter", "Scale Parameter"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether to scale the parameter or not."), "type": "bool"},
+ "stochastic_rounding": {"title": QCA.translate("optimizer_parameter", "Stochastic Rounding"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Stochastic rounding for weight updates. Improves quality when using bfloat16 weights."), "type": "bool"},
+ "use_bias_correction": {"title": QCA.translate("optimizer_parameter", "Bias Correction"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Turn on Adam's bias correction."), "type": "bool"},
+ "use_triton": {"title": QCA.translate("optimizer_parameter", "Use Triton"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether Triton optimization should be used."), "type": "bool"},
+ "warmup_init": {"title": QCA.translate("optimizer_parameter", "Warmup Initialization"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether to warm-up the optimizer initialization."), "type": "bool"},
+ "weight_decay": {"title": QCA.translate("optimizer_parameter", "Weight Decay"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Regularization to prevent overfitting."), "type": "float"},
+ "weight_lr_power": {"title": QCA.translate("optimizer_parameter", "Weight LR Power"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "During warmup, the weights in the average will be equal to lr raised to this power. Set to 0 for no weighting."), "type": "float"},
+ "decoupled_decay": {"title": QCA.translate("optimizer_parameter", "Decoupled Decay"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "If set as True, then the optimizer uses decoupled weight decay as in AdamW."), "type": "bool"},
+ "fixed_decay": {"title": QCA.translate("optimizer_parameter", "Fixed Decay"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "(When Decoupled Decay is True:) Applies fixed weight decay when True; scales decay with learning rate when False."), "type": "bool"},
+ "rectify": {"title": QCA.translate("optimizer_parameter", "Rectify"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Perform the rectified update similar to RAdam."), "type": "bool"},
+ "degenerated_to_sgd": {"title": QCA.translate("optimizer_parameter", "Degenerated to SGD"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Performs SGD update when gradient variance is high."), "type": "bool"},
+ "k": {"title": QCA.translate("optimizer_parameter", "K"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Number of vector projected per iteration."), "type": "int"},
+ "xi": {"title": QCA.translate("optimizer_parameter", "Xi"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Term used in vector projections to avoid division by zero."), "type": "float"},
+ "n_sma_threshold": {"title": QCA.translate("optimizer_parameter", "N SMA Threshold"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Number of SMA threshold."), "type": "int"},
+ "ams_bound": {"title": QCA.translate("optimizer_parameter", "AMS Bound"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether to use the AMSBound variant."), "type": "bool"},
+ "r": {"title": QCA.translate("optimizer_parameter", "R"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "EMA factor."), "type": "float"},
+ "adanorm": {"title": QCA.translate("optimizer_parameter", "AdaNorm"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether to use the AdaNorm variant"), "type": "bool"},
+ "adam_debias": {"title": QCA.translate("optimizer_parameter", "Adam Debias"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Only correct the denominator to avoid inflating step sizes early in training."), "type": "bool"},
+ "slice_p": {"title": QCA.translate("optimizer_parameter", "Slice parameters"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Reduce memory usage by calculating LR adaptation statistics on only every pth entry of each tensor. For values greater than 1 this is an approximation to standard Prodigy. Values ~11 are reasonable."), "type": "int"},
+ "cautious": {"title": QCA.translate("optimizer_parameter", "Cautious"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Whether to use the Cautious variant"), "type": "bool"},
+ "weight_decay_by_lr": {"title": QCA.translate("optimizer_parameter", "weight_decay_by_lr"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Automatically adjust weight decay based on lr"), "type": "bool"},
+ "prodigy_steps": {"title": QCA.translate("optimizer_parameter", "prodigy_steps"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Turn off Prodigy after N steps"), "type": "int"},
+ "use_speed": {"title": QCA.translate("optimizer_parameter", "use_speed"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "use_speed method"), "type": "bool"},
+ "split_groups": {"title": QCA.translate("optimizer_parameter", "split_groups"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Use split groups when training multiple params(uNet,TE..)"), "type": "bool"},
+ "split_groups_mean": {"title": QCA.translate("optimizer_parameter", "split_groups_mean"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Use mean for split groups"), "type": "bool"},
+ "factored": {"title": QCA.translate("optimizer_parameter", "factored"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Use factored"), "type": "bool"},
+ "factored_fp32": {"title": QCA.translate("optimizer_parameter", "factored_fp32"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Use factored_fp32"), "type": "bool"},
+ "use_stableadamw": {"title": QCA.translate("optimizer_parameter", "use_stableadamw"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Use use_stableadamw for gradient scaling"), "type": "bool"},
+ "use_cautious": {"title": QCA.translate("optimizer_parameter", "use_cautious"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Use cautious method"), "type": "bool"},
+ "use_grams": {"title": QCA.translate("optimizer_parameter", "use_grams"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Use grams method"), "type": "bool"},
+ "use_adopt": {"title": QCA.translate("optimizer_parameter", "use_adopt"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Use adopt method"), "type": "bool"},
+ "d_limiter": {"title": QCA.translate("optimizer_parameter", "d_limiter"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Prevent over-estimated LRs when gradients and EMA are still stabilizing"), "type": "bool"},
+ "use_schedulefree": {"title": QCA.translate("optimizer_parameter", "use_schedulefree"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Use Schedulefree method"), "type": "bool"},
+ "use_orthograd": {"title": QCA.translate("optimizer_parameter", "use_orthograd"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Use orthograd method"), "type": "bool"},
+ "nnmf_factor": {"title": QCA.translate("optimizer_parameter", "Factored Optimizer"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Enables a memory-efficient mode by applying fast low-rank factorization to the optimizers states. It combines factorization for magnitudes with 1-bit compression for signs, drastically reducing VRAM usage and allowing for larger models or batch sizes. This is an approximation which may slightly alter training dynamics."), "type": "bool"},
+ "orthogonal_gradient": {"title": QCA.translate("optimizer_parameter", "OrthoGrad"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Reduces overfitting by removing the gradient component parallel to the weight, thus improving generalization."), "type": "bool"},
+ "use_atan2": {"title": QCA.translate("optimizer_parameter", "Atan2 Scaling"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "A robust replacement for eps, which also incorporates gradient clipping, bounding and stabilizing the optimizer updates."), "type": "bool"},
+ "cautious_mask": {"title": QCA.translate("optimizer_parameter", "Cautious Variant"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Applies a mask to dampen or zero-out momentum components that disagree with the current gradients direction."), "type": "bool"},
+ "grams_moment": {"title": QCA.translate("optimizer_parameter", "GRAMS Variant"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Aligns the momentum direction with the current gradient direction while preserving its accumulated magnitude."), "type": "bool"},
+ "use_AdEMAMix": {"title": QCA.translate("optimizer_parameter", "AdEMAMix EMA"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Adds a second, slow-moving EMA, which is combined with the primary momentum to stabilize updates, and accelerate the training."), "type": "bool"},
+ "beta3_ema": {"title": QCA.translate("optimizer_parameter", "Beta3 EMA"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Coefficient for slow-moving EMA of AdEMAMix."), "type": "float"},
+ "beta1_warmup": {"title": QCA.translate("optimizer_parameter", "Beta1 Warmup Steps"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Number of warmup steps to gradually increase beta1 from Minimum Beta1 Value to its final value. During warmup, beta1 increases linearly. leave it empty to disable warmup and use constant beta1."), "type": "int"},
+ "min_beta1": {"title": QCA.translate("optimizer_parameter", "Minimum Beta1"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Starting beta1 value for warmup scheduling. Used only when beta1 warmup is enabled. Lower values allow faster initial adaptation, while higher values provide more smoothing. The final beta1 value is specified in the beta1 parameter."), "type": "float"},
+ "Simplified_AdEMAMix": {"title": QCA.translate("optimizer_parameter", "Simplified AdEMAMix"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Enables a simplified, single-EMA variant of AdEMAMix. Instead of blending two moving averages (fast and slow momentum), this version combines the raw current gradient (controlled by 'Grad α') directly with a single theory-based momentum. This makes the optimizer highly responsive to recent gradient information, which can accelerate training in all batch size scenarios when tuned correctly."), "type": "bool"},
+ "alpha_grad": {"title": QCA.translate("optimizer_parameter", "Grad α"), "tooltip": QCA.translate("optimizer_parameter_tooltip", "Controls the mixing coefficient between raw gradients and momentum gradients in Simplified AdEMAMix. Higher values (e.g., 10-100) emphasize recent gradients, suitable for small batch sizes to reduce noise. Lower values (e.g., 0-1) emphasize historical gradients, suitable for large batch sizes for stability. Setting to 0 uses only momentum gradients without raw gradient contribution."), "type": "float"},
+ 'kourkoutas_beta': {'title': QCA.translate("optimizer_parameter", 'Kourkoutas Beta'), 'tooltip': QCA.translate("optimizer_parameter_tooltip", 'Enables a layer-wise dynamic β₂ adaptation. This feature makes the optimizer more responsive to "spiky" gradients by lowering β₂ during periods of high variance, and more stable during calm periods by raising β₂ towards its maximum. It can significantly improve training stability and final loss.'), 'type': 'bool'},
+ 'k_warmup_steps': {'title': QCA.translate("optimizer_parameter", 'K-β Warmup Steps'), 'tooltip': QCA.translate("optimizer_parameter_tooltip", 'When using Kourkoutas Beta, the number of initial training steps during which the dynamic β₂ logic is held off. In this period, β₂ is set to its fixed value to allow for initial training stability before the adaptive mechanism activates.'), 'type': 'int'},
+ 'schedulefree_c': {'title': QCA.translate("optimizer_parameter", 'Schedule free averaging strength'), 'tooltip': QCA.translate("optimizer_parameter_tooltip", 'Larger values = more responsive (shorter averaging window); smaller values = smoother (longer window). Set to 0 to disable and use the original Schedule-Free rule. Short small batches (≈6-12); long/large-batch (≈50-200).'), 'type': 'float'},
+ }
+
+ # Quick way of connecting controls without redefining every single one of them.
+ state_ui_connections = {
+ f"optimizer.{k}": "{}{}".format(k, "Cbx" if v["type"] == "bool" else ("Sbx" if v["type"] == "int" else "Led")) for k, v in optimizer_params.items()
+ }
+
+ def __init__(self, loader, parent=None):
+ # Deferred import to avoid a circular import error.
+ from modules.util.optimizer_util import OPTIMIZER_DEFAULT_PARAMETERS
+ self.OPTIMIZER_DEFAULT_PARAMETERS = OPTIMIZER_DEFAULT_PARAMETERS
+
+ super().__init__(loader, "modules/ui/views/windows/optimizer.ui", name=None, parent=parent)
+
+ ###FSM###
+
+ def _setup(self):
+ for row, k in enumerate(sorted(self.optimizer_params.keys())):
+ v = self.optimizer_params[k]
+
+ if v["type"] == "bool":
+ wdg_name = "{}{}".format(k, "Cbx")
+ wdg = QtW.QCheckBox(parent=self.ui, text=v["title"], objectName=wdg_name)
+ wdg.setToolTip(v["tooltip"])
+ self.ui.optimizerLay.addWidget(wdg, row, 0, 1, 2)
+ else:
+ wdg_name = "{}{}".format(k, "Lbl")
+ lbl = QtW.QLabel(parent=self.ui, text=v["title"], objectName=wdg_name)
+ if v["type"] == "int":
+ wdg_name = "{}{}".format(k, "Sbx")
+ wdg = QtW.QSpinBox(parent=self.ui, objectName=wdg_name)
+ wdg.setMinimum(-999999)
+ wdg.setMaximum(999999)
+ else:
+ wdg_name = "{}{}".format(k, "Led")
+ wdg = QtW.QLineEdit(parent=self.ui, objectName=wdg_name)
+ self._connectScientificNotation(wdg, inf=True, neg_inf=True)
+ wdg.setToolTip(v["tooltip"])
+ lbl.setBuddy(wdg)
+ self.ui.optimizerLay.addWidget(lbl, row, 0, 1, 1)
+ self.ui.optimizerLay.addWidget(wdg, row, 1, 1, 1)
+
+ def __loadDefaults(self):
+ def f():
+ optimizer = self.ui.optimizerCmb.currentData()
+ if optimizer is not None:
+ for k, v in self.OPTIMIZER_DEFAULT_PARAMETERS[optimizer].items():
+ StateModel.instance().set_state(f"optimizer.{k}", v)
+
+ QtW.QApplication.instance().stateChanged.emit()
+ QtW.QApplication.instance().optimizerChanged.emit(optimizer)
+ return f
+
+ def _connectUIBehavior(self):
+ self._connect(self.ui.optimizerCmb.activated, self.__updateOptimizer(from_index=True))
+ self._connect(self.ui.loadDefaultsBtn.clicked, self.__loadDefaults())
+
+
+ callback = self.__updateOptimizer(from_index=False)
+ self._connect(QtW.QApplication.instance().optimizerChanged, callback)
+ self._connect(QtW.QApplication.instance().stateChanged, lambda: callback(StateModel.instance().get_state("optimizer.optimizer")))
+
+ ###Reactions####
+
+ def __updateOptimizer(self, from_index=False):
+ def f(idx):
+ self.parent.ui.optimizerCmb.blockSignals(True)
+ self.parent.ui.optimizerCmb.setCurrentIndex(idx)
+ self.parent.ui.optimizerCmb.blockSignals(False)
+
+ self.__updateOptimizerControls(self.ui.optimizerCmb.currentData())
+
+
+ def g(optimizer):
+ self.parent.ui.optimizerCmb.blockSignals(True)
+ self.parent.ui.optimizerCmb.setCurrentIndex(self.parent.ui.optimizerCmb.findData(optimizer))
+ self.parent.ui.optimizerCmb.blockSignals(False)
+
+ self.__updateOptimizerControls(optimizer)
+
+ return f if from_index else g
+
+ ###Utils###
+
+ def __updateOptimizerControls(self, optimizer):
+ # QGridLayout has no direct children, therefore, we must retrieve them in a different way.
+ for k, v in self.optimizer_params.items():
+ if k in self.OPTIMIZER_DEFAULT_PARAMETERS[optimizer]:
+ val = StateModel.instance().get_state(f"optimizer.{k}")
+ else:
+ val = None
+ if v["type"] == "bool":
+ wdg = self.ui.findChild(QtW.QCheckBox, f"{k}Cbx")
+ if wdg is not None:
+ wdg.setVisible(k in self.OPTIMIZER_DEFAULT_PARAMETERS[optimizer])
+ if val is not None:
+ wdg.setChecked(bool(val))
+ else:
+ wdg = self.ui.findChild(QtW.QLabel, f"{k}Lbl")
+ if wdg is not None:
+ wdg.setVisible(k in self.OPTIMIZER_DEFAULT_PARAMETERS[optimizer])
+ if v["type"] == "int":
+ wdg = self.ui.findChild(QtW.QSpinBox, f"{k}Sbx")
+ if wdg is not None:
+ wdg.setVisible(k in self.OPTIMIZER_DEFAULT_PARAMETERS[optimizer])
+ if val is not None:
+ wdg.setValue(int(val))
+ else:
+ wdg = self.ui.findChild(QtW.QLineEdit, f"{k}Led")
+ if wdg is not None:
+ wdg.setVisible(k in self.OPTIMIZER_DEFAULT_PARAMETERS[optimizer])
+ if val is not None:
+ wdg.setText(str(val))
diff --git a/modules/ui/controllers/windows/ProfileController.py b/modules/ui/controllers/windows/ProfileController.py
new file mode 100644
index 000000000..d716f867b
--- /dev/null
+++ b/modules/ui/controllers/windows/ProfileController.py
@@ -0,0 +1,39 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.StateModel import StateModel
+
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class ProfileController(BaseController):
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/windows/profile.ui", name=None, parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connect(self.ui.dumpBtn.clicked, self.__dump())
+ self._connect(self.ui.startBtn.clicked, self.__toggleProfiling())
+
+ ###Reactions###
+
+ def __dump(self):
+ @Slot()
+ def f():
+ StateModel.instance().dump_stack()
+ return f
+
+ def __toggleProfiling(self):
+ @Slot()
+ def f():
+ StateModel.instance().toggle_profiler()
+ if StateModel.instance().is_profiling:
+ self.ui.statusLbl.setText(QCA.translate("profiling_window", "Profiling active..."))
+ self.ui.startBtn.setText(QCA.translate("profiling_window", "End Profiling"))
+ else:
+ self.ui.statusLbl.setText(QCA.translate("profiling_window", "Inactive"))
+ self.ui.startBtn.setText(QCA.translate("profiling_window", "Start Profiling"))
+
+ # TODO: this button exits the application if not run from Scalene. It would be nice to disable it when running from python.
+ # However the library does not expose a function to check before running scalene_profiler.start()
+ return f
diff --git a/modules/ui/controllers/windows/SampleController.py b/modules/ui/controllers/windows/SampleController.py
new file mode 100644
index 000000000..3a5cc863b
--- /dev/null
+++ b/modules/ui/controllers/windows/SampleController.py
@@ -0,0 +1,71 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.controllers.widgets.SampleParamsController import SampleParamsController
+from modules.ui.models.SamplingModel import SamplingModel
+from modules.ui.utils.WorkerPool import WorkerPool
+
+import PySide6.QtGui as QtGui
+from PIL.ImageQt import ImageQt
+from PySide6.QtCore import Slot
+
+
+class SampleController(BaseController):
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/windows/sample.ui", name=None, parent=parent)
+
+ ###FSM###
+
+ def _setup(self):
+ self.samplingParams = SampleParamsController(self.loader, model_instance=SamplingModel.instance(), write_signal=self.ui.sampleBtn.clicked, parent=self.parent)
+ self.ui.paramsLay.addWidget(self.samplingParams.ui)
+
+ def _connectUIBehavior(self):
+ self._connect(self.ui.sampleBtn.clicked, self.__startSample())
+
+ self.__enableControls(True)()
+
+ ###Reactions###
+
+ def __startSample(self):
+ @Slot()
+ def f():
+ worker, name = WorkerPool.instance().createNamed(self.__sample(), "sample_image", poolless=True,
+ inject_progress_callback=True)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableControls(False), result_fn=None,
+ finished_fn=self.__enableControls(True),
+ errored_fn=self.__enableControls(True), aborted_fn=self.__enableControls(True),
+ progress_fn=self.__updateStatus())
+ WorkerPool.instance().start(name)
+
+ return f
+
+ def __enableControls(self, enabled):
+ @Slot()
+ def f():
+ self.ui.sampleBtn.setEnabled(enabled)
+ if enabled:
+ self.ui.progressBar.setValue(0)
+
+ return f
+
+ def __updateStatus(self):
+ progress_fn = self._updateProgress(self.ui.progressBar)
+
+ @Slot(dict)
+ def f(data):
+ if "status" in data:
+ self.ui.statusLbl.setText(data["status"])
+
+ if "data" in data:
+ self.ui.previewLbl.setPixmap(QtGui.QPixmap.fromImage(ImageQt(data["data"])))
+
+ progress_fn(data)
+
+ return f
+
+ ###Utils###
+
+ def __sample(self):
+ def f(progress_fn=None):
+ SamplingModel.instance().sample(progress_fn=progress_fn)
+ return f
diff --git a/modules/ui/controllers/windows/SaveController.py b/modules/ui/controllers/windows/SaveController.py
new file mode 100644
index 000000000..a0b9b5ab8
--- /dev/null
+++ b/modules/ui/controllers/windows/SaveController.py
@@ -0,0 +1,33 @@
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.StateModel import StateModel
+
+import PySide6.QtGui as QtGui
+import PySide6.QtWidgets as QtW
+from PySide6.QtCore import Slot
+
+
+class SaveController(BaseController):
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/windows/save.ui", name=None, parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connect(self.ui.cancelBtn.clicked, lambda: self.ui.hide())
+ self._connect(self.ui.okBtn.clicked, self.__save())
+
+ def _connectInputValidation(self):
+ self.ui.configCmb.setValidator(QtGui.QRegularExpressionValidator(r"[a-zA-Z0-9_\-.][a-zA-Z0-9_\-. ]*", self.ui))
+
+ ###Reactions###
+
+ def __save(self):
+ @Slot()
+ def f():
+ name = self.ui.configCmb.currentText()
+ if name != "" and not name.startswith("#"):
+ StateModel.instance().save_to_file(name)
+
+ QtW.QApplication.instance().savedConfig.emit(name)
+ self.ui.hide()
+ return f
diff --git a/modules/ui/controllers/windows/VideoController.py b/modules/ui/controllers/windows/VideoController.py
new file mode 100644
index 000000000..f449ca21d
--- /dev/null
+++ b/modules/ui/controllers/windows/VideoController.py
@@ -0,0 +1,252 @@
+import os
+
+from modules.ui.controllers.BaseController import BaseController
+from modules.ui.models.VideoModel import VideoModel
+from modules.ui.utils.WorkerPool import WorkerPool
+
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Slot
+
+
+class VideoController(BaseController):
+ def __init__(self, loader, parent=None):
+ super().__init__(loader, "modules/ui/views/windows/video.ui", name=None, parent=parent)
+
+ ###FSM###
+
+ def _connectUIBehavior(self):
+ self._connectFileDialog(self.ui.linkListBtn, self.ui.linkListLed, is_dir=False, save=False,
+ title=QCA.translate("dialog_window", "Open Link list"),
+ filters=QCA.translate("filetype_filters",
+ "Text (*.txt)"))
+ self._connectFileDialog(self.ui.singleVideo1Btn, self.ui.singleVideo1Led, is_dir=False, save=False,
+ title=QCA.translate("dialog_window", "Open Video"),
+ filters=QCA.translate("filetype_filters",
+ "Video (*.m4v *.wmv *.mp4 *.avi *.webm)"))
+ self._connectFileDialog(self.ui.singleVideo2Btn, self.ui.singleVideo2Led, is_dir=False, save=False,
+ title=QCA.translate("dialog_window", "Open Video"),
+ filters=QCA.translate("filetype_filters",
+ "Video (*.m4v *.wmv *.mp4 *.avi *.webm)"))
+
+ self._connectFileDialog(self.ui.directory1Btn, self.ui.directory1Led, is_dir=True, save=False,
+ title=QCA.translate("dialog_window", "Open Video directory"))
+ self._connectFileDialog(self.ui.directory2Btn, self.ui.directory2Led, is_dir=True, save=False,
+ title=QCA.translate("dialog_window", "Open Video directory"))
+
+ self._connectFileDialog(self.ui.output1Btn, self.ui.output1Led, is_dir=True, save=True,
+ title=QCA.translate("dialog_window", "Save Video directory"))
+ self._connectFileDialog(self.ui.output2Btn, self.ui.output2Led, is_dir=True, save=True,
+ title=QCA.translate("dialog_window", "Save Video directory"))
+ self._connectFileDialog(self.ui.output3Btn, self.ui.output3Led, is_dir=True, save=True,
+ title=QCA.translate("dialog_window", "Save Video directory"))
+
+ self._connect(self.ui.infoBtn.clicked, lambda: self._openUrl("https://github.com/yt-dlp/yt-dlp?tab=readme-ov-file#usage-and-options"))
+
+ state_ui_connections = {
+ "clips.single_video": "singleVideo1Led",
+ "clips.range_start": "timeRangeStart1Led",
+ "clips.range_end": "timeRangeStop1Led",
+ "clips.directory": "directory1Led",
+ "clips.output": "output1Led",
+ "clips.output_to_subdirectories": "outputSubdirectories1Cbx",
+ "clips.split_at_cuts": "splitCutsCbx",
+ "clips.max_length": "maxLengthSbx",
+ "clips.fps": "fpsSbx",
+ "clips.remove_borders": "removeBorders1Cbx",
+ "clips.crop_variation": "cropVariation1Sbx",
+
+ "images.single_video": "singleVideo2Led",
+ "images.range_start": "timeRangeStart2",
+ "images.range_end": "timeRangeStop2",
+ "images.directory": "directory2Led",
+ "images.output": "output2Led",
+ "images.output_to_subdirectories": "outputSubdirectories2Cbx",
+ "images.capture_rate": "imagesSecSbx",
+ "images.blur_removal": "blurRemovalSbx",
+ "images.remove_borders": "removeBorders2Cbx",
+ "images.crop_variation": "cropVariation2Sbx",
+
+ "download.single_link": "singleLinkLed",
+ "download.link_list": "linkListLed",
+ "download.output": "output3Led",
+ "download.additional_args": "additionalArgsTed",
+ }
+ self._connectStateUI(state_ui_connections, VideoModel.instance(), update_after_connect=True)
+
+
+ self._connect(self.ui.extractSingle1Btn.clicked, self.__startClipSingle())
+ self._connect(self.ui.extractDirectory1Btn.clicked, self.__startClipDirectory())
+ self._connect(self.ui.extractSingle2Btn.clicked, self.__startImageSingle())
+ self._connect(self.ui.extractDirectory2Btn.clicked, self.__startImageDirectory())
+ self._connect(self.ui.downloadLinkBtn.clicked, self.__startDownloadLink())
+ self._connect(self.ui.downloadListBtn.clicked, self.__startDownloadList())
+
+ self.__enableButtons(True)()
+
+
+ ###Reactions###
+
+ def __enableButtons(self, enabled):
+ @Slot()
+ def f():
+ self.ui.extractSingle1Btn.setEnabled(enabled)
+ self.ui.extractDirectory1Btn.setEnabled(enabled)
+ self.ui.extractSingle2Btn.setEnabled(enabled)
+ self.ui.extractDirectory2Btn.setEnabled(enabled)
+ self.ui.downloadLinkBtn.setEnabled(enabled)
+ self.ui.downloadListBtn.setEnabled(enabled)
+ return f
+
+ def __startClipSingle(self):
+ @Slot()
+ def f():
+ if self.ui.singleVideo1Led.text() != "":
+ if os.path.exists(self.ui.singleVideo1Led.text()):
+ if self.ui.output1Led.text() != "":
+ worker, name = WorkerPool.instance().createNamed(self.__extractClip(), "video_processing", batch_mode=False)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableButtons(False), result_fn=None, finished_fn=self.__enableButtons(True),
+ errored_fn=self.__enableButtons(True), aborted_fn=self.__enableButtons(True))
+ WorkerPool.instance().start(name)
+ else:
+ self._openAlert(QCA.translate("video_window", "No Folder Selected"),
+ QCA.translate("video_window", "Please select an output folder"))
+ else:
+ self._openAlert(QCA.translate("video_window", "Invalid File"),
+ QCA.translate("video_window", "The selected input file does not exist"),
+ type="critical")
+ else:
+ self._openAlert(QCA.translate("video_window", "No File Selected"),
+ QCA.translate("video_window", "Please select an input file"))
+ return f
+
+ def __startClipDirectory(self):
+ @Slot()
+ def f():
+ if self.ui.directory1Led.text() != "":
+ if os.path.isdir(self.ui.directory1Led.text()):
+ if self.ui.output1Led.text() != "":
+ worker, name = WorkerPool.instance().createNamed(self.__extractClip(), "video_processing", batch_mode=True)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableButtons(False), result_fn=None, finished_fn=self.__enableButtons(True),
+ errored_fn=self.__enableButtons(True), aborted_fn=self.__enableButtons(True))
+ WorkerPool.instance().start(name)
+ else:
+ self._openAlert(QCA.translate("video_window", "No Folder Selected"),
+ QCA.translate("video_window", "Please select an output folder"))
+ else:
+ self._openAlert(QCA.translate("video_window", "Invalid Folder"),
+ QCA.translate("video_window", "The selected input folder does not exist"),
+ type="critical")
+ else:
+ self._openAlert(QCA.translate("video_window", "No Folder Selected"),
+ QCA.translate("video_window", "Please select an input folder"))
+ return f
+
+ def __startImageSingle(self):
+ @Slot()
+ def f():
+ if self.ui.singleVideo2Led.text() != "":
+ if os.path.exists(self.ui.singleVideo2Led.text()):
+ if self.ui.output2Led.text() != "":
+ worker, name = WorkerPool.instance().createNamed(self.__extractImage(), "video_processing", batch_mode=False)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableButtons(False), result_fn=None, finished_fn=self.__enableButtons(True),
+ errored_fn=self.__enableButtons(True), aborted_fn=self.__enableButtons(True))
+ WorkerPool.instance().start(name)
+ else:
+ self._openAlert(QCA.translate("video_window", "No Folder Selected"),
+ QCA.translate("video_window", "Please select an output folder"))
+ else:
+ self._openAlert(QCA.translate("video_window", "Invalid File"),
+ QCA.translate("video_window", "The selected input file does not exist"),
+ type="critical")
+ else:
+ self._openAlert(QCA.translate("video_window", "No File Selected"),
+ QCA.translate("video_window", "Please select an input file"))
+ return f
+
+ def __startImageDirectory(self):
+ @Slot()
+ def f():
+ if self.ui.directory2Led.text() != "":
+ if os.path.isdir(self.ui.directory2Led.text()):
+ if self.ui.output2Led.text() != "":
+ worker, name = WorkerPool.instance().createNamed(self.__extractImage(), "video_processing", batch_mode=True)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableButtons(False), result_fn=None, finished_fn=self.__enableButtons(True),
+ errored_fn=self.__enableButtons(True), aborted_fn=self.__enableButtons(True))
+ WorkerPool.instance().start(name)
+ else:
+ self._openAlert(QCA.translate("video_window", "No Folder Selected"),
+ QCA.translate("video_window", "Please select an output folder"))
+ else:
+ self._openAlert(QCA.translate("video_window", "Invalid Folder"),
+ QCA.translate("video_window", "The selected input folder does not exist"),
+ type="critical")
+ else:
+ self._openAlert(QCA.translate("video_window", "No Folder Selected"),
+ QCA.translate("video_window", "Please select an input folder"))
+ return f
+
+ def __startDownloadLink(self):
+ @Slot()
+ def f():
+ if self.ui.singleLinkLed.text() != "":
+ if self.ui.output3Led.text() != "":
+ worker, name = WorkerPool.instance().createNamed(self.__download(), "video_processing", batch_mode=False)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableButtons(False), result_fn=None, finished_fn=self.__enableButtons(True),
+ errored_fn=self.__enableButtons(True), aborted_fn=self.__enableButtons(True))
+ WorkerPool.instance().start(name)
+ else:
+ self._openAlert(QCA.translate("video_window", "No Folder Selected"),
+ QCA.translate("video_window", "Please select an output folder"))
+ else:
+ self._openAlert(QCA.translate("video_window", "No URL Provided"),
+ QCA.translate("video_window", "Please insert a valid URL"))
+ return f
+
+ def __startDownloadList(self):
+ @Slot()
+ def f():
+ if self.ui.linkListLed.text() != "":
+ if os.path.exists(self.ui.linkListLed.text()):
+ if self.ui.output3Led.text() != "":
+ worker, name = WorkerPool.instance().createNamed(self.__download(), "video_processing", batch_mode=True)
+ if worker is not None:
+ worker.connectCallbacks(init_fn=self.__enableButtons(False), result_fn=None, finished_fn=self.__enableButtons(True),
+ errored_fn=self.__enableButtons(True), aborted_fn=self.__enableButtons(True))
+ WorkerPool.instance().start(name)
+ else:
+ self._openAlert(QCA.translate("video_window", "No Folder Selected"),
+ QCA.translate("video_window", "Please select an output folder"))
+ else:
+ self._openAlert(QCA.translate("video_window", "Invalid File"),
+ QCA.translate("video_window", "The selected input file does not exist"),
+ type="critical")
+ else:
+ self._openAlert(QCA.translate("video_window", "No File Selected"),
+ QCA.translate("video_window", "Please select an input file"))
+ return f
+
+ ###Utils###
+
+ def __extractClip(self):
+ def f(batch_mode):
+ return VideoModel.instance().extract_clips_multi(batch_mode)
+
+ return f
+
+ def __extractImage(self):
+ def f(batch_mode):
+ return VideoModel.instance().extract_images_multi(batch_mode)
+
+ return f
+
+
+ def __download(self):
+ def f(batch_mode):
+ return VideoModel.instance().download_multi(batch_mode)
+
+ return f
diff --git a/modules/ui/models/BulkCaptionModel.py b/modules/ui/models/BulkCaptionModel.py
new file mode 100644
index 000000000..b9372609c
--- /dev/null
+++ b/modules/ui/models/BulkCaptionModel.py
@@ -0,0 +1,83 @@
+import functools
+import re
+from multiprocessing import Pool
+from pathlib import Path
+
+from modules.ui.models.SingletonConfigModel import SingletonConfigModel
+from modules.util.enum.BulkEditMode import BulkEditMode
+
+
+def _edit_text(config, read_only, file):
+ try:
+ changed = False
+ with open(file, "r+", encoding="utf-8") as f:
+ content = original_content = f.read()
+ if config["add_text"] != "":
+ content = f"{config['add_text']} {content}" if config["add_mode"] == BulkEditMode.PREPEND else f"{content} {config['add_text']}"
+ if config["remove_text"] != "":
+ content = content.replace(config["remove_text"], "")
+ if config["replace_text"] != "" and config["replace_with"] != "":
+ content = content.replace(config["replace_text"], config["replace_with"])
+ if config["regex_pattern"] != "" and config["regex_replace"] != "":
+ regex = re.compile(config["regex_pattern"])
+ content = regex.sub(config["regex_replace"], content)
+
+ changed = content != original_content
+
+ if not read_only and changed:
+ f.seek(0)
+ f.truncate()
+ f.write(content)
+ except (OSError, re.error):
+ return content, False
+
+ return content, changed
+
+
+class BulkCaptionModel(SingletonConfigModel):
+ def __init__(self):
+ super().__init__({
+ "directory": "",
+ "add_text": "",
+ "add_mode": BulkEditMode.PREPEND,
+ "remove_text": "",
+ "replace_text": "",
+ "replace_with": "",
+ "regex_pattern": "",
+ "regex_replace": "",
+ })
+
+ self.pool = None
+
+ def terminate_pool(self):
+ if self.pool is not None:
+ self.pool.terminate()
+ self.pool.join()
+ self.pool = None
+
+
+ def bulk_edit(self, read_only=False, preview_n=None, progress_fn=None):
+ base_path = Path(self.get_state("directory"))
+ files = list(base_path.glob("*.txt"))
+
+ if self.pool is None:
+ self.pool = Pool()
+
+ if preview_n is not None and read_only:
+ files = files[:preview_n]
+
+ with self.critical_region_read():
+ result = self.pool.map(functools.partial(_edit_text, self.config, read_only), files)
+
+ total = len(result)
+ processed = len([r for r in result if r[1]])
+ skipped = total - processed
+
+ if preview_n is not None:
+ result = result[:preview_n]
+
+ if progress_fn is not None:
+ if read_only:
+ progress_fn({"status": f"Previewing {total} files: {processed} will be processed, {skipped} will be skipped", "data": "\n\n".join([r[0] for r in result])})
+ else:
+ progress_fn({"status": f"Edited {total} files: {processed} processed, {skipped} skipped", "data": "\n\n".join([r[0] for r in result])})
diff --git a/modules/ui/models/BulkImageModel.py b/modules/ui/models/BulkImageModel.py
new file mode 100644
index 000000000..c124b322b
--- /dev/null
+++ b/modules/ui/models/BulkImageModel.py
@@ -0,0 +1,534 @@
+import functools
+import os
+import random
+import threading
+import traceback
+import uuid
+from multiprocessing import Pool
+from pathlib import Path
+
+from modules.ui.models.SingletonConfigModel import SingletonConfigModel
+from modules.util import image_util, path_util
+from modules.util.enum.ImageMegapixels import ImageMegapixels
+from modules.util.enum.ImageOperations import ImageOperations
+from modules.util.enum.ImageOptimization import ImageOptimization
+
+import imagesize
+import oxipng
+from PIL import Image, ImageColor
+
+# multiprocessing.Pool requires pickable functions, which must be defined outside classes.
+# https://pypi.org/project/pathos/ uses dill as serialization engine, allowing more flexibility, in case we want to switch to something more mainteinable.
+
+def _verify_image(file):
+ """
+ Verify a single image file for corruption. Raises ValueError on failure.
+
+ Note: The image is opened twice because Pillow's .verify() invalidates the image object.
+ """
+ file_path = Path(file)
+ valid = False
+ try:
+ with Image.open(file_path) as img:
+ img.verify()
+ with Image.open(file_path) as img:
+ img.load()
+ if hasattr(img, "getpixel"):
+ img.getpixel((0, 0))
+ valid = True
+ except Exception:
+ valid = False
+
+ return valid
+
+def _process_alpha_image(bg_color_tuple, file):
+ with Image.open(file) as img:
+ if img.mode not in ("RGBA", "LA"):
+ return False
+
+ background = Image.new("RGB", img.size, bg_color_tuple)
+ background.paste(img, (0, 0), img)
+ background.save(str(file))
+ return True
+
+def _resize_large_image(target_pixels, file):
+ width, height = imagesize.get(file)
+ if width * height <= target_pixels:
+ return False
+
+ new_width, new_height = BulkImageModel.calculate_dimensions_for_megapixels(
+ width, height, target_pixels
+ )
+ reduction_factor = width / new_width
+
+ with Image.open(file) as img:
+ resample_filter = (
+ Image.Resampling.LANCZOS
+ if hasattr(Image, "Resampling")
+ else Image.LANCZOS
+ )
+
+ if reduction_factor >= 3 and hasattr(image_util, "dpid_resize"):
+ resized_img = image_util.dpid_resize(img, (new_width, new_height))
+ else:
+ resized_img = img.resize((new_width, new_height), resample=resample_filter)
+
+ save_kwargs = {}
+ if file.suffix.lower() in [".jpg", ".jpeg"]:
+ save_kwargs["quality"] = 95
+ if "icc_profile" in img.info:
+ save_kwargs["icc_profile"] = img.info["icc_profile"]
+ if "exif" in img.info:
+ save_kwargs["exif"] = img.info["exif"]
+ elif file.suffix.lower() == ".png":
+ save_kwargs["compress_level"] = 4
+
+ if resized_img.mode == "P":
+ resized_img = resized_img.convert("RGB")
+
+ resized_img.save(str(file), **save_kwargs)
+ return True
+
+
+def _optimize_png(file):
+ original_size = file.stat().st_size
+ oxipng.optimize(file, level=5, fix_errors=True)
+ new_size = file.stat().st_size
+ bytes_saved = original_size - new_size
+ if bytes_saved > 0:
+ return True, bytes_saved
+ else:
+ return False, 0
+
+def _is_lossless_check(file, img, check_lossless):
+ if not check_lossless:
+ return False
+ else:
+ return file.suffix.lower() in {".jpg", ".jpeg"}
+
+def _convert_image(format_options, file):
+ # Use local variables for performance and clarity
+ format_ext = format_options["format_ext"]
+ pil_format = format_options["pil_format"]
+ lossless_extensions = format_options.get("lossless_extensions", set())
+ quality = format_options.get("quality", 90)
+ save_kwargs_base = format_options.get("save_kwargs", {})
+
+ original_size = file.stat().st_size
+ new_path = file.with_suffix(format_ext)
+
+ with Image.open(file) as img:
+ is_lossless = file.suffix.lower() in lossless_extensions or _is_lossless_check(file, img, format_options["check_lossless"])
+
+ save_kwargs = save_kwargs_base.copy()
+ save_kwargs["quality"] = quality
+ if is_lossless:
+ save_kwargs["lossless"] = True
+
+ img.save(new_path, pil_format, **save_kwargs)
+
+ if not new_path.exists():
+ return "errors", (file.name, "Failed to save new file.")
+
+ new_size = new_path.stat().st_size
+ bytes_saved = original_size - new_size
+
+ if new_size < original_size:
+ file.unlink()
+ return True, bytes_saved
+ else:
+ new_path.unlink()
+ return False, 0
+
+
+class BulkImageModel(SingletonConfigModel):
+ def __init__(self):
+ super().__init__({
+ "directory": "",
+ "verify_images": False,
+ "sequential_rename": False,
+ "process_alpha": False,
+ "resize_large_images": False,
+ "optimization_type": ImageOptimization.NONE,
+ "resize_megapixels": ImageMegapixels.COMPUTE_PROOF_MEGAPIXEL_THRESHOLD,
+ "resize_custom_megapixels": 4,
+ "alpha_bg_color": "#ffffff",
+ })
+
+ self.pool = None
+ self.abort_flag = threading.Event()
+ self.progress_fn = None
+
+ def process_files(self, progress_fn=None):
+ with self.critical_region_read():
+ directory = self.get_state("directory")
+ self.progress_fn = progress_fn
+
+ if self.pool is None:
+ self.pool = Pool()
+
+ if os.path.isdir(directory):
+ path = Path(directory)
+ files = [f for f in path.iterdir() if f.is_file()]
+ self.log("info", f"Found {len(files)} files in {directory}")
+
+ self.__run_operations(files)
+
+ def terminate_pool(self):
+ if self.pool is not None:
+ self.pool.terminate()
+ self.pool.join()
+ self.pool = None
+
+ def __run_operations(self, files):
+ operations = []
+ if self.get_state("verify_images"):
+ operations.append(ImageOperations.VERIFY_IMG)
+ if self.get_state("sequential_rename"):
+ operations.append(ImageOperations.SEQUENTIAL_RENAME)
+ if self.get_state("process_alpha"):
+ operations.append(ImageOperations.PROCESS_ALPHA)
+ if self.get_state("resize_large_images"):
+ operations.append(ImageOperations.RESIZE_LARGE_IMG)
+
+ opt = self.get_state("optimization_type")
+ if opt == ImageOptimization.PNG:
+ operations.append(ImageOperations.OPTIMIZE_PNG)
+ elif opt == ImageOptimization.WEBP:
+ operations.append(ImageOperations.CONVERT_WEBP)
+ elif opt == ImageOptimization.JXL:
+ operations.append(ImageOperations.CONVERT_JXL)
+
+ # The correctness of these operations relies on the insertion order, as SEQUENTIAL_RENAME, CONVERT_WEBP and CONVERT_JXL change the list of files.
+ # In the original implementation (and this one as well) conversion is performed last, so we need to keep track of changed files only for SEQUENTIAL_RENAME.
+
+ total_ops = len(operations)
+ i = 0
+ while not self.abort_flag.is_set() and len(operations) > 0:
+ op= operations.pop(0)
+ i += 1
+ if self.progress_fn is not None:
+ self.progress_fn({"status": op.pretty_print(), "value": i, "max_value": total_ops, "data": f"{op.pretty_print()}..."})
+ try:
+ if op == ImageOperations.VERIFY_IMG:
+ self.__verify_images(files)
+ elif op == ImageOperations.SEQUENTIAL_RENAME:
+ files = self.__rename_files_sequentially(files)
+ elif op == ImageOperations.PROCESS_ALPHA:
+ self.__process_alpha_images(files)
+ elif op == ImageOperations.RESIZE_LARGE_IMG:
+ self.__resize_large_images(files)
+ elif op == ImageOperations.OPTIMIZE_PNG:
+ self.__optimize_pngs(files)
+ elif op == ImageOperations.CONVERT_WEBP:
+ self.__convert_to_webp(files)
+ elif op == ImageOperations.CONVERT_JXL:
+ self.__convert_to_jpegxl(files)
+ except Exception:
+ if self.progress_fn is not None:
+ self.progress_fn({"status": f"Error during {op.pretty_print().lower()}"})
+ self.log("critical", traceback.format_exc())
+
+ if self.abort_flag.is_set():
+ if self.progress_fn is not None:
+ self.progress_fn({"status": "Processing aborted"})
+ else:
+ if self.progress_fn is not None:
+ self.progress_fn({"status": "Processing complete", "value": 0, "max_value": 0})
+
+ @staticmethod
+ def calculate_dimensions_for_megapixels(
+ original_width: int, original_height: int, target_pixels: int
+ ) -> tuple[int, int]:
+ """Calculates new dimensions to fit an image within a pixel budget."""
+ original_pixels = original_width * original_height
+ if original_pixels <= target_pixels:
+ return original_width, original_height
+
+ scale_factor = (target_pixels / original_pixels) ** 0.5
+ new_width = int(original_width * scale_factor)
+ new_height = int(original_height * scale_factor)
+ return new_width, new_height
+
+ @staticmethod
+ def _filter_is_image(f: Path) -> bool:
+ """Filter for supported image files."""
+ return path_util.is_supported_image_extension(f.suffix)
+
+ @staticmethod
+ def _filter_images_and_skip_masks(f: Path) -> bool:
+ """Filter for supported image files, excluding mask files."""
+ return path_util.is_supported_image_extension(f.suffix) and not f.stem.endswith("-masklabel")
+
+ @staticmethod
+ def _file_filter(f, skip_extensions):
+ return (
+ path_util.is_supported_image_extension(f.suffix)
+ and f.suffix.lower() not in skip_extensions
+ and "-" not in f.stem # Skip mask files
+ )
+
+ def __verify_images(self, files):
+ # We want to check only images.
+ files = [f for f in files if self._filter_is_image(f)]
+
+ # TODO: maybe this should also verify that there are no ambiguous images (e.g., "1.png" and "1.jpeg" in the same folder)
+
+ is_valid = self.pool.map(_verify_image, files)
+
+ total = len(is_valid)
+ valid = len([v for v in is_valid if v])
+ invalid = total - valid
+
+ # If there is no progress callback, there is no reason to compute the output. Note that _verify_image may still be reimplemented in the future to attempt a file fix, therefore we should still map files to it.
+ if self.progress_fn is not None:
+ for file, valid in zip(files, is_valid, strict=True):
+ if valid:
+ self.progress_fn({"data": f"✓ {file.name} is valid"})
+ else:
+ self.progress_fn({"data": f"✗ {file.name} is CORRUPTED"}) # TODO: the original implementation also passed the exception message.
+
+ self.progress_fn({"data": f"Checked {total} files: {valid} valid, {invalid} invalid"})
+
+ def __rename_files_sequentially(self, files):
+ outfiles = files
+ if len(files) > 0:
+ # TODO: This does not always work (after processing multiple files are deleted, and caption/mask become associated with the wrong image), but could NOT reproduce bug with:
+ # dataset with images + masks + captions
+ # dataset with some captions missing
+ # dataset with some masks missing
+ # dataset with mixed png, jpeg # TODO: do we also want to add an operation to convert all images to the same format (other than webp/jxl)?
+ # Probably it was an edge case like "img.png + img.jpeg + img.txt + img-masklabel.png", somehow desynchronizing all the other valid triplets?
+ # The other possible cause may be an incorrect exception handling.
+
+ # TODO Improvement: should unlink() be replaced by the OS' send to recycling bin? https://pypi.org/project/Send2Trash/
+
+
+ groups = {}
+
+ for f in files:
+ stem = f.stem
+
+ key = stem.removesuffix("-masklabel") if stem.endswith("-masklabel") else stem
+
+ if key not in groups:
+ groups[key] = {"image": None, "caption": None, "masks": []}
+
+ if stem.endswith("-masklabel"):
+ groups[key]["masks"].append(f) # TODO: Why can it be the case that we have multiple masks?
+ elif self._filter_is_image(f):
+ groups[key]["image"] = f
+ elif f.suffix.lower() == ".txt":
+ groups[key]["caption"] = f
+
+
+ image_groups = {stem: data for stem, data in groups.items() if data["image"] is not None}
+
+ if self.progress_fn is not None:
+ self.progress_fn({"data": f"Found {len(image_groups)} images and their associated files to rename."})
+
+ # Process files only if there is at least one image, and they are not already sorted.
+ if len(image_groups) > 0:
+ if not any(
+ (not f.isdigit()) or (int(f) != i + 1)
+ for i, f in enumerate(sorted(image_groups.keys()))
+ ):
+ if self.progress_fn is not None:
+ self.progress_fn({"data": "Files are already named sequentially. No action needed."})
+ else:
+ # TODO: from what I understand, the original intended behavior was to undo EVERY renaming, in case of an OSError.
+ # I think a best-effort strategy is more reasonable: if a single sample fails, rollback only its files (image, caption and masks), and stop there.
+ # This is because if you get an error during renaming, either there are permission issues (on some files), or the directory is no longer writeable/mounted (and in that case rollbacking would fail, and we would be in an inconsistent state anyway).
+ # The only goal of the rollback mechanism is to attempt to guarantee that no (image, caption, mask) triplet association is lost due to partial renamings.
+ outfiles = []
+
+ renaming = []
+ for i, img in enumerate(sorted(image_groups), start=1):
+ tmp = image_groups[img]
+ tmp_name = str(uuid.uuid4().hex)
+ renaming.extend([(tmp["image"], tmp["image"].with_name(f"{tmp_name}{tmp['image'].suffix}"), tmp["image"].with_name(f"{i}{tmp['image'].suffix}"))])
+ if tmp["caption"] is not None:
+ renaming.extend([(tmp["caption"], tmp["caption"].with_name(f"{tmp_name}{tmp['caption'].suffix}"), tmp["caption"].with_name(f"{i}{tmp['caption'].suffix}"))])
+ for mask in tmp["masks"]:
+ renaming.extend([(mask, mask.with_name(f"{tmp_name}-masklabel{mask.suffix}"), mask.with_name(f"{i}-masklabel{mask.suffix}"))])
+
+ errored = False
+ # Rename all the samples (image, caption and mask) in order. To avoid overwriting samples which already have integer names, we need two passes.
+ # First pass: source to temp name.
+ last_before_error = 0
+ try:
+ for j, (src, tmp_dest, _) in enumerate(renaming):
+ self.progress_fn({"data": f"Renaming {src.name} to {tmp_dest.name}..."})
+ src.rename(tmp_dest)
+ last_before_error = j
+
+ except OSError:
+ errored = True
+ try:
+ for src2, dest2, _ in renaming[:-last_before_error]:
+ dest2.rename(src2)
+ except OSError:
+ if self.progress_fn is not None:
+ self.progress_fn({"status": "Critical failure during rename",
+ "data": f"OSError while attempting rollback. Is {dest2.path} still accessible?"})
+ self.log("critical", traceback.format_exc())
+
+ if self.progress_fn is not None:
+ self.progress_fn({"status": "Rename failed, successfully rolled back.", "data": f"Rename failed for file {src.name}"})
+
+ if not errored:
+ # Second pass: temp name to destination.
+ last_before_error = 0
+ try:
+ for j, (_, tmp_dest, final_dest) in enumerate(renaming):
+ self.progress_fn({"data": f"Renaming {tmp_dest.name} to {final_dest.name}..."})
+ tmp_dest.rename(final_dest)
+ outfiles.append(final_dest)
+ last_before_error = j
+
+ except OSError:
+ try:
+ for _, src2, dest2 in renaming[:-last_before_error]:
+ dest2.rename(src2)
+ outfiles = outfiles[:-last_before_error]
+ except OSError:
+ if self.progress_fn is not None:
+ self.progress_fn({"status": "Critical failure during rename",
+ "data": f"OSError while attempting rollback. Is {dest2.path} still accessible?"})
+ self.log("critical", traceback.format_exc())
+
+ if self.progress_fn is not None:
+ self.progress_fn({"status": "Rename failed, successfully rolled back.",
+ "data": f"Rename failed for file {tmp_dest.name}"})
+
+ return outfiles
+
+ def __process_alpha_images(self, files):
+ bg_color_str = self.get_state("alpha_bg_color")
+
+ files = [f for f in files if self._filter_images_and_skip_masks(f)]
+
+ if self.progress_fn is not None:
+ self.progress_fn({"data": "Processing transparent images (excluding mask files)..."})
+
+
+ try:
+ if bg_color_str.lower() in ("-1", "random"):
+ r, g, b = (random.randint(0, 255) for _ in range(3))
+ bg_color_tuple = (r, g, b)
+
+ if self.progress_fn is not None:
+ self.progress_fn({"data": f"Using random background color: #{r:02x}{g:02x}{b:02x}"})
+ else:
+ color = ImageColor.getrgb(bg_color_str)
+ # Ensure we have a 3-channel RGB tuple for the background
+ bg_color_tuple = color[:3]
+
+ if self.progress_fn is not None:
+ self.progress_fn({"data": f"Using background color: {bg_color_str} (RGB: {bg_color_tuple})"})
+ except (ValueError, TypeError) as e:
+ if self.progress_fn is not None:
+ self.progress_fn({"data": f"Invalid color '{bg_color_str}': {e}. Using white instead"})
+
+ bg_color_tuple = (255, 255, 255) # Fallback to White in case of unknown color.
+
+ result = self.pool.map(functools.partial(_process_alpha_image, bg_color_tuple), files)
+ total = len(result)
+ processed = len([r for r in result if r])
+ skipped = total - processed
+
+ if self.progress_fn is not None:
+ self.progress_fn({"data": f"Transparency processing complete: {total} total, {processed} processed, {skipped} skipped"})
+
+ def __resize_large_images(self, files):
+ files = [f for f in files if self._filter_is_image(f)] # TODO: original implementation skipped masks, but it does not make sense to rescale images, but not masks.
+
+ mp = self.get_state("resize_megapixels")
+ if mp == ImageMegapixels.CUSTOM:
+ mp = ImageMegapixels.ONE_MEGAPIXEL.value * self.get_state("resize_custom_megapixels")
+ else:
+ mp = mp.value
+
+ if self.progress_fn is not None:
+ self.progress_fn({"data": f"Starting resizing of large images... Target: {mp / ImageMegapixels.ONE_MEGAPIXEL.value:.1f}MP"})
+
+ result = self.pool.map(functools.partial(_resize_large_image, mp), files)
+ total = len(result)
+ processed = len([r for r in result if r])
+ skipped = total - processed
+
+ if self.progress_fn is not None:
+ self.progress_fn({"data": f"Image resizing complete: {total} total, {processed} processed, {skipped} skipped"})
+
+
+ def __optimize_pngs(self, files):
+ files = [f for f in files if f.suffix == ".png"]
+ result = self.pool.map(_optimize_png, files)
+ total = len(result)
+ processed = len([r for r in result if r[0]])
+ skipped = total - processed
+
+ bytes_saved = sum([r[1] for r in result])
+ avg_bytes_saved = bytes_saved / total if total > 0 else 0
+
+ if self.progress_fn is not None:
+ self.progress_fn({"data": f"Completed optimization: {processed} PNGs optimized, {skipped} skipped. Saved {bytes_saved} bytes ({avg_bytes_saved} average bytes per file)"})
+
+ def __convert_to_webp(self, files):
+ """Convert images to WebP format using the generic converter."""
+ format_options = {
+ "format_ext": ".webp",
+ "pil_format": "WEBP",
+ "lossless_extensions": {".png", ".tiff", ".tif", ".bmp"},
+ "check_lossless": False,
+ "quality": 90,
+ }
+ skip_extensions = {".webp", ".jxl", ".avif"}
+ self.__convert_image_format(
+ files,
+ "WebP",
+ skip_extensions,
+ format_options,
+ )
+
+ def __convert_to_jpegxl(self, files):
+ """Convert images to JPEG XL format using the generic converter."""
+ format_options = {
+ "format_ext": ".jxl",
+ "pil_format": "JXL",
+ "lossless_extensions": set(),
+ "check_lossless": True,
+ "quality": 90,
+ }
+ skip_extensions = {".jxl"}
+ self.__convert_image_format(
+ files,
+ "JPEG XL",
+ skip_extensions,
+ format_options
+ )
+
+ def __convert_image_format(
+ self,
+ files: list[Path],
+ target_format: str,
+ skip_extensions: set,
+ format_options: dict
+ ) -> None:
+ """Generic image conversion function for multiple formats."""
+
+ if self.progress_fn is not None:
+ self.progress_fn({"data": f"Starting conversion to {target_format} format..."})
+
+ files = [f for f in files if self._file_filter(f, skip_extensions)]
+
+ result = self.pool.map(functools.partial(_convert_image, format_options), files)
+ total = len(result)
+ processed = len([r for r in result if r[0]])
+ skipped = total - processed
+
+ bytes_saved = sum([r[1] for r in result])
+ avg_bytes_saved = bytes_saved / total
+
+ if self.progress_fn is not None:
+ self.progress_fn({"data": f"Completed optimization: {processed} PNGs optimized, {skipped} skipped. Saved {bytes_saved} bytes ({avg_bytes_saved} average bytes per file)"})
diff --git a/modules/ui/models/CaptionModel.py b/modules/ui/models/CaptionModel.py
new file mode 100644
index 000000000..63c090ee5
--- /dev/null
+++ b/modules/ui/models/CaptionModel.py
@@ -0,0 +1,71 @@
+from modules.module.Blip2Model import Blip2Model
+from modules.module.BlipModel import BlipModel
+from modules.module.WDModel import WDModel
+from modules.ui.models.SingletonConfigModel import SingletonConfigModel
+from modules.util.enum.GenerateCaptionsModel import GenerateCaptionsAction, GenerateCaptionsModel
+from modules.util.torch_util import default_device, torch_gc
+
+import torch
+
+
+class CaptionModel(SingletonConfigModel):
+ def __init__(self):
+ super().__init__({
+ "model": GenerateCaptionsModel.BLIP,
+ "path": "",
+ "caption": "",
+ "prefix": "",
+ "postfix": "",
+ "mode": GenerateCaptionsAction.REPLACE,
+ "include_subdirectories": False,
+ })
+
+ self.captioning_model = None
+
+ def create_captions(self, progress_fn=None):
+ with self.critical_region_read():
+ self.__load_captioning_model(self.get_state("model"))
+
+ self.captioning_model.caption_folder(
+ sample_dir=self.get_state("path"),
+ initial_caption=self.get_state("caption"),
+ caption_prefix=self.get_state("prefix"),
+ caption_postfix=self.get_state("postfix"),
+ mode=str(self.get_state("mode")).lower(),
+ include_subdirectories=self.get_state("include_subdirectories"),
+ progress_callback=self.__wrap_progress(progress_fn),
+ )
+
+ def __load_captioning_model(self, model):
+ self.captioning_model = None
+
+ if model == GenerateCaptionsModel.BLIP:
+ if self.captioning_model is None or not isinstance(self.captioning_model, BlipModel):
+ self.log("info", "Loading Blip model, this may take a while")
+ self.release_model()
+ self.captioning_model = BlipModel(default_device, torch.float16)
+ elif model == GenerateCaptionsModel.BLIP2:
+ if self.captioning_model is None or not isinstance(self.captioning_model, Blip2Model):
+ self.log("info", "Loading Blip2 model, this may take a while")
+ self.release_model()
+ self.captioning_model = Blip2Model(default_device, torch.float16)
+ elif model == GenerateCaptionsModel.WD14_VIT_2:
+ if self.captioning_model is None or not isinstance(self.captioning_model, WDModel):
+ self.log("info", "Loading WD14_VIT_v2 model, this may take a while")
+ self.release_model()
+ self.captioning_model = WDModel(default_device, torch.float16)
+
+ def __wrap_progress(self, fn):
+ def f(value, max_value):
+ if fn is not None:
+ fn({"value": value, "max_value": max_value})
+ return f
+
+ def release_model(self):
+ """Release all models from VRAM"""
+ freed = False
+ if self.captioning_model is not None:
+ self.captioning_model = None
+ freed = True
+ if freed:
+ torch_gc()
diff --git a/modules/ui/models/ConceptModel.py b/modules/ui/models/ConceptModel.py
new file mode 100644
index 000000000..9c26cac55
--- /dev/null
+++ b/modules/ui/models/ConceptModel.py
@@ -0,0 +1,548 @@
+import copy
+import fractions
+import json
+import math
+import os
+import pathlib
+import platform
+import random
+import threading
+import time
+
+from modules.ui.models.SingletonConfigModel import SingletonConfigModel
+from modules.ui.models.StateModel import StateModel
+from modules.util import concept_stats, path_util
+from modules.util.config.ConceptConfig import ConceptConfig
+from modules.util.enum.ConceptType import ConceptType
+from modules.util.image_util import load_image
+from modules.util.path_util import write_json_atomic
+
+from mgds.LoadingPipeline import LoadingPipeline
+from mgds.OutputPipelineModule import OutputPipelineModule
+from mgds.PipelineModule import PipelineModule
+from mgds.pipelineModules.CapitalizeTags import CapitalizeTags
+from mgds.pipelineModules.DropTags import DropTags
+from mgds.pipelineModules.RandomBrightness import RandomBrightness
+from mgds.pipelineModules.RandomCircularMaskShrink import (
+ RandomCircularMaskShrink,
+)
+from mgds.pipelineModules.RandomContrast import RandomContrast
+from mgds.pipelineModules.RandomFlip import RandomFlip
+from mgds.pipelineModules.RandomHue import RandomHue
+from mgds.pipelineModules.RandomMaskRotateCrop import RandomMaskRotateCrop
+from mgds.pipelineModules.RandomRotate import RandomRotate
+from mgds.pipelineModules.RandomSaturation import RandomSaturation
+from mgds.pipelineModules.ShuffleTags import ShuffleTags
+from mgds.pipelineModuleTypes.RandomAccessPipelineModule import (
+ RandomAccessPipelineModule,
+)
+
+import torch
+from torchvision.transforms import functional
+
+import huggingface_hub
+from PIL import Image
+
+
+class InputPipelineModule(
+ PipelineModule,
+ RandomAccessPipelineModule,
+):
+ def __init__(self, data: dict):
+ super().__init__()
+ self.data = data
+
+ def length(self) -> int:
+ return 1
+
+ def get_inputs(self) -> list[str]:
+ return []
+
+ def get_outputs(self) -> list[str]:
+ return list(self.data.keys())
+
+ def get_item(self, variation: int, index: int, requested_name: str = None) -> dict:
+ return self.data
+
+
+class ConceptModel(SingletonConfigModel):
+ def __init__(self):
+ super().__init__([])
+ self.cancel_scan_flag = threading.Event()
+
+ def __len__(self):
+ with self.critical_region_read():
+ return len(self.config)
+
+
+ def get_random_seed(self):
+ return ConceptConfig.default_values().seed
+
+
+ def get_concept_name(self, idx):
+ with self.critical_region_read():
+ name = self.config[idx].name
+ path = self.config[idx].path
+
+ if name is not None and name != "":
+ return name
+ elif path is not None and path != "":
+ return os.path.basename(path)
+ else:
+ return ""
+
+ def toggle_concepts(self):
+ some_enabled = self.some_concepts_enabled()
+
+ with self.critical_region_write():
+ for c in self.config:
+ c.enabled = not some_enabled
+
+ def get_filtered_concepts(self, query="", type=ConceptType.ALL, show_disabled=True):
+ with self.critical_region_read():
+ filtered_concepts = [(idx, c) for idx, c in enumerate(self.config)
+ if (show_disabled or c.enabled) and
+ (type == ConceptType.ALL or c.type == type) and
+ (query == "" or query.strip() in c.name)]
+
+ return filtered_concepts
+
+
+ def some_concepts_enabled(self):
+ out = False
+ with self.critical_region_read():
+ for c in self.config:
+ out |= c.enabled
+ return out
+
+ def create_new_concept(self):
+ with self.critical_region_write():
+ con_cfg = ConceptConfig.default_values()
+ self.config.append(con_cfg)
+
+ def clone_concept(self, idx):
+ with self.critical_region_write():
+ new_element = copy.deepcopy(self.config[idx])
+ self.config.append(new_element)
+
+ def delete_concept(self, idx):
+ with self.critical_region_write():
+ self.config.pop(idx)
+
+ def save_config(self, path="training_concepts"):
+ if not os.path.exists(path):
+ os.mkdir(path)
+
+ config_path = StateModel.instance().get_state("concept_file_name")
+ with self.critical_region_read():
+ write_json_atomic(config_path, [element.to_dict() for element in self.config])
+
+ def load_config(self, filename, path="training_concepts"):
+ if not os.path.exists(path):
+ os.mkdir(path)
+
+ if filename == "":
+ filename = "concepts"
+
+ config_file = path_util.canonical_join(path, f"{filename}.json")
+ StateModel.instance().set_state("concept_file_name", config_file)
+
+ with self.critical_region_write():
+ self.config = []
+
+ if os.path.exists(config_file):
+ with open(config_file, "r") as f:
+ loaded_config_json = json.load(f)
+ for element_json in loaded_config_json:
+ element = ConceptConfig.default_values().from_dict(element_json)
+ self.config.append(element)
+
+ @staticmethod
+ def get_concept_path(path):
+ if os.path.isdir(path):
+ return path
+ try:
+ # don't download, only check if available locally:
+ return huggingface_hub.snapshot_download(repo_id=path, repo_type="dataset", local_files_only=True)
+ except Exception:
+ return None
+
+ #@SingletonConfigModel.atomic
+ def get_preview_icon(self, idx):
+ preview_path = "resources/icons/icon.png"
+
+ with self.critical_region_read():
+ glob_pattern = "**/*.*" if self.get_state(f"{idx}.include_subdirectories") else "*.*"
+ concept_path = self.get_concept_path(self.get_state(f"{idx}.path"))
+
+ if concept_path:
+ for path in pathlib.Path(concept_path).glob(glob_pattern):
+ extension = os.path.splitext(path)[1]
+ if path.is_file() and path_util.is_supported_image_extension(extension) \
+ and not path.name.endswith("-masklabel.png") and not path.name.endswith("-condlabel.png"):
+ preview_path = path_util.canonical_join(concept_path, path)
+ break
+
+ image = load_image(preview_path, convert_mode="RGBA")
+ size = min(image.width, image.height)
+ image = image.crop((
+ (image.width - size) // 2,
+ (image.height - size) // 2,
+ (image.width - size) // 2 + size,
+ (image.height - size) // 2 + size,
+ ))
+ image = image.resize((150, 150), Image.Resampling.BILINEAR)
+ return image
+
+ def download_dataset(self, idx):
+ # Exception handled by WorkerPool.
+ huggingface_hub.login(token=StateModel.instance().get_state("secrets.huggingface_token"), new_session=False)
+ huggingface_hub.snapshot_download(repo_id=self.get_state(f"{idx}.path"), repo_type="dataset")
+
+ def get_preview_prompt(self, filename, show_augmentations):
+ empty_msg = "[Empty prompt]"
+ try:
+ with open(filename, "r") as f:
+ if show_augmentations:
+ lines = [line.strip() for line in f if line.strip()]
+ return random.choice(lines) if lines else empty_msg
+ content = f.read().strip()
+ return content if content else empty_msg
+ except FileNotFoundError:
+ return "File not found, please check the path"
+ except IsADirectoryError:
+ return "[Provided path is a directory, please correct the caption path]"
+ except PermissionError:
+ if platform.system() == "Windows":
+ return "[Permission denied, please check the file permissions or Windows Defender settings]"
+ else:
+ return "[Permission denied, please check the file permissions]"
+ except UnicodeDecodeError:
+ return "[Invalid file encoding. This should not happen, please report this issue]"
+
+ def get_concept_stats(self, idx, advanced_checks, wait_time=60):
+ path, include_subdirectories, concept = self.bulk_read(f"{idx}.path", f"{idx}.include_subdirectories", str(idx))
+
+
+ if not os.path.isdir(path):
+ self.log("error", f"Unable to get statistics for invalid concept path: {path}")
+ return
+ start_time = time.perf_counter()
+ self.cancel_scan_flag.clear()
+
+ concept_path = self.get_concept_path(path)
+
+ if not concept_path:
+ self.log("error", f"Unable to get statistics for invalid concept path: {path}")
+ return
+ subfolders = [concept_path]
+
+ stats_dict = concept_stats.init_concept_stats(advanced_checks)
+ for path in subfolders:
+ if self.cancel_scan_flag.is_set() or time.perf_counter() - start_time > wait_time:
+ break
+ stats_dict = concept_stats.folder_scan(path, stats_dict, advanced_checks, concept, start_time, wait_time, self.cancel_scan_flag)
+ if include_subdirectories and not self.cancel_scan_flag.is_set(): #add all subfolders of current directory to for loop
+ subfolders.extend([f for f in os.scandir(path) if f.is_dir()])
+
+ self.set_state(f"{idx}.concept_stats", stats_dict)
+ self.cancel_scan_flag.clear()
+
+ def pretty_print_stats(self, idx):
+ concept_stats = self.get_state(f"{idx}.concept_stats")
+ formatted_stats = {}
+
+ if len(concept_stats) == 0:
+ for k in ["file_size", "processing_time", "dir_count", "image_count", "image_count_mask", "image_count_caption",
+ "video_count", "video_count_caption", "mask_count", "mask_count_unpaired", "caption_count",
+ "unpaired_captions", "max_pixels", "avg_pixels", "min_pixels", "length_max", "length_avg",
+ "length_min", "fps_max", "fps_avg", "fps_min", "caption_max", "caption_avg", "caption_min", "small_bucket"
+ ]:
+ formatted_stats[k] = "-"
+ formatted_stats["aspect_buckets"] = {}
+ return formatted_stats
+
+ # File size.
+ formatted_stats["file_size"] = str(int(concept_stats["file_size"] / 1048576)) + " MB"
+ formatted_stats["processing_time"] = str(round(concept_stats["processing_time"], 2)) + " s"
+
+ # Directory count.
+ formatted_stats["dir_count"] = concept_stats["directory_count"]
+
+ # Image count.
+ formatted_stats["image_count"] = concept_stats["image_count"]
+ formatted_stats["image_count_mask"] = concept_stats["image_with_mask_count"]
+ formatted_stats["image_count_caption"] = concept_stats["image_with_caption_count"]
+
+ # Video count.
+ formatted_stats["video_count"] = concept_stats["video_count"]
+ formatted_stats["video_count_caption"] = concept_stats["video_with_caption_count"]
+
+ # Mask count.
+ formatted_stats["mask_count"] = concept_stats["mask_count"]
+ formatted_stats["mask_count_unpaired"] = concept_stats["unpaired_masks"]
+
+ # Caption count.
+ if "subcaption_count" in concept_stats and concept_stats["subcaption_count"] > 0:
+ formatted_stats["caption_count"] = f'{concept_stats["caption_count"]} ({concept_stats["subcaption_count"]})'
+ else:
+ formatted_stats["caption_count"] = concept_stats["caption_count"]
+ formatted_stats["unpaired_captions"] = concept_stats["unpaired_captions"]
+
+ # Resolution info.
+ max_pixels = concept_stats["max_pixels"]
+ avg_pixels = concept_stats["avg_pixels"]
+ min_pixels = concept_stats["min_pixels"]
+
+ if any(isinstance(x, str) for x in [max_pixels, avg_pixels, min_pixels]) or concept_stats["image_count"] == 0: # will be str if adv stats were not taken
+ formatted_stats["max_pixels"] = "-"
+ formatted_stats["avg_pixels"] = "-"
+ formatted_stats["min_pixels"] = "-"
+ else:
+ # formatted as (#pixels/1000000) MP, width x height, \n filename
+ formatted_stats["max_pixels"] = f'{str(round(max_pixels[0] / 1000000, 2))} MP, {max_pixels[2]}\n{max_pixels[1]}'
+ formatted_stats["avg_pixels"] = f'{str(round(avg_pixels / 1000000, 2))} MP, ~{int(math.sqrt(avg_pixels))}w x {int(math.sqrt(avg_pixels))}h'
+ formatted_stats["min_pixels"] = f'{str(round(min_pixels[0] / 1000000, 2))} MP, {min_pixels[2]}\n{min_pixels[1]}'
+
+ # Video length and fps info.
+ max_length = concept_stats["max_length"]
+ avg_length = concept_stats["avg_length"]
+ min_length = concept_stats["min_length"]
+ max_fps = concept_stats["max_fps"]
+ avg_fps = concept_stats["avg_fps"]
+ min_fps = concept_stats["min_fps"]
+
+ if any(isinstance(x, str) for x in [max_length, avg_length, min_length]) or concept_stats["video_count"] == 0: # will be str if adv stats were not taken
+ formatted_stats["length_max"] = "-"
+ formatted_stats["length_avg"] = "-"
+ formatted_stats["length_min"] = "-"
+ formatted_stats["fps_max"] = "-"
+ formatted_stats["fps_avg"] = "-"
+ formatted_stats["fps_min"] = "-"
+ else:
+ # formatted as (#frames) frames \n filename
+ formatted_stats["length_max"] = f'{int(max_length[0])} frames\n{max_length[1]}'
+ formatted_stats["length_avg"] = f'{int(avg_length)} frames'
+ formatted_stats["length_min"] = f'{int(min_length[0])} frames\n{min_length[1]}'
+ # formatted as (#fps) fps \n filename
+ formatted_stats["fps_max"] = f'{int(max_fps[0])} fps\n{max_fps[1]}'
+ formatted_stats["fps_avg"] = f'{int(avg_fps)} fps'
+ formatted_stats["fps_min"] = f'{int(min_fps[0])} fps\n{min_fps[1]}'
+
+ # Caption info.
+ max_caption_length = concept_stats["max_caption_length"]
+ avg_caption_length = concept_stats["avg_caption_length"]
+ min_caption_length = concept_stats["min_caption_length"]
+
+ if any(isinstance(x, str) for x in [max_caption_length, avg_caption_length, min_caption_length]) or concept_stats["caption_count"] == 0: # will be str if adv stats were not taken
+ formatted_stats["caption_max"] = "-"
+ formatted_stats["caption_avg"] = "-"
+ formatted_stats["caption_min"] = "-"
+ else:
+ # formatted as (#chars) chars, (#words) words, \n filename
+ formatted_stats["caption_max"] = f'{max_caption_length[0]} chars, {max_caption_length[2]} words\n{max_caption_length[1]}'
+ formatted_stats["caption_avg"] = f'{int(avg_caption_length[0])} chars, {int(avg_caption_length[1])} words'
+ formatted_stats["caption_min"] = f'{min_caption_length[0]} chars, {min_caption_length[2]} words\n{min_caption_length[1]}'
+
+ # Aspect bucketing.
+ aspect_buckets = concept_stats["aspect_buckets"]
+ formatted_stats["aspect_buckets"] = aspect_buckets
+
+ if len(aspect_buckets) != 0 and max(
+ val for val in aspect_buckets.values()) > 0: # check aspect_bucket data exists and is not all zero
+ min_val = min(val for val in aspect_buckets.values() if val > 0) # smallest nonzero values
+ if max(val for val in
+ aspect_buckets.values()) > min_val: # check if any buckets larger than min_val exist - if all images are same aspect then there won't be
+ min_val2 = min(
+ val for val in aspect_buckets.values() if (val > 0 and val != min_val)) # second smallest bucket
+ else:
+ min_val2 = min_val # if no second smallest bucket exists set to min_val
+ min_aspect_buckets = {key: val for key, val in aspect_buckets.items() if val in (min_val, min_val2)}
+ min_bucket_str = ""
+ for key, val in min_aspect_buckets.items():
+ min_bucket_str += f'aspect {self.decimal_to_aspect_ratio(key)} : {val} img\n'
+ min_bucket_str.strip()
+
+ formatted_stats["small_bucket"] = min_bucket_str
+ else:
+ formatted_stats["small_bucket"] = "-"
+
+
+ return formatted_stats
+
+
+
+ def decimal_to_aspect_ratio(self, value):
+ #find closest fraction to decimal aspect value and convert to a:b format
+ aspect_fraction = fractions.Fraction(value).limit_denominator(16)
+ aspect_string = f'{aspect_fraction.denominator}:{aspect_fraction.numerator}'
+ return aspect_string
+
+ def get_image(self, idx, image_id, show_augmentations=False):
+ with self.critical_region_read():
+ preview_image_path = "resources/icons/icon.png"
+ file_index = -1
+ glob_pattern = "**/*.*" if self.get_state(f"{idx}.include_subdirectories") else "*.*"
+
+ concept_path = self.get_concept_path(self.get_state(f"{idx}.path"))
+ if concept_path:
+ for path in pathlib.Path(concept_path).glob(glob_pattern):
+ extension = os.path.splitext(path)[1]
+ if path.is_file() and path_util.is_supported_image_extension(extension) \
+ and not path.name.endswith("-masklabel.png") and not path.name.endswith("-condlabel.png"):
+ preview_image_path = path_util.canonical_join(concept_path, path)
+ file_index += 1
+ if file_index == image_id:
+ break
+
+ image = load_image(preview_image_path, 'RGB')
+ image_tensor = functional.to_tensor(image)
+
+ splitext = os.path.splitext(preview_image_path)
+ preview_mask_path = path_util.canonical_join(splitext[0] + "-masklabel.png")
+ if not os.path.isfile(preview_mask_path):
+ preview_mask_path = None
+
+ if preview_mask_path:
+ mask = Image.open(preview_mask_path).convert("L")
+ mask_tensor = functional.to_tensor(mask)
+ else:
+ mask_tensor = torch.ones((1, image_tensor.shape[1], image_tensor.shape[2]))
+
+ source = self.get_state(f"{idx}.text.prompt_source")
+ preview_p = pathlib.Path(preview_image_path)
+ if source == "filename":
+ prompt_output = preview_p.stem or "[Empty prompt]"
+ else:
+ file_map = {
+ "sample": preview_p.with_suffix(".txt"),
+ "concept": pathlib.Path(self.get_state(f"{idx}.text.prompt_path")) if self.get_state(f"{idx}.text.prompt_path") else None,
+ }
+ file_path = file_map.get(source)
+ prompt_output = self.get_preview_prompt(str(file_path), show_augmentations) if file_path else "[Empty prompt]"
+
+ modules = []
+ if show_augmentations:
+ input_module = InputPipelineModule({
+ 'true': True,
+ 'image': image_tensor,
+ 'mask': mask_tensor,
+ 'enable_random_flip': self.get_state(f"{idx}.image.enable_random_flip"),
+ 'enable_fixed_flip': self.get_state(f"{idx}.image.enable_fixed_flip"),
+ 'enable_random_rotate': self.get_state(f"{idx}.image.enable_random_rotate"),
+ 'enable_fixed_rotate': self.get_state(f"{idx}.image.enable_fixed_rotate"),
+ 'random_rotate_max_angle': self.get_state(f"{idx}.image.random_rotate_max_angle"),
+ 'enable_random_brightness': self.get_state(f"{idx}.image.enable_random_brightness"),
+ 'enable_fixed_brightness': self.get_state(f"{idx}.image.enable_fixed_brightness"),
+ 'random_brightness_max_strength': self.get_state(f"{idx}.image.random_brightness_max_strength"),
+ 'enable_random_contrast': self.get_state(f"{idx}.image.enable_random_contrast"),
+ 'enable_fixed_contrast': self.get_state(f"{idx}.image.enable_fixed_contrast"),
+ 'random_contrast_max_strength': self.get_state(f"{idx}.image.random_contrast_max_strength"),
+ 'enable_random_saturation': self.get_state(f"{idx}.image.enable_random_saturation"),
+ 'enable_fixed_saturation': self.get_state(f"{idx}.image.enable_fixed_saturation"),
+ 'random_saturation_max_strength': self.get_state(f"{idx}.image.random_saturation_max_strength"),
+ 'enable_random_hue': self.get_state(f"{idx}.image.enable_random_hue"),
+ 'enable_fixed_hue': self.get_state(f"{idx}.image.enable_fixed_hue"),
+ 'random_hue_max_strength': self.get_state(f"{idx}.image.random_hue_max_strength"),
+ 'enable_random_circular_mask_shrink': self.get_state(f"{idx}.image.enable_random_circular_mask_shrink"),
+ 'enable_random_mask_rotate_crop': self.get_state(f"{idx}.image.enable_random_mask_rotate_crop"),
+
+ 'prompt': prompt_output,
+ 'tag_dropout_enable': self.get_state(f"{idx}.text.tag_dropout_enable"),
+ 'tag_dropout_probability': self.get_state(f"{idx}.text.tag_dropout_probability"),
+ 'tag_dropout_mode': self.get_state(f"{idx}.text.tag_dropout_mode"),
+ 'tag_dropout_special_tags': self.get_state(f"{idx}.text.tag_dropout_special_tags"),
+ 'tag_dropout_special_tags_mode': self.get_state(f"{idx}.text.tag_dropout_special_tags_mode"),
+ 'tag_delimiter': self.get_state(f"{idx}.text.tag_delimiter"),
+ 'keep_tags_count': self.get_state(f"{idx}.text.keep_tags_count"),
+ 'tag_dropout_special_tags_regex': self.get_state(f"{idx}.text.tag_dropout_special_tags_regex"),
+ 'caps_randomize_enable': self.get_state(f"{idx}.text.caps_randomize_enable"),
+ 'caps_randomize_probability': self.get_state(f"{idx}.text.caps_randomize_probability"),
+ 'caps_randomize_mode': self.get_state(f"{idx}.text.caps_randomize_mode"),
+ 'caps_randomize_lowercase': self.get_state(f"{idx}.text.caps_randomize_lowercase"),
+ 'enable_tag_shuffling': self.get_state(f"{idx}.text.enable_tag_shuffling"),
+ })
+
+ circular_mask_shrink = RandomCircularMaskShrink(mask_name='mask', shrink_probability=1.0,
+ shrink_factor_min=0.2, shrink_factor_max=1.0,
+ enabled_in_name='enable_random_circular_mask_shrink')
+ random_mask_rotate_crop = RandomMaskRotateCrop(mask_name='mask', additional_names=['image'], min_size=512,
+ min_padding_percent=10, max_padding_percent=30,
+ max_rotate_angle=20,
+ enabled_in_name='enable_random_mask_rotate_crop')
+ random_flip = RandomFlip(names=['image', 'mask'], enabled_in_name='enable_random_flip',
+ fixed_enabled_in_name='enable_fixed_flip')
+ random_rotate = RandomRotate(names=['image', 'mask'], enabled_in_name='enable_random_rotate',
+ fixed_enabled_in_name='enable_fixed_rotate',
+ max_angle_in_name='random_rotate_max_angle')
+ random_brightness = RandomBrightness(names=['image'], enabled_in_name='enable_random_brightness',
+ fixed_enabled_in_name='enable_fixed_brightness',
+ max_strength_in_name='random_brightness_max_strength')
+ random_contrast = RandomContrast(names=['image'], enabled_in_name='enable_random_contrast',
+ fixed_enabled_in_name='enable_fixed_contrast',
+ max_strength_in_name='random_contrast_max_strength')
+ random_saturation = RandomSaturation(names=['image'], enabled_in_name='enable_random_saturation',
+ fixed_enabled_in_name='enable_fixed_saturation',
+ max_strength_in_name='random_saturation_max_strength')
+ random_hue = RandomHue(names=['image'], enabled_in_name='enable_random_hue',
+ fixed_enabled_in_name='enable_fixed_hue',
+ max_strength_in_name='random_hue_max_strength')
+ drop_tags = DropTags(text_in_name='prompt', enabled_in_name='tag_dropout_enable',
+ probability_in_name='tag_dropout_probability', dropout_mode_in_name='tag_dropout_mode',
+ special_tags_in_name='tag_dropout_special_tags',
+ special_tag_mode_in_name='tag_dropout_special_tags_mode',
+ delimiter_in_name='tag_delimiter',
+ keep_tags_count_in_name='keep_tags_count', text_out_name='prompt',
+ regex_enabled_in_name='tag_dropout_special_tags_regex')
+ caps_randomize = CapitalizeTags(text_in_name='prompt', enabled_in_name='caps_randomize_enable',
+ probability_in_name='caps_randomize_probability',
+ capitalize_mode_in_name='caps_randomize_mode',
+ delimiter_in_name='tag_delimiter',
+ convert_lowercase_in_name='caps_randomize_lowercase',
+ text_out_name='prompt')
+ shuffle_tags = ShuffleTags(text_in_name='prompt', enabled_in_name='enable_tag_shuffling',
+ delimiter_in_name='tag_delimiter', keep_tags_count_in_name='keep_tags_count',
+ text_out_name='prompt')
+ output_module = OutputPipelineModule(['image', 'mask', 'prompt'])
+
+ modules = [
+ input_module,
+ circular_mask_shrink,
+ random_mask_rotate_crop,
+ random_flip,
+ random_rotate,
+ random_brightness,
+ random_contrast,
+ random_saturation,
+ random_hue,
+ drop_tags,
+ caps_randomize,
+ shuffle_tags,
+ output_module,
+ ]
+
+ pipeline = LoadingPipeline(
+ device=torch.device('cpu'),
+ modules=modules,
+ batch_size=1,
+ seed=random.randint(0, 2 ** 30),
+ state=None,
+ initial_epoch=0,
+ initial_index=0,
+ )
+
+ data = pipeline.__next__()
+ image_tensor = data['image']
+ mask_tensor = data['mask']
+ prompt_output = data['prompt']
+
+ filename_output = os.path.basename(preview_image_path)
+
+ mask_tensor = torch.clamp(mask_tensor, 0.3, 1)
+ image_tensor = image_tensor * mask_tensor
+
+ image = functional.to_pil_image(image_tensor)
+
+ image.thumbnail((300, 300))
+
+ return image, filename_output, prompt_output
diff --git a/modules/ui/models/ConvertModel.py b/modules/ui/models/ConvertModel.py
new file mode 100644
index 000000000..133f4f299
--- /dev/null
+++ b/modules/ui/models/ConvertModel.py
@@ -0,0 +1,64 @@
+import traceback
+from uuid import uuid4
+
+from modules.ui.models.SingletonConfigModel import SingletonConfigModel
+from modules.util import create
+from modules.util.args.ConvertModelArgs import ConvertModelArgs
+from modules.util.enum.TrainingMethod import TrainingMethod
+from modules.util.ModelNames import EmbeddingName, ModelNames
+from modules.util.torch_util import torch_gc
+
+
+class ConvertModel(SingletonConfigModel):
+ def __init__(self):
+ super().__init__(ConvertModelArgs.default_values())
+
+ def convert_model(self):
+ cfg = self.bulk_read("model_type", "training_method", "input_name",
+ "output_model_destination", "output_model_format", "output_dtype",
+ as_dict=True)
+
+ try:
+ model_loader = create.create_model_loader(
+ model_type=cfg["model_type"],
+ training_method=cfg["training_method"]
+ )
+ model_saver = create.create_model_saver(
+ model_type=cfg["model_type"],
+ training_method=cfg["training_method"]
+ )
+
+ print("Loading model " + cfg["input_name"])
+ if cfg["training_method"] in [TrainingMethod.FINE_TUNE]:
+ model = model_loader.load(
+ model_type=cfg["model_type"],
+ model_names=ModelNames(
+ base_model=cfg["input_name"],
+ ),
+ weight_dtypes=self.config.weight_dtypes(),
+ )
+ elif cfg["training_method"] in [TrainingMethod.LORA, TrainingMethod.EMBEDDING]:
+ model = model_loader.load(
+ model_type=cfg["model_type"],
+ model_names=ModelNames(
+ lora=cfg["input_name"],
+ embedding=EmbeddingName(str(uuid4()), cfg["input_name"]),
+ ),
+ weight_dtypes=self.config.weight_dtypes(),
+ )
+ else:
+ raise Exception("could not load model: " + cfg["input_name"])
+
+ self.log("info", "Saving model " + cfg["output_model_destination"])
+ model_saver.save(
+ model=model,
+ model_type=cfg["model_type"],
+ output_model_format=cfg["output_model_format"],
+ output_model_destination=cfg["output_model_destination"],
+ dtype=cfg["output_dtype"].torch_dtype(),
+ )
+ self.log("info", "Model converted")
+ except Exception:
+ self.log("critical", traceback.format_exc())
+ finally:
+ torch_gc()
diff --git a/modules/ui/models/DatasetModel.py b/modules/ui/models/DatasetModel.py
new file mode 100644
index 000000000..f827c06f9
--- /dev/null
+++ b/modules/ui/models/DatasetModel.py
@@ -0,0 +1,241 @@
+import os
+import re
+from pathlib import Path
+
+from modules.ui.models.SingletonConfigModel import SingletonConfigModel
+from modules.util.config.BaseConfig import BaseConfig
+from modules.util.enum.CaptionFilter import CaptionFilter
+from modules.util.enum.FileFilter import FileFilter
+
+from PIL import Image
+
+
+class DatasetConfig(BaseConfig):
+ path: str
+ valid: bool # Is a valid dataset path?
+ include_subdirectories: bool
+ file_filter: str
+ file_filter_mode: FileFilter
+ caption_filter: str
+ caption_filter_mode: CaptionFilter
+ files: list
+
+ @staticmethod
+ def default_values():
+ data = []
+
+ # name, default value, data type, nullable
+ data.append(("path", None, str, True))
+ data.append(("valid", False, bool, False))
+ data.append(("include_subdirectories", False, bool, False))
+ data.append(("file_filter", "", str, False))
+ data.append(("file_filter_mode", FileFilter.FILE, FileFilter, False))
+ data.append(("caption_filter", "", str, False))
+ data.append(("caption_filter_mode", CaptionFilter.MATCHES, CaptionFilter, False))
+ data.append(("files", [], list, False))
+ data.append(("filter_mask_exists", False, bool, False))
+ data.append(("filter_caption_exists", False, bool, False))
+
+ return DatasetConfig(data)
+
+
+
+class DatasetModel(SingletonConfigModel):
+ def __init__(self):
+ super().__init__(DatasetConfig.default_values())
+
+ def scan(self):
+ path, include_subdirs = self.bulk_read("path", "include_subdirectories")
+
+ if path is not None:
+ root = Path(path)
+
+ root_str = str(root.resolve())
+ root_len = len(root_str) + 1 # ".../dir" + "/"
+ stack = [root_str]
+ results = []
+
+ while stack:
+ top = stack.pop()
+ with os.scandir(top) as it:
+ for entry in it:
+ if entry.is_dir(follow_symlinks=False):
+ if include_subdirs:
+ stack.append(entry.path)
+ continue
+ name = entry.name
+ if self.__is_supported(name):
+ # strip root and back-slashes only once
+ results.append(entry.path[root_len:].replace("\\", "/"))
+ self.set_state("files", sorted(results, key=lambda x: self.natural_sort_key(x)))
+
+ def get_filtered_files(self):
+ (path, filtered, file_filter, file_filter_mode,
+ caption_filter, caption_filter_mode, filter_mask_exists, filter_caption_exists) = self.bulk_read("path", "files", "file_filter",
+ "file_filter_mode", "caption_filter", "caption_filter_mode",
+ "filter_mask_exists", "filter_caption_exists")
+ file_filter = file_filter.strip()
+ caption_filter = caption_filter.strip()
+
+ filtered = [self.get_caption_mask_exist(path, file) for file in filtered]
+
+ if filter_caption_exists:
+ filtered = [f for f in filtered if f[1]]
+
+ if filter_mask_exists:
+ filtered = [f for f in filtered if f[2]]
+
+ if file_filter == "" and caption_filter == "":
+ return filtered
+
+ if file_filter != "":
+ try:
+ pattern = re.compile(re.escape(file_filter), re.IGNORECASE)
+
+ if file_filter_mode == FileFilter.FILE:
+ filtered = [
+ f for f in filtered if pattern.search(Path(f[0]).name)
+ ]
+ elif file_filter_mode == FileFilter.PATH:
+ filtered = [
+ f for f in filtered if pattern.search(f[0]) # f is already str # TODO: This is taken from the original implementation, however it has the same effect of FileFilter.BOTH, because it does not strip the filename before searching
+ ]
+ else: # Both
+ filtered = [
+ f
+ for f in filtered
+ if pattern.search(f[0]) or pattern.search(Path(f[0]).name)
+ ]
+ except re.error:
+ pass
+
+ if caption_filter != "" and path is not None:
+ try:
+ caption_files = []
+ for file in filtered: # Iterate over strings
+ file_path_str = file[0]
+ full_path = Path(path) / file_path_str # file_path_str is relative
+ caption_path = full_path.with_suffix(".txt")
+
+ if not caption_path.exists():
+ continue
+ try:
+ caption_content = caption_path.read_text(
+ encoding="utf-8"
+ ).strip()
+ match = False
+ if caption_filter_mode == CaptionFilter.CONTAINS:
+ if caption_filter.lower() in caption_content.lower():
+ match = True
+ elif caption_filter_mode == CaptionFilter.MATCHES:
+ if caption_filter.lower() == caption_content.lower():
+ match = True
+ elif caption_filter_mode == CaptionFilter.EXCLUDES:
+ if caption_filter.lower() not in caption_content.lower():
+ match = True
+ elif caption_filter_mode == CaptionFilter.REGEX:
+ pat = re.compile(caption_filter, re.IGNORECASE)
+ if pat.search(caption_content):
+ match = True
+ if match:
+ caption_files.append(file)
+ except Exception:
+ continue
+ filtered = list(caption_files)
+ except Exception as e:
+ self.log("error", f"Error applying caption filter: {e}")
+
+ return filtered
+
+ @staticmethod
+ def natural_sort_key(s):
+ """Sort strings with embedded numbers in natural order."""
+
+ # Split the input string into text and numeric parts
+ def convert(text):
+ return int(text) if text.isdigit() else text.lower()
+
+ return [convert(c) for c in re.split(r"(\d+)", s)]
+
+ def get_sample(self, path):
+ image = None
+ caption = None
+ mask = None
+
+ basepath = self.get_state("path")
+ image_path = Path(basepath) / path
+ mask_path = image_path.with_name(f"{image_path.stem}-masklabel.png")
+ caption_path = image_path.with_suffix(".txt")
+
+ if os.path.exists(image_path):
+ image = Image.open(image_path).convert("RGB")
+
+ if os.path.exists(mask_path):
+ mask = Image.open(mask_path).convert("L")
+
+ if os.path.exists(caption_path):
+ caption = caption_path.read_text(encoding="utf-8").strip()
+
+ return image, mask, caption
+
+
+ def get_mask_path(self, path):
+ basepath = self.get_state("path")
+ image_path = Path(basepath) / path
+ mask_path = image_path.with_name(f"{image_path.stem}-masklabel.png")
+
+ return mask_path, os.path.exists(mask_path)
+
+
+ def save_caption(self, path, caption):
+ basepath = self.get_state("path")
+ image_path = Path(basepath) / path
+ caption_path = image_path.with_suffix(".txt")
+ caption_path.write_text(caption.strip(), encoding="utf-8")
+
+ def delete_caption(self, path):
+ basepath = self.get_state("path")
+ image_path = Path(basepath) / path
+ caption_path = image_path.with_suffix(".txt")
+ caption_path.unlink(missing_ok=True)
+
+ def delete_sample(self, path):
+ basepath = self.get_state("path")
+ image_path = Path(basepath) / path
+ mask_path = image_path.with_name(f"{image_path.stem}-masklabel.png")
+ caption_path = image_path.with_suffix(".txt")
+
+ image_path.unlink(missing_ok=True)
+ mask_path.unlink(missing_ok=True)
+ caption_path.unlink(missing_ok=True)
+
+ with self.critical_region_write():
+ if path in self.config.files:
+ self.config.files.remove(path)
+
+ @staticmethod
+ def get_caption_mask_exist(path, file):
+ full_path = Path(path) / file
+ caption_path = full_path.with_suffix(".txt")
+ mask_path = full_path.with_name(f"{full_path.stem}-masklabel.png")
+
+ return str(file), caption_path.exists(), mask_path.exists()
+
+
+ def __is_supported(self, filename):
+ """
+ 6-10× faster than the original:
+ * No Path() construction
+ * No lower() for every file (only for the slice that matters)
+ * One hash-lookup, one endswith, no branches
+ """
+ dot = filename.rfind('.')
+ if dot == -1:
+ return False
+
+ # Check if the stem (filename without extension) ends with the mask suffix
+ if filename[:dot].endswith("-masklabel"):
+ return False
+
+ ext = filename[dot:].lower() # slice, not copy of whole string
+ return ext in {'.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.webp', '.jxl'}
diff --git a/modules/ui/models/MaskHistoryModel.py b/modules/ui/models/MaskHistoryModel.py
new file mode 100644
index 000000000..a4d9b0e55
--- /dev/null
+++ b/modules/ui/models/MaskHistoryModel.py
@@ -0,0 +1,111 @@
+from modules.ui.models.SingletonConfigModel import SingletonConfigModel
+from modules.util.config.BaseConfig import BaseConfig
+
+import cv2
+import numpy as np
+
+
+class MaskHistoryConfig(BaseConfig):
+ buffer: list
+ ptr: int
+ current_mask: np.ndarray
+ original_mask: np.ndarray
+ width: int
+ height: int
+
+
+ @staticmethod
+ def default_values():
+ data = []
+
+ # name, default value, data type, nullable
+ data.append(("buffer", [], list, False))
+ data.append(("ptr", 0, int, False))
+ data.append(("original_mask", None, np.ndarray, True))
+ data.append(("current_mask", None, np.ndarray, True))
+ data.append(("width", 0, int, False))
+ data.append(("height", 0, int, False))
+
+ return MaskHistoryConfig(data)
+
+class MaskHistoryModel(SingletonConfigModel):
+ def __init__(self):
+ super().__init__(MaskHistoryConfig.default_values())
+ self.draw = None
+
+ def load_mask(self, mask):
+ with self.critical_region_write():
+ self.config.buffer = []
+ self.config.ptr = 0
+ self.config.original_mask = mask
+ self.config.current_mask = mask.copy()
+ self.config.width, self.config.height = mask.shape
+ packed, _, _ = self.__pack(self.config.current_mask)
+ self.config.buffer.append(packed)
+
+ def __pack(self, mask):
+ # Encodes np.bool into np.uint8 (flattened and zero-padded at the end).
+ w, h = mask.shape
+ return np.packbits(mask.astype(np.bool)), w, h
+
+ def __unpack(self, mask, w, h):
+ # Unpack np.uint8 into np.bool, remove zero-padding at the end, and reshape.
+ return np.unpackbits(mask)[:w * h].reshape((w, h)).astype(np.uint8).copy()
+
+ def undo(self):
+ with self.critical_region_write():
+ if self.config.ptr > 0:
+ self.config.ptr -= 1
+ self.config.current_mask = self.__unpack(self.config.buffer[self.config.ptr], self.config.width, self.config.height)
+
+ def redo(self):
+ with self.critical_region_write():
+ if self.config.ptr < len(self.config.buffer) - 1:
+ self.config.ptr += 1
+ self.config.current_mask = self.__unpack(self.config.buffer[self.config.ptr], self.config.width, self.config.height)
+
+ def commit(self):
+ with self.critical_region_write():
+ if self.config.current_mask is not None:
+ if self.config.ptr < len(self.config.buffer) - 1:
+ self.config.buffer = self.config.buffer[:self.config.ptr + 1] # Invalidate the future before adding a new state.
+
+ packed, _, _ = self.__pack(self.config.current_mask)
+ self.config.buffer.append(packed)
+ self.config.ptr = len(self.config.buffer) - 1
+
+ def clear_history(self):
+ with self.critical_region_write():
+ if self.config.original_mask is not None:
+ self.config.buffer = []
+ self.config.ptr = 0
+ self.config.current_mask = self.config.original_mask.copy()
+ packed, _, _ = self.__pack(self.config.current_mask)
+ self.config.buffer.append(packed)
+
+ def delete_mask(self):
+ with self.critical_region_write():
+ self.config.current_mask = np.ones_like(self.config.current_mask) * 255
+
+
+ def fill(self, x, y, color):
+ with self.critical_region_write():
+ if self.config.current_mask is not None:
+ cv2.floodFill(self.config.current_mask, None, (x, y), color)
+
+ def paint_stroke(self, x0, y0, x1, y1, radius, color, commit=False):
+ with self.critical_region_write():
+ if self.config.current_mask is not None:
+ # Draw line between points
+ line_width = 2 * radius + 1
+ cv2.line(self.config.current_mask, (x0, y0), (x1, y1), color, line_width)
+
+ # Draw circle at start point
+ cv2.circle(self.config.current_mask, (x0, y0), radius, color, -1)
+
+ # Draw circle at end point if different from start
+ if (x0, y0) != (x1, y1):
+ cv2.circle(self.config.current_mask, (x1, y1), radius, color, -1)
+
+ if commit:
+ self.commit()
diff --git a/modules/ui/models/MaskModel.py b/modules/ui/models/MaskModel.py
new file mode 100644
index 000000000..046e35889
--- /dev/null
+++ b/modules/ui/models/MaskModel.py
@@ -0,0 +1,78 @@
+from modules.module.ClipSegModel import ClipSegModel
+from modules.module.MaskByColor import MaskByColor
+from modules.module.RembgHumanModel import RembgHumanModel
+from modules.module.RembgModel import RembgModel
+from modules.ui.models.SingletonConfigModel import SingletonConfigModel
+from modules.util.enum.GenerateMasksModel import GenerateMasksAction, GenerateMasksModel
+from modules.util.torch_util import default_device, torch_gc
+
+import torch
+
+
+class MaskModel(SingletonConfigModel):
+ def __init__(self):
+ super().__init__({
+ "model": GenerateMasksModel.CLIPSEG,
+ "path": "",
+ "prompt": "",
+ "mode": GenerateMasksAction.REPLACE,
+ "alpha": 1.0,
+ "threshold": 0.3,
+ "smooth": 5,
+ "expand": 10,
+ "include_subdirectories": False,
+ })
+
+ self.masking_model = None
+
+ def create_masks(self, progress_fn=None):
+ with self.critical_region_read():
+ self.__load_masking_model(self.get_state("model"))
+
+ self.masking_model.mask_folder(
+ sample_dir=self.get_state("path"),
+ prompts=[self.get_state("prompt")],
+ mode=str(self.get_state("mode")).lower(),
+ alpha=float(self.get_state("alpha")),
+ threshold=float(self.get_state("threshold")),
+ smooth_pixels=int(self.get_state("smooth")),
+ expand_pixels=int(self.get_state("expand")),
+ include_subdirectories=self.get_state("include_subdirectories"),
+ progress_callback=self.__wrap_progress(progress_fn),
+ )
+
+ def __wrap_progress(self, fn):
+ def f(value, max_value):
+ if fn is not None:
+ fn({"value": value, "max_value": max_value})
+ return f
+
+ def __load_masking_model(self, model):
+ if model == GenerateMasksModel.CLIPSEG:
+ if self.masking_model is None or not isinstance(self.masking_model, ClipSegModel):
+ self.log("info", "Loading ClipSeg model, this may take a while")
+ self.release_model()
+ self.masking_model = ClipSegModel(default_device, torch.float32)
+ elif model == GenerateMasksModel.REMBG:
+ if self.masking_model is None or not isinstance(self.masking_model, RembgModel):
+ self.log("info", "Loading Rembg model, this may take a while")
+ self.release_model()
+ self.masking_model = RembgModel(default_device, torch.float32)
+ elif model == GenerateMasksModel.REMBG_HUMAN:
+ if self.masking_model is None or not isinstance(self.masking_model, RembgHumanModel):
+ self.log("info", "Loading Rembg-Human model, this may take a while")
+ self.release_model()
+ self.masking_model = RembgHumanModel(default_device, torch.float32)
+ elif model == GenerateMasksModel.COLOR:
+ if self.masking_model is None or not isinstance(self.masking_model, MaskByColor):
+ self.release_model()
+ self.masking_model = MaskByColor(default_device, torch.float32)
+
+ def release_model(self):
+ """Release all models from VRAM"""
+ freed = False
+ if self.masking_model is not None:
+ self.masking_model = None
+ freed = True
+ if freed:
+ torch_gc()
diff --git a/modules/ui/models/SampleModel.py b/modules/ui/models/SampleModel.py
new file mode 100644
index 000000000..52afd62ab
--- /dev/null
+++ b/modules/ui/models/SampleModel.py
@@ -0,0 +1,77 @@
+import copy
+import json
+import os
+
+from modules.ui.models.SingletonConfigModel import SingletonConfigModel
+from modules.ui.models.StateModel import StateModel
+from modules.util import path_util
+from modules.util.config.SampleConfig import SampleConfig
+from modules.util.path_util import write_json_atomic
+
+
+class SampleModel(SingletonConfigModel):
+ def __init__(self):
+ super().__init__([])
+
+ def __len__(self):
+ with self.critical_region_read():
+ return len(self.config)
+
+ def get_default_sample(self):
+ return SampleConfig.default_values().to_dict()
+
+ def create_new_sample(self):
+ with self.critical_region_write():
+ smp_cfg = SampleConfig.default_values()
+ self.config.append(smp_cfg)
+
+ def clone_sample(self, idx):
+ with self.critical_region_write():
+ new_element = copy.deepcopy(self.config[idx])
+ self.config.append(new_element)
+
+ def delete_sample(self, idx):
+ with self.critical_region_write():
+ self.config.pop(idx)
+
+ def toggle_samples(self):
+ some_enabled = self.some_samples_enabled()
+
+ with self.critical_region_write():
+ for smp in self.config:
+ smp.enabled = not some_enabled
+
+ def some_samples_enabled(self):
+ with self.critical_region_read():
+ out = False
+ for smp in self.config:
+ out |= smp.enabled
+ return out
+
+ def save_config(self, path="training_samples"):
+ if not os.path.exists(path):
+ os.mkdir(path)
+
+ config_path = StateModel.instance().get_state("sample_definition_file_name")
+ with self.critical_region_read():
+ write_json_atomic(config_path, [element.to_dict() for element in self.config])
+
+ def load_config(self, filename, path="training_samples"):
+ if not os.path.exists(path):
+ os.mkdir(path)
+
+ if filename == "":
+ filename = "samples"
+
+ config_file = path_util.canonical_join(path, f"{filename}.json")
+ StateModel.instance().set_state("sample_definition_file_name", config_file)
+
+ with self.critical_region_write():
+ self.config = []
+
+ if os.path.exists(config_file):
+ with open(config_file, "r") as f:
+ loaded_config_json = json.load(f)
+ for element_json in loaded_config_json:
+ element = SampleConfig.default_values().from_dict(element_json)
+ self.config.append(element)
diff --git a/modules/ui/models/SamplingModel.py b/modules/ui/models/SamplingModel.py
new file mode 100644
index 000000000..042ce0143
--- /dev/null
+++ b/modules/ui/models/SamplingModel.py
@@ -0,0 +1,137 @@
+import copy
+import os
+
+from modules.model.BaseModel import BaseModel
+from modules.modelSampler.BaseModelSampler import ModelSamplerOutput
+from modules.ui.models.SingletonConfigModel import SingletonConfigModel
+from modules.ui.models.StateModel import StateModel
+from modules.ui.models.TrainingModel import TrainingModel
+from modules.util import create
+from modules.util.config.SampleConfig import SampleConfig
+from modules.util.config.TrainConfig import TrainConfig
+from modules.util.enum.EMAMode import EMAMode
+from modules.util.enum.FileType import FileType
+from modules.util.enum.TrainingMethod import TrainingMethod
+from modules.util.time_util import get_string_timestamp
+
+import torch
+
+
+class SamplingModel(SingletonConfigModel):
+ def __init__(self):
+ super().__init__(SampleConfig.default_values())
+ self.model = None
+ self.progress_fn = None
+
+ def __update_preview(self, sampler_output: ModelSamplerOutput):
+ if sampler_output.file_type == FileType.IMAGE:
+ image = sampler_output.data
+ if self.progress_fn is not None:
+ self.progress_fn({"data": image})
+
+ def __update_progress(self, progress, max_progress):
+ if self.progress_fn is not None:
+ self.progress_fn({"value": progress, "max_value": max_progress})
+
+ def __load_model(self, train_config) -> BaseModel:
+ model_loader = create.create_model_loader(
+ model_type=train_config.model_type,
+ training_method=train_config.training_method,
+ )
+
+ model_setup = create.create_model_setup(
+ model_type=train_config.model_type,
+ train_device=torch.device(train_config.train_device),
+ temp_device=torch.device(train_config.temp_device),
+ training_method=train_config.training_method,
+ )
+
+ model_names = train_config.model_names()
+ if train_config.continue_last_backup:
+ last_backup_path = train_config.get_last_backup_path()
+
+ if last_backup_path:
+ if train_config.training_method == TrainingMethod.LORA:
+ model_names.lora = last_backup_path
+ elif train_config.training_method == TrainingMethod.EMBEDDING:
+ model_names.embedding.model_name = last_backup_path
+ else: # fine-tunes
+ model_names.base_model = last_backup_path
+
+ self.log("info", f"Loading from backup '{last_backup_path}'...")
+ else:
+ self.log("info", "No backup found, loading without backup...")
+
+ model = model_loader.load(
+ model_type=train_config.model_type,
+ model_names=model_names,
+ weight_dtypes=train_config.weight_dtypes(),
+ )
+ model.train_config = train_config
+
+ model_setup.setup_optimizations(model, train_config)
+ model_setup.setup_train_device(model, train_config)
+ model_setup.setup_model(model, train_config)
+ model.to(torch.device(train_config.temp_device))
+
+ return model
+
+ def __create_sampler(self, model, train_config):
+ return create.create_model_sampler(
+ train_device=torch.device(train_config.train_device),
+ temp_device=torch.device(train_config.temp_device),
+ model=model,
+ model_type=train_config.model_type,
+ training_method=train_config.training_method,
+ )
+
+ def sample(self, progress_fn=None):
+ self.progress_fn = progress_fn
+
+ with self.critical_region_read():
+ sample = copy.deepcopy(self.config)
+
+ if TrainingModel.instance().training_commands is not None:
+ TrainingModel.instance().training_callbacks.set_on_sample_custom(self.__update_preview)
+ TrainingModel.instance().training_callbacks.set_on_update_sample_custom_progress(self.__update_progress)
+
+ TrainingModel.instance().training_commands.sample_custom(sample)
+ else:
+ with StateModel.instance().critical_region_read():
+ train_config = TrainConfig.default_values().from_dict(StateModel.instance().config.to_dict())
+
+ train_config.optimizer.optimizer = None
+ train_config.ema = EMAMode.OFF
+
+ if self.model is None:
+ # lazy initialization
+ self.model = self.__load_model(train_config)
+ self.model_sampler = self.__create_sampler(self.model, train_config)
+
+ sample.from_train_config(train_config)
+
+ sample_dir = os.path.join(
+ train_config.workspace_dir,
+ "samples",
+ "custom",
+ )
+
+ progress = self.model.train_progress
+ sample_path = os.path.join(
+ sample_dir,
+ f"{get_string_timestamp()}-training-sample-{progress.filename_string()}"
+ )
+
+ self.model.eval()
+
+ self.model_sampler.sample(
+ sample_config=sample,
+ destination=sample_path,
+ image_format=train_config.sample_image_format,
+ video_format=train_config.sample_video_format,
+ audio_format=train_config.sample_audio_format,
+ on_sample=self.__update_preview,
+ on_update_progress=self.__update_progress,
+ )
+
+ # TODO: Should self.model be garbage collected?
diff --git a/modules/ui/models/SingletonConfigModel.py b/modules/ui/models/SingletonConfigModel.py
new file mode 100644
index 000000000..74031c9b3
--- /dev/null
+++ b/modules/ui/models/SingletonConfigModel.py
@@ -0,0 +1,209 @@
+import os
+import threading
+import traceback
+from contextlib import contextmanager
+
+from modules.util import path_util
+
+
+# Base class for config models. It provides a Singleton interface and four synchronization mechanisms:
+# - with self.critical_region_read(): allows free access to the "self" instance to any thread, as long as nobody is writing, otherwise waits for the writer to finish.
+# - with self.critical_region-write(): waits for all the readers to finish, and then blocks the instance "self" for writing.
+# - with self.critical_region(): blocks all the threads trying to access the "self" instance in a generic way.
+# - with self.critical_region_global(): blocks ALL the model instances derived from this class.
+# Ideally, only read/write accesses should be used, the other methods are there to cover only limited use cases, as they cause significant performance degradation.
+# Logging, with the self.log() method, blocks all instances, albeit using a specific reentrant lock.
+class SingletonConfigModel:
+ _instance = None
+ _frozenConfig = None
+ _is_frozen = False
+
+ # The following are reentrant locks shared across all the subclasses.
+ _global_mutex = threading.RLock() # Generic.
+ _log_mutex = threading.RLock() # Specific for logging messages.
+
+ def __init__(self, config=None):
+ self.config = config
+
+ self._mutex = threading.RLock() # Local reentrant lock.
+
+ # Reentrant read-write lock implementation based on: https://gist.github.com/icezyclon/124df594496dee71ce8455a31b1dd29f
+ self._writer_id = None
+ self._writer_count = 0
+ self._readers = {}
+ self._condition = threading.Condition(threading.RLock())
+
+ def __acquire_read_lock(self):
+ id = threading.get_ident()
+ with self._condition:
+ self._readers[id] = self._readers.get(id, 0) + 1
+
+ def __release_read_lock(self):
+ id = threading.get_ident()
+ with self._condition:
+ if id not in self._readers:
+ raise RuntimeError(f"Read lock was released while not holding it by thread {id}")
+ if self._readers[id] == 1:
+ del self._readers[id]
+ else:
+ self._readers[id] -= 1
+
+ if not self._readers:
+ self._condition.notify()
+
+ def __acquire_write_lock(self):
+ id = threading.get_ident()
+
+ self._condition.acquire()
+ if self._writer_id == id:
+ self._writer_count += 1
+ return
+
+ times_reading = self._readers.pop(id, 0)
+ while len(self._readers) > 0:
+ self._condition.wait()
+ self._writer_id = id
+ self._writer_count += 1
+ if times_reading:
+ self._readers[id] = times_reading
+
+ def __release_write_lock(self):
+ if self._writer_id != threading.get_ident():
+ raise RuntimeError(f"Write lock was released while not holding it by thread {threading.current_thread().ident}")
+ self._writer_count -= 1
+ if self._writer_count == 0:
+ self._writer_id = None
+ self._condition.notify()
+ self._condition.release()
+
+ @contextmanager
+ def critical_region_read(self):
+ try:
+ self.__acquire_read_lock()
+ yield
+ finally:
+ self.__release_read_lock()
+
+ @contextmanager
+ def critical_region_write(self):
+ try:
+ self.__acquire_write_lock()
+ yield
+ finally:
+ self.__release_write_lock()
+
+ @contextmanager
+ def critical_region_global(self):
+ try:
+ self._global_mutex.acquire()
+ yield
+ finally:
+ self._global_mutex.release()
+
+ @contextmanager
+ def critical_region(self):
+ try:
+ self._mutex.acquire()
+ yield
+ finally:
+ self._mutex.release()
+
+ @classmethod
+ def instance(cls):
+ if cls._instance is None:
+ cls._instance = cls()
+ return cls._instance
+
+ def log(self, severity, message):
+ self._log_mutex.acquire()
+ print(f"{severity}: {message}") # TODO: use logging.logger, logging on file, other approach?
+ # Proposed severities: "critical", "error", "warning", "debug", "info"
+ # Maybe some of them default to files, others to console, other both?
+ self._log_mutex.release()
+
+ # Read a list of config variables at once, in a thread-safe fashion.
+ # Important: this method should be used at the beginning of long computations, to fetch a coherent collection of values, regardless of
+ def bulk_read(self, *paths, as_dict=False):
+ with self.critical_region_read():
+ if as_dict:
+ out = {path: self.get_state(path) for path in paths}
+ else:
+ out = [self.get_state(path) for path in paths]
+ return out
+
+ # Write a list of config variables at once, in a thread-safe fashion.
+ def bulk_write(self, kv_pairs):
+ with self.critical_region_write():
+ for k, v in kv_pairs.items():
+ self.set_state(k, v)
+
+
+ # Read a single config variable in a thread-safe fashion.
+ def get_state(self, path):
+ if self.config is not None:
+ try:
+ with self.critical_region_read():
+ ref = self.config
+ if path == "":
+ return ref
+
+ for key in str(path).split("."):
+ if isinstance(ref, list):
+ ref = ref[int(key)]
+ elif isinstance(ref, dict) and key in ref:
+ ref = ref[key]
+ elif hasattr(ref, key):
+ ref = getattr(ref, key)
+ else:
+ self.log("debug", f"Key {key} not found in config")
+ return None
+ return ref
+
+ except Exception:
+ self.log("critical", traceback.format_exc())
+ return None
+
+ # Write a single config variable in a thread-safe fashion.
+ def set_state(self, path, value):
+ if self.config is not None:
+ with self.critical_region_write():
+ ref = self.config
+ for ptr in str(path).split(".")[:-1]:
+ if isinstance(ref, list):
+ ref = ref[int(ptr)]
+ elif isinstance(ref, dict):
+ ref = ref[ptr]
+ elif hasattr(ref, ptr):
+ ref = getattr(ref, ptr)
+ if isinstance(ref, list):
+ if isinstance(ref[int(path.split(".")[-1])], float):
+ ref[int(path.split(".")[-1])] = float(value)
+ else:
+ ref[int(path.split(".")[-1])] = value
+ elif isinstance(ref, dict):
+ if path.split(".")[-1] and isinstance(ref[path.split(".")[-1]], float):
+ ref[path.split(".")[-1]] = float(value)
+ else:
+ ref[path.split(".")[-1]] = value
+ elif hasattr(ref, path.split(".")[-1]):
+ if isinstance(getattr(ref, path.split(".")[-1]), float):
+ setattr(ref, path.split(".")[-1], float(value))
+ else:
+ setattr(ref, path.split(".")[-1], value)
+ else:
+ self.log("debug", f"Key {path} not found in config")
+
+ def load_available_config_names(self, dir="training_presets", include_default=True):
+ configs = [("", path_util.canonical_join(dir, "#.json"))] if include_default else []
+
+ if os.path.isdir(dir):
+ for path in os.listdir(dir):
+ if path != "#.json":
+ path = path_util.canonical_join(dir, path)
+ if path.endswith(".json") and os.path.isfile(path):
+ name = os.path.basename(path)
+ name = os.path.splitext(name)[0]
+ configs.append((name, path))
+ configs.sort()
+
+ return configs
diff --git a/modules/ui/models/StateModel.py b/modules/ui/models/StateModel.py
new file mode 100644
index 000000000..b36c24689
--- /dev/null
+++ b/modules/ui/models/StateModel.py
@@ -0,0 +1,194 @@
+import copy
+import faulthandler
+import json
+import os
+import subprocess
+import sys
+import traceback
+from pathlib import Path
+
+import scripts.generate_debug_report
+from modules.ui.models.SingletonConfigModel import SingletonConfigModel
+from modules.util import path_util
+from modules.util.config.SecretsConfig import SecretsConfig
+from modules.util.config.TrainConfig import TrainConfig, TrainEmbeddingConfig
+from modules.util.enum.CloudType import CloudType
+from modules.util.path_util import write_json_atomic
+
+from scalene import (
+ scalene_profiler, # TODO: importing Scalene sets the application locale to ANSI-C, while QT6 uses UTF-8 by default. We could change locale to C to suppress warnings, but this may cause problems with some features...
+)
+
+
+class StateModel(SingletonConfigModel):
+ def __init__(self):
+ super().__init__(TrainConfig.default_values())
+ self.is_profiling = False
+ self.tensorboard_subprocess = None
+ self.old_workspace = None
+ self.old_tensorboard_port = None
+ self.old_expose_tensorboard = None
+
+ def save_default(self):
+ with self.critical_region_read():
+ self.save_to_file("#")
+ self.__save_secrets("secrets.json")
+
+ def save_config(self, filename):
+ with open(filename, "w") as f, self.critical_region_read():
+ json.dump(self.config.to_pack_dict(secrets=False), f, indent=4)
+
+ def load_config(self, filename):
+ basename = os.path.basename(filename)
+ is_built_in_preset = basename.startswith("#") and basename != "#.json"
+
+ if os.path.exists(filename):
+ with open(filename, "r") as f:
+ loaded_dict = json.load(f)
+ default_config = TrainConfig.default_values()
+ if is_built_in_preset:
+ # always assume built-in configs are saved in the most recent version
+ loaded_dict["__version"] = default_config.config_version
+ loaded_config = default_config.from_dict(loaded_dict).to_unpacked_config()
+
+ if os.path.exists("secrets.json"):
+ with open("secrets.json", "r") as f:
+ secrets_dict = json.load(f)
+ loaded_config.secrets = SecretsConfig.default_values().from_dict(secrets_dict)
+
+ with self.critical_region_write():
+ self.config.from_dict(loaded_config.to_dict())
+
+
+ def save_to_file(self, name):
+ name = path_util.safe_filename(name)
+ path = path_util.canonical_join("training_presets", f"{name}.json")
+
+ with self.critical_region_read():
+ write_json_atomic(path, self.config.to_settings_dict(secrets=False))
+
+ return path
+
+ def __save_secrets(self, path):
+ write_json_atomic(path, self.config.secrets.to_dict())
+
+ return path
+
+ def set_scheduler_params(self, idx, key, value):
+ with self.critical_region_write():
+ if len(self.config.scheduler_params) == idx:
+ self.config.scheduler_params.append({"key": key, "value": value})
+ elif len(self.config.scheduler_params) > idx:
+ self.config.scheduler_params[idx] = {"key": key, "value": value}
+
+ def create_new_embedding(self):
+ with self.critical_region_write():
+ emb_cfg = TrainEmbeddingConfig.default_values()
+ self.config.additional_embeddings.append(emb_cfg)
+
+ def clone_embedding(self, idx):
+ with self.critical_region_write():
+ new_element = copy.deepcopy(self.config.additional_embeddings[idx])
+ new_element.uuid = self.get_random_uuid()
+
+ self.config.additional_embeddings.append(new_element)
+
+ def delete_embedding(self, idx):
+ with self.critical_region_write():
+ self.config.additional_embeddings.pop(idx)
+
+ def get_random_uuid(self):
+ return TrainEmbeddingConfig.default_values().uuid
+
+ def dump_stack(self):
+ with open('stacks.txt', 'w') as f:
+ faulthandler.dump_traceback(f)
+
+ def toggle_profiler(self):
+ if self.is_profiling:
+ scalene_profiler.stop()
+ else:
+ scalene_profiler.start()
+
+ def start_tensorboard(self):
+ ws, port, expose = self.bulk_read("workspace_dir", "tensorboard_port", "tensorboard_expose")
+
+ if self.old_tensorboard_port != port and self.old_workspace != ws and self.old_expose_tensorboard != expose:
+ if self.tensorboard_subprocess:
+ self.stop_tensorboard()
+
+ with self.critical_region_write():
+ self.old_tensorboard_port = port
+ self.old_workspace = ws
+ self.old_expose_tensorboard = expose
+
+
+
+ tensorboard_executable = os.path.join(os.path.dirname(sys.executable), "tensorboard")
+ tensorboard_log_dir = os.path.join(ws, "tensorboard")
+
+ os.makedirs(Path(tensorboard_log_dir).absolute(), exist_ok=True)
+
+ tensorboard_args = [
+ tensorboard_executable,
+ "--logdir",
+ tensorboard_log_dir,
+ "--port",
+ str(port),
+ "--samples_per_plugin=images=100,scalars=10000",
+ ]
+
+ if expose:
+ tensorboard_args.append("--bind_all")
+
+ try:
+ self.tensorboard_subprocess = subprocess.Popen(tensorboard_args)
+ except Exception:
+ self.tensorboard_subprocess = None
+
+ def stop_tensorboard(self):
+ if self.tensorboard_subprocess:
+ try:
+ self.tensorboard_subprocess.terminate()
+ self.tensorboard_subprocess.wait(timeout=5)
+ except subprocess.TimeoutExpired:
+ self.tensorboard_subprocess.kill()
+ except Exception:
+ pass
+ finally:
+ self.tensorboard_subprocess = None
+
+ def enable_embeddings(self):
+ add_embs = len(self.get_state("additional_embeddings"))
+ train_embs = {f"additional_embeddings.{idx}.train": True for idx in range(add_embs)}
+ self.bulk_write(train_embs)
+
+
+ def get_gpus(self):
+ gpus = []
+
+ type, key = self.bulk_read("cloud.type", "secrets.cloud.api_key")
+
+ if type == CloudType.RUNPOD:
+ import runpod
+ runpod.api_key = key
+ gpus = runpod.get_gpus()
+
+ return gpus
+
+ def generate_debug_package(self, zip_name, progress_fn=None):
+ zip_path = Path(zip_name)
+
+ if progress_fn is not None:
+ progress_fn({"status": "Generating debug package..."})
+
+ try:
+ with self.critical_region_write():
+ config_json_string = json.dumps(self.config.to_pack_dict(secrets=False))
+ scripts.generate_debug_report.create_debug_package(str(zip_path), config_json_string)
+ if progress_fn is not None:
+ progress_fn({"status": f"Debug package saved to {zip_path.name}"})
+ except Exception as e:
+ self.log("critical", traceback.format_exc())
+ if progress_fn is not None:
+ progress_fn({"status": f"Error generating debug package: {e}"})
diff --git a/modules/ui/models/TimestepGenerator.py b/modules/ui/models/TimestepGenerator.py
new file mode 100644
index 000000000..7f60274b5
--- /dev/null
+++ b/modules/ui/models/TimestepGenerator.py
@@ -0,0 +1,51 @@
+
+from modules.modelSetup.mixin.ModelSetupNoiseMixin import (
+ ModelSetupNoiseMixin,
+)
+from modules.util.config.TrainConfig import TrainConfig
+from modules.util.enum.TimestepDistribution import TimestepDistribution
+
+import torch
+from torch import Tensor
+
+
+class TimestepGenerator(ModelSetupNoiseMixin):
+
+ def __init__(
+ self,
+ timestep_distribution: TimestepDistribution,
+ min_noising_strength: float,
+ max_noising_strength: float,
+ noising_weight: float,
+ noising_bias: float,
+ timestep_shift: float,
+ ):
+ super().__init__()
+
+ self.timestep_distribution = timestep_distribution
+ self.min_noising_strength = min_noising_strength
+ self.max_noising_strength = max_noising_strength
+ self.noising_weight = noising_weight
+ self.noising_bias = noising_bias
+ self.timestep_shift = timestep_shift
+
+ def generate(self) -> Tensor:
+ generator = torch.Generator()
+ generator.seed()
+
+ config = TrainConfig.default_values()
+ config.timestep_distribution = self.timestep_distribution
+ config.min_noising_strength = self.min_noising_strength
+ config.max_noising_strength = self.max_noising_strength
+ config.noising_weight = self.noising_weight
+ config.noising_bias = self.noising_bias
+ config.timestep_shift = self.timestep_shift
+
+
+ return self._get_timestep_discrete(
+ num_train_timesteps=1000,
+ deterministic=False,
+ generator=generator,
+ batch_size=1000000,
+ config=config,
+ )
diff --git a/modules/ui/models/TrainingModel.py b/modules/ui/models/TrainingModel.py
new file mode 100644
index 000000000..042ba3cef
--- /dev/null
+++ b/modules/ui/models/TrainingModel.py
@@ -0,0 +1,146 @@
+import datetime
+import time
+import traceback
+
+from modules.ui.models.SingletonConfigModel import SingletonConfigModel
+from modules.ui.models.StateModel import StateModel
+from modules.util import create
+from modules.util.callbacks.TrainCallbacks import TrainCallbacks
+from modules.util.commands.TrainCommands import TrainCommands
+from modules.util.config.TrainConfig import TrainConfig
+from modules.util.torch_util import torch_gc
+from modules.util.TrainProgress import TrainProgress
+
+import torch
+
+
+class TrainingModel(SingletonConfigModel):
+ def __init__(self):
+ super().__init__(None)
+ self.training_commands = None
+ self.training_callbacks = None
+ self.train_config = None
+
+ def backup_now(self):
+ with self.critical_region_read():
+ if self.training_commands:
+ self.training_commands.backup()
+
+ def save_now(self):
+ with self.critical_region_read():
+ if self.training_commands:
+ self.training_commands.save()
+
+ def sample_now(self):
+ with self.critical_region_read():
+ if self.training_commands:
+ self.training_commands.sample_default()
+
+ def train(self, reattach=False, progress_fn=None):
+ self.progress_fn = progress_fn
+
+ cfg = StateModel.instance().bulk_read("tensorboard", "tensorboard_always_on",
+ "cloud.enabled",
+ as_dict=True)
+
+
+ StateModel.instance().save_default()
+
+ if cfg["tensorboard"] and not cfg["tensorboard_always_on"] and StateModel.instance().tensorboard_subprocess is not None:
+ StateModel.instance().stop_tensorboard()
+
+ with self.critical_region_write():
+ self.training_commands = TrainCommands()
+
+ self.training_callbacks = TrainCallbacks(
+ on_update_train_progress=self.__on_update_train_progress(),
+ on_update_status=self.__on_update_status(),
+ )
+
+ with StateModel.instance().critical_region_read():
+ self.train_config = TrainConfig.default_values().from_dict(StateModel.instance().config.to_dict())
+
+ torch_gc()
+
+ error_caught = False
+
+
+ trainer = create.create_trainer(self.train_config, self.training_callbacks, self.training_commands,
+ reattach=reattach)
+ try:
+ trainer.start()
+ if cfg["cloud.enabled"]:
+ StateModel.instance().set_state("secrets.cloud", self.train_config.secrets.cloud)
+ self.start_time = time.monotonic()
+ trainer.train()
+ except Exception:
+ if cfg["cloud.enabled"]:
+ StateModel.instance().set_state("secrets.cloud", self.train_config.secrets.cloud)
+ error_caught = True
+ self.log("critical", traceback.format_exc())
+
+ trainer.end()
+
+ # clear gpu memory
+ del trainer
+
+ self.training_thread = None
+ with self.critical_region_write():
+ self.training_commands = None
+ torch.clear_autocast_cache()
+ torch_gc()
+
+ if error_caught:
+ if self.progress_fn is not None:
+ self.progress_fn({"status": "Error: check the console for details", "event": "cancelled"})
+ else:
+ if self.progress_fn is not None:
+ self.progress_fn({"status": "Stopped", "event": "finished"})
+
+ if cfg["tensorboard_always_on"] and StateModel.instance().tensorboard_subprocess is not None:
+ StateModel.instance().start_tensorboard()
+
+
+ def __on_update_train_progress(self):
+ def f(train_progress: TrainProgress, max_steps: int, max_epoch: int):
+ if self.progress_fn is not None:
+ self.progress_fn({"epoch": train_progress.epoch,
+ "max_epochs": max_epoch,
+ "step": train_progress.epoch_step,
+ "max_steps": max_steps,
+ "eta": self.__calculate_eta_string(train_progress, max_steps, max_epoch)})
+ return f
+
+ def __on_update_status(self):
+ def f(status: str):
+ if self.progress_fn is not None:
+ self.progress_fn({"status": status})
+ return f
+
+ def __calculate_eta_string(self, train_progress: TrainProgress, max_step: int, max_epoch: int) -> str | None:
+ spent_total = time.monotonic() - self.start_time
+ steps_done = train_progress.epoch * max_step + train_progress.epoch_step
+ remaining_steps = (max_epoch - train_progress.epoch - 1) * max_step + (max_step - train_progress.epoch_step)
+ total_eta = spent_total / steps_done * remaining_steps
+
+ if train_progress.global_step <= 30:
+ return "Estimating ..."
+
+ td = datetime.timedelta(seconds=total_eta)
+ days = td.days
+ hours, remainder = divmod(td.seconds, 3600)
+ minutes, seconds = divmod(remainder, 60)
+ if days > 0:
+ return f"{days}d {hours}h"
+ elif hours > 0:
+ return f"{hours}h {minutes}m"
+ elif minutes > 0:
+ return f"{minutes}m {seconds}s"
+ else:
+ return f"{seconds}s"
+
+ def stop_training(self):
+ if self.progress_fn is not None:
+ self.progress_fn({"event": "stopping", "status": "Stopping..."})
+ if self.training_commands is not None:
+ self.training_commands.stop()
diff --git a/modules/ui/models/VideoModel.py b/modules/ui/models/VideoModel.py
new file mode 100644
index 000000000..f8a6f8d7b
--- /dev/null
+++ b/modules/ui/models/VideoModel.py
@@ -0,0 +1,503 @@
+import concurrent.futures
+import math
+import os
+import pathlib
+import random
+import shlex
+import shutil
+import subprocess
+
+from modules.ui.models.SingletonConfigModel import SingletonConfigModel
+from modules.util.path_util import SUPPORTED_VIDEO_EXTENSIONS
+
+import cv2
+import scenedetect
+
+
+class VideoModel(SingletonConfigModel):
+ def __init__(self):
+ super().__init__({
+ "clips": {
+ "single_video": "",
+ "range_start": "00:00:00",
+ "range_end": "99:99:99",
+ "directory": "",
+ "output": "",
+ "output_to_subdirectories": False,
+ "split_at_cuts": False,
+ "max_length": 0,
+ "fps": 0,
+ "remove_borders": False,
+ "crop_variation": 0
+ },
+ "images": {
+ "single_video": "",
+ "range_start": "00:00:00",
+ "range_end": "99:99:99",
+ "directory": "",
+ "output": "",
+ "output_to_subdirectories": False,
+ "capture_rate": 0,
+ "blur_removal": 0,
+ "remove_borders": False,
+ "crop_variation": 0
+ },
+ "download": {
+ "single_link": "",
+ "link_list": "",
+ "output": "",
+ "additional_args": "--quiet --no-warnings --progress"
+ }
+ })
+
+ def __get_vid_paths(self, batch_mode: bool, input_path_single: str, input_path_dir: str):
+ input_videos = []
+ if not batch_mode:
+ path = pathlib.Path(input_path_single)
+ if path.is_file():
+ vid = cv2.VideoCapture(str(path))
+ ok = False
+ try:
+ if vid.isOpened():
+ ok, _ = vid.read()
+ finally:
+ vid.release()
+ if ok:
+ return [path]
+ else:
+ self.log("error", "Invalid video file!")
+ return []
+ else:
+ self.log("error", "No file specified, or invalid file path!")
+ return []
+ else:
+ input_videos = []
+ if not pathlib.Path(input_path_dir).is_dir() or input_path_dir == "":
+ self.log("error", "Invalid input directory!")
+ return []
+ # Only traverse supported extensions to avoid opening every file.
+ lower_exts = {e.lower() for e in SUPPORTED_VIDEO_EXTENSIONS}
+ for path in pathlib.Path(input_path_dir).rglob("*"):
+ if path.is_file() and path.suffix.lower() in lower_exts:
+ vid = cv2.VideoCapture(str(path))
+ ok = False
+ try:
+ if vid.isOpened():
+ ok, _ = vid.read()
+ finally:
+ vid.release()
+ if ok:
+ input_videos.append(path)
+ self.log("info", f'Found {len(input_videos)} videos to process')
+ return input_videos
+
+ def __get_random_aspect(self, height : int, width : int, variation: float) -> tuple[int, int, int, int]:
+ if variation == 0:
+ return 0, height, 0, width
+
+ old_aspect = height/width
+ variation_scaled = old_aspect*variation
+ if old_aspect > 1.2:
+ new_aspect = min(4.0, max(1.0, random.triangular(old_aspect-(variation_scaled*1.5), old_aspect+(variation_scaled/2), old_aspect)))
+ elif old_aspect < 0.85:
+ new_aspect = max(0.25, min(1.0, random.triangular(old_aspect-(variation_scaled/2), old_aspect+(variation_scaled*1.5), old_aspect)))
+ else:
+ new_aspect = random.triangular(old_aspect-variation_scaled, old_aspect+variation_scaled)
+
+ new_aspect = round(new_aspect, 2)
+ if new_aspect > old_aspect:
+ new_height = int(height)
+ new_width = int(width*(old_aspect/new_aspect))
+ elif new_aspect < old_aspect:
+ new_height = int(height*(new_aspect/old_aspect))
+ new_width = int(width)
+ else:
+ new_height = int(height)
+ new_width = int(width)
+
+ position_x = random.randint(0, width-new_width)
+ position_y = random.randint(0, height-new_height)
+ #print(new_aspect)
+ #print(position_y, new_height, position_x, new_width)
+ return position_y, new_height, position_x, new_width
+
+ def __find_main_contour(self, frame):
+ frame_grayscale = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+ _, frame_thresh = cv2.threshold(frame_grayscale, 15, 255, cv2.THRESH_BINARY)
+ frame_contours, _ = cv2.findContours(frame_thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ if frame_contours:
+ frame_maincontour = max(frame_contours, key=cv2.contourArea)
+ x1, y1, w1, h1 = cv2.boundingRect(frame_maincontour)
+ else: #fallback if no contours detected
+ x1 = 0
+ y1 = 0
+ h1, w1, _ = frame.shape
+ if not frame_contours or h1 < 10 or w1 < 10: #if bounding box did not detect the correct area, likely due to black frame
+ x1 = 0
+ y1 = 0
+ h1, w1, _ = frame.shape
+ return x1, y1, w1, h1
+
+ def extract_clips_multi(self, batch_mode: bool):
+ cfg = self.bulk_read("clips.output", "clips.max_length", "clips.crop_variation",
+ "clips.fps", "clips.single_video", "clips.directory",
+ "clips.output_to_subdirectories", "clips.split_at_cuts",
+ "clips.remove_borders", "clips.range_start", "clips.range_end",
+ as_dict=True)
+
+
+ # if not pathlib.Path(cfg["clips.output"]).is_dir() or cfg["clips.output"] == "":
+ # self.log("error", "Invalid output directory!")
+ # return
+ pathlib.Path(cfg["clips.output"]).mkdir(parents=True, exist_ok=True)
+
+
+ # validate numeric inputs
+ try:
+ max_length = float(cfg["clips.max_length"])
+ crop_variation = float(cfg["clips.crop_variation"])
+ target_fps = int(cfg["clips.fps"])
+ except ValueError:
+ self.log("error", "Invalid numeric input for Max Length, Crop Variation, or FPS.")
+ return
+ # if max_length <= 0.25:
+ # self.log("error", "Max Length of clips must be > 0.25 seconds.")
+ # return
+ # if target_fps < 0:
+ # self.log("error", "Target FPS must be a positive integer (or 0 to skip fps re-encoding).")
+ # return
+ # if not (0.0 <= crop_variation < 1.0):
+ # self.log("error", "Crop Variation must be between 0.0 and 1.0.")
+ # return
+
+ input_videos = self.__get_vid_paths(batch_mode, cfg["clips.single_video"], cfg["clips.directory"])
+ if len(input_videos) == 0: # exit if no paths found
+ return
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
+ for video_path in input_videos:
+ if cfg["clips.output_to_subdirectories"] and batch_mode:
+ output_directory = os.path.join(cfg["clips.output"],
+ os.path.splitext(os.path.relpath(video_path, cfg["clips.directory"]))[0])
+ elif cfg["clips.output_to_subdirectories"] and not batch_mode:
+ output_directory = os.path.join(cfg["clips.output"],
+ os.path.splitext(os.path.basename(video_path))[0])
+ else:
+ output_directory = cfg["clips.output"]
+
+ if batch_mode:
+ executor.submit(self.__extract_clips,
+ str(video_path), "00:00:00", "99:99:99", max_length, cfg["clips.split_at_cuts"],
+ cfg["clips.remove_borders"], crop_variation, target_fps, output_directory)
+ else:
+ executor.submit(self.__extract_clips,
+ str(video_path), str(cfg["clips.range_start"]), str(cfg["clips.range_end"]), max_length, cfg["clips.split_at_cuts"],
+ cfg["clips.remove_borders"], crop_variation, target_fps, output_directory)
+
+ if batch_mode:
+ self.log("info", f'Clip extraction from all videos in {cfg["clips.directory"]} complete')
+ else:
+ self.log("info", f'Clip extraction from {cfg["clips.single_video"]} complete')
+
+ def __extract_clips(self, video_path: str, timestamp_min: str, timestamp_max: str, max_length: float,
+ split_at_cuts: bool, remove_borders : bool, crop_variation: float, target_fps: int, output_dir: str):
+ video = cv2.VideoCapture(video_path)
+ fps = video.get(cv2.CAP_PROP_FPS) or 0.0
+ if fps <= 0:
+ self.log("warning", f'Could not read FPS for "{os.path.basename(video_path)}". Falling back to 30 FPS.') # fallback to some sane FPS value
+ fps = 30.0
+ max_length_frames = int(max_length * fps) #convert max length from seconds to frames
+ min_length_frames = max(int(0.25*fps), 1) #minimum clip length of 1/4 second or 1 frame
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
+ timestamp_max_frame = int(sum(int(x) * 60 ** i for i, x in enumerate(reversed(timestamp_max.split(':')))) * fps)
+ timestamp_max_frame = min(timestamp_max_frame, max(total_frames - 1, 0))
+ timestamp_min_frame = int(sum(int(x) * 60 ** i for i, x in enumerate(reversed(timestamp_min.split(':')))) * fps)
+ timestamp_min_frame = min(timestamp_min_frame, timestamp_max_frame)
+
+ if split_at_cuts:
+ #use scenedetect to find cuts, based on start/end frame number
+ timecode_list = scenedetect.detect(
+ str(video_path),
+ scenedetect.AdaptiveDetector(),
+ start_time=int(timestamp_min_frame),
+ end_time=int(timestamp_max_frame))
+ scene_list = [(x[0].get_frames(), x[1].get_frames()) for x in timecode_list]
+ if len(scene_list) == 0:
+ scene_list = [(timestamp_min_frame, timestamp_max_frame)] # use start/end frames if no scenes detected
+ else:
+ scene_list = [(timestamp_min_frame, timestamp_max_frame)] # default if not using cuts, start and end of time range
+
+ scene_list_split = []
+ for scene in scene_list:
+ length = scene[1]-scene[0]
+ if length > max_length_frames: #check for any scenes longer than max length
+ n = math.ceil(length/max_length_frames) #divide into n new scenes
+ new_length = int(length/n)
+ new_splits = range(scene[0], scene[1]+min_length_frames, new_length) #divide clip into closest chunks to max_length
+ for i, _n in enumerate(new_splits[:-1]):
+ if new_splits[i+1] - new_splits[i] > min_length_frames:
+ scene_list_split += [(new_splits[i], new_splits[i+1])]
+ else:
+ if length > (min_length_frames+2):
+ scene_list_split += [(scene[0]+1, scene[1]-1)] #trim first and last frame from detected scenes to avoid transition artifacts
+
+ self.log("info", f'Video "{os.path.basename(video_path)}" being split into {len(scene_list_split)} clips in {output_dir}...')
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
+ for scene in scene_list_split:
+ executor.submit(self.__save_clip, scene, video_path, target_fps, remove_borders, crop_variation, output_dir)
+
+ video.release()
+
+ def __save_clip(self, scene : tuple[int, int], video_path : str, target_fps : int, remove_borders : bool, crop_variation : float, output_dir : str):
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ basename, ext = os.path.splitext(os.path.basename(video_path))
+ video = cv2.VideoCapture(str(video_path))
+ fps = video.get(cv2.CAP_PROP_FPS) or 0.0
+ if fps <= 0:
+ self.log("warning", f'Could not read FPS for "{os.path.basename(video_path)}". Falling back to 30 FPS.')
+ fps = 30.0
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ output_name = f'{output_dir}{os.sep}{basename}_{scene[0]}-{scene[1]}'
+ output_ext = ".mp4"
+
+ video.set(cv2.CAP_PROP_POS_FRAMES, (scene[1] + scene[0])//2)
+ frame_number = int(video.get(cv2.CAP_PROP_POS_FRAMES))
+ success, frame = video.read()
+ if not success or frame is None:
+ self.log("error", f'Failed to read frame from "{os.path.basename(video_path)}" at {int(frame_number)}. Skipping clip.')
+ video.release()
+ return
+
+ #crop out borders of frame - blends five random frames from the scene to get "average" image
+ #helps prevent incorrect cropping when sampled frame may be all black or otherwise detect incorrect border
+ if remove_borders:
+ frame_blend = frame
+ for i in range(5): # blend 5 random frames to get average
+ random_frame = random.randint(scene[0], scene[1])
+ video.set(cv2.CAP_PROP_POS_FRAMES, random_frame)
+ success, frame = video.read()
+ if not success or frame is None:
+ continue
+ a = 1/(i+1)
+ b = 1-a
+ frame_blend = cv2.addWeighted(frame, a, frame_blend, b, 0)
+ x1, y1, w1, h1 = self.__find_main_contour(frame_blend)
+ else:
+ x1 = 0
+ y1 = 0
+ h1, w1, _ = frame.shape
+
+ y2, h2, x2, w2 = self.__get_random_aspect(h1, w1, crop_variation)
+ writer = cv2.VideoWriter(output_name+output_ext, fourcc, fps, (w2, h2))
+ video.set(cv2.CAP_PROP_POS_FRAMES, scene[0])
+ frame_number = int(video.get(cv2.CAP_PROP_POS_FRAMES))
+ success, frame = video.read()
+
+ while success and (frame_number < scene[1]): # loop through frames within each scene
+ frame_trimmed = frame[y1:y1+h1, x1:x1+w1] # cut out black borders if applicable
+ writer.write(frame_trimmed[y2:y2+h2, x2:x2+w2]) # save frame with random crop variation if applicable
+ success, frame = video.read()
+ frame_number += 1
+ writer.release()
+ video.release()
+
+ if target_fps > 0: # use ffmpeg to change to set framerate - saves copy and deletes original
+ if int(round(fps)) == target_fps:
+ # Already at desired fps; skip re-encode.
+ return
+ cmd = [
+ "ffmpeg", "-y",
+ "-i", f"{output_name}{output_ext}",
+ "-filter:v", f"fps={target_fps}",
+ "-an",
+ f"{output_name}_{target_fps}fps{output_ext}",
+ ]
+ proc = subprocess.run(cmd, stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL)
+ if proc.returncode == 0:
+ try:
+ os.remove(output_name + output_ext)
+ except OSError:
+ self.log("error", f"Failed to remove conversion placeholder {output_name + output_ext}, remove manually or check folder permissions.")
+
+ def extract_images_multi(self, batch_mode : bool):
+ cfg = self.bulk_read("images.output", "images.capture_rate", "images.blur_removal",
+ "images.crop_variation", "images.single_video", "images.directory",
+ "images.output_to_subdirectories", "images.directory", "images.remove_borders",
+ "images.range_start", "images.range_end",
+ as_dict=True)
+
+
+ # if not pathlib.Path(cfg["images.output"]).is_dir() or cfg["images.output"] == "":
+ # self.log("error", "Invalid output directory!")
+ # return
+ pathlib.Path(cfg["images.output"]).mkdir(parents=True, exist_ok=True)
+
+ # validate numeric inputs
+ try:
+ capture_rate = float(cfg["images.capture_rate"])
+ blur_threshold = float(cfg["images.blur_removal"])
+ crop_variation = float(cfg["images.crop_variation"])
+ except ValueError:
+ self.log("error", "Invalid numeric input for Images/sec, Blur Removal, or Crop Variation.")
+ return
+ # if capture_rate <= 0:
+ # self.log("error", "Images/sec must be > 0.")
+ # return
+ # if not (0.0 <= blur_threshold < 1.0):
+ # self.log("error", "Blur Removal must be between 0.0 and 1.0.")
+ # return
+ # if not (0.0 <= crop_variation < 1.0):
+ # self.log("error", "Crop Variation must be between 0.0 and 1.0.")
+ # return
+
+ input_videos = self.__get_vid_paths(batch_mode, cfg["images.single_video"], cfg["images.directory"])
+ if len(input_videos) == 0: #exit if no paths found
+ return
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
+ for video_path in input_videos:
+ if cfg["images.output_to_subdirectories"] and batch_mode:
+ output_directory = os.path.join(cfg["images.output"],
+ os.path.splitext(os.path.relpath(video_path, cfg["images.directory"]))[0])
+ elif cfg["images.output_to_subdirectories"] and not batch_mode:
+ output_directory = os.path.join(cfg["images.output"],
+ os.path.splitext(os.path.basename(video_path))[0])
+ else:
+ output_directory = cfg["images.output"]
+
+ if batch_mode:
+ executor.submit(self.__save_frames,
+ str(video_path), "00:00:00", "99:99:99", capture_rate,
+ blur_threshold, cfg["images.remove_borders"], crop_variation, output_directory)
+ else:
+ executor.submit(self.__save_frames,
+ str(video_path), str(cfg["images.range_start"]), str(cfg["images.range_end"]), capture_rate,
+ blur_threshold, cfg["images.remove_borders"], crop_variation, output_directory)
+ if batch_mode:
+ self.log("info", f'Image extraction from all videos in {cfg["images.directory"]} complete')
+ else:
+ self.log("info", f'Image extraction from {cfg["images.single_video"]} complete')
+
+ def __save_frames(self, video_path: str, timestamp_min: str, timestamp_max: str, capture_rate: float,
+ blur_threshold: float, remove_borders : bool, crop_variation: float, output_dir: str):
+ video = cv2.VideoCapture(video_path)
+ fps = video.get(cv2.CAP_PROP_FPS) or 0.0
+ if fps <= 0:
+ self.log("warning", f'Could not read FPS for "{os.path.basename(video_path)}". Falling back to 30 FPS.')
+ fps = 30.0
+ if capture_rate <= 0:
+ self.log("error", "Images/sec must be > 0.")
+ video.release()
+ return
+ image_rate = max(int(fps / capture_rate), 1) # frames between captures (min 1)
+ total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) or 0
+ timestamp_max_frame = int(sum(int(x) * 60 ** i for i, x in enumerate(reversed(timestamp_max.split(':')))) * fps)
+ timestamp_max_frame = min(timestamp_max_frame, max(total_frames - 1, 0))
+ timestamp_min_frame = int(sum(int(x) * 60 ** i for i, x in enumerate(reversed(timestamp_min.split(':')))) * fps)
+ timestamp_min_frame = min(timestamp_min_frame, timestamp_max_frame)
+ frame_range = range(timestamp_min_frame, timestamp_max_frame, image_rate)
+ frame_list = []
+
+ for n in frame_range:
+ frame = abs(int(random.triangular(n-(image_rate/2), n+(image_rate/2)))) #random triangular distribution around center
+ frame = max(0, min(frame, max(total_frames - 1, 0)))
+ frame_list.append(frame)
+
+ self.log("info", f'Video "{os.path.basename(video_path)}" will be split into {len(frame_list)} images in {output_dir}...')
+
+ output_list = []
+ for f in frame_list:
+ video.set(cv2.CAP_PROP_POS_FRAMES, f)
+ success, frame = video.read()
+ if success and frame is not None:
+ frame_grayscale = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+ frame_sharpness = cv2.Laplacian(frame_grayscale, cv2.CV_64F).var() #get sharpness of greyscale pic
+ output_list.append((f, frame_sharpness))
+
+ if not output_list:
+ self.log("warning", f'No frames extracted from {os.path.basename(video_path)} in the selected range.')
+ video.release()
+ return
+
+ output_list_sorted = sorted(output_list, key=lambda x: x[1])
+ cutoff = int(blur_threshold*len(output_list_sorted)) #calculate cutoff as portion of total frames
+ output_list_cut = output_list_sorted[cutoff:] # keep all frames above cutoff
+ self.log("info", f'{cutoff} blurriest images have been dropped from {os.path.basename(video_path)}')
+
+ basename, ext = os.path.splitext(os.path.basename(video_path))
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ for f in output_list_cut:
+ filename = f'{output_dir}{os.sep}{basename}_{f[0]}.jpg'
+ video.set(cv2.CAP_PROP_POS_FRAMES, f[0])
+ success, frame = video.read()
+
+ #crop out borders of frame
+ if remove_borders and success and frame is not None:
+ x1, y1, w1, h1 = self.__find_main_contour(frame)
+ frame_cropped = frame[y1:y1+h1, x1:x1+w1]
+ else:
+ frame_cropped = frame if success and frame is not None else None
+ if frame_cropped is not None:
+ x1 = 0
+ y1 = 0
+ h1, w1, _ = frame_cropped.shape
+
+ y2, h2, x2, w2 = self.__get_random_aspect(h1, w1, crop_variation)
+ #print(y1, h1, x1, w1, ":", y2, h2, x2, w2)
+
+ if success and frame is not None:
+ cv2.imwrite(filename, frame_cropped[y2:y2+h2, x2:x2+w2]) #save images
+ video.release()
+
+ def download_multi(self, batch_mode: bool):
+ cfg = self.bulk_read("download.output", "download.single_link",
+ "download.link_list", "download.additional_args",
+ as_dict=True)
+
+ # if not pathlib.Path(cfg["download.output"]).is_dir() or cfg["download.output"] == "":
+ # self.log("error", "Invalid output directory!")
+ # return
+ pathlib.Path(cfg["download.output"]).mkdir(parents=True, exist_ok=True)
+
+ if not batch_mode:
+ ydl_urls = [cfg["download.single_link"]]
+ elif batch_mode:
+ ydl_path = pathlib.Path(cfg["download.link_list"])
+ if ydl_path.is_file() and ydl_path.suffix.lower() == ".txt":
+ with open(ydl_path) as file:
+ ydl_urls = file.readlines()
+ else:
+ self.log("error", "Invalid link list!")
+ return
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
+ for url in ydl_urls:
+ executor.submit(self.__download_video,
+ url.strip(), cfg["download.output"], cfg["download.additional_args"])
+
+ self.log("info", f'Completed {len(ydl_urls)} downloads.')
+
+ def __download_video(self, url: str, output_dir: str, output_args: str):
+ url = (url or "").strip()
+ if not url:
+ self.log("warning", "Empty URL, skipping download.")
+ return
+
+ additional_args = shlex.split(output_args.strip()) if output_args and output_args.strip() else [] # Respect quotes and split into list
+
+ yt_dlp = shutil.which("yt-dlp")
+ if yt_dlp is not None:
+ cmd = [yt_dlp, "-o", "%(title)s.%(ext)s", "-P", output_dir] + additional_args + [url]
+
+ self.log("info", f'Downloading {url}...')
+ exitcode = subprocess.run(cmd).returncode
+ if exitcode == 0:
+ self.log("info", f'Download {url} done!')
+ else:
+ self.log("error", f'Failed to download {url} (process terminated with exit code {exitcode})')
+ else:
+ self.log("critical", 'yt-dlp executable not found in $PATH')
diff --git a/modules/ui/utils/FigureWidget.py b/modules/ui/utils/FigureWidget.py
new file mode 100644
index 000000000..ed559ff23
--- /dev/null
+++ b/modules/ui/utils/FigureWidget.py
@@ -0,0 +1,232 @@
+from modules.ui.utils.MaskDrawingToolbar import MaskDrawingToolbar
+from modules.util.enum.EditMode import EditMode
+from modules.util.enum.MouseButton import MouseButton
+from modules.util.enum.ToolType import ToolType
+
+import PySide6.QtGui as QtG
+from matplotlib.backend_bases import MouseButton as MplMouseButton
+from matplotlib.backends.backend_qtagg import FigureCanvas
+from matplotlib.figure import Figure
+from PySide6.QtCore import QCoreApplication as QCA
+from PySide6.QtCore import Signal
+
+# This class creates a FigureWidget associated with a matplotlib drawing area (self.figure) and a toolbar.
+# The toolbar can contain a set of default tools for zooming (zoom_tools=True), and arbitrary tools specified as a list of dictionaries with the following keys:
+# "type" (mandatory): tool type defined in modules.util.enum.ToolType
+# "fn" (optional): callback invoked when the tool is used (clicked signal for buttons, valueChanged for spinboxes)
+# "tool" (only for CHECKABLE_BUTTON): modules.util.enum.EditMode value associated with the tool. It will handle mutual exclusion of tools automatically
+# "text" (optional): the tool's text
+# "icon" (optional): the tool's icon
+# "tooltip" (optional): the tool's tooltip text
+# "shortcut" (optional): the tool's shortcut
+# "name" (mandatory for spinboxes): the tool's QWidget objectName, to be used with findChild()
+# "spinbox_range" (optional, for spinboxes only): (minimum, maximum, stepSize) for the spinbox
+# "value" (optional, for spinboxes only): the spinbox' default value
+#
+# The FigureWidget has two event handling mechanisms:
+# - Mutually exclusive tools (i.e., those associated with a ToolType.CHECKABLE_BUTTON defining a "tool" field) can use registerTool() to automatically invoke functions when clicked, released or mouseMoved events are fired.
+# - The class can optionally emit QT6 signals (emit_clicked, emit_released, emit_wheel, emit_moved passed to __init__) which can be handled externally.
+# RegisterTool accepts two types of callbacks:
+# - use_mpl_event=True: callbacks will receive matplotlib events
+# - use_mpl_event=False: callbacks will receive the same interface of QT6 signals.
+#
+# The QT6 signals exposed are:
+# wheelUp()
+# wheelDown()
+# clicked(modules.utils.enum.MouseButton, int_x, int_y)
+# released(modules.utils.enum.MouseButton, int_x, int_y)
+# moved(modules.utils.enum.MouseButton, int_start_x, int_start_y, int_end_x, int_end_y)
+# Coordinates are either in absolute canvas pixels (use_data_coordinates=False), or referred to the data loaded on the canvas (e.g., image or plot coordinates)
+
+class FigureWidget(FigureCanvas):
+ clicked = Signal(MouseButton, int, int) # x, y
+ released = Signal(MouseButton, int, int) # x, y
+ wheelUp = Signal()
+ wheelDown = Signal()
+ moved = Signal(MouseButton, int, int, int, int) # x0, y0, x1, y1.
+ # Note: signals cannot be declared with unions like "int | None". So we either declare them as object to allow emitting None values, or use -1 for events outside the image (the latter approach is safer).
+
+ def __init__(self, parent=None, width=5, height=4, dpi=100, zoom_tools=False, other_tools=None, emit_clicked=False, emit_released=False, emit_wheel=False, emit_moved=False, use_data_coordinates=True):
+ super().__init__(Figure(figsize=(width, height), layout="tight", dpi=dpi))
+ self.toolbar = MaskDrawingToolbar(self, parent=parent)
+ self.event_handlers = {}
+ self.theme = "dark" if QtG.QGuiApplication.styleHints().colorScheme() == QtG.Qt.ColorScheme.Dark else "light"
+
+ tools = []
+ if zoom_tools:
+ tools.extend([{
+ "type": ToolType.BUTTON,
+ "fn": self.toolbar.home,
+ "icon": f"resources/icons/buttons/{self.theme}/house.svg",
+ "tooltip": QCA.translate("toolbar_item", "Reset original view (CTRL+H)"),
+ "shortcut": "Ctrl+H",
+ },
+ {
+ "type": ToolType.CHECKABLE_BUTTON,
+ "tool": EditMode.PAN,
+ "icon": f"resources/icons/buttons/{self.theme}/move.svg",
+ "tooltip": QCA.translate("toolbar_item", "Left button pans, Right button zooms (CTRL+P)"),
+ "shortcut": "Ctrl+P",
+ },
+ {
+ "type": ToolType.CHECKABLE_BUTTON,
+ "tool": EditMode.ZOOM,
+ "icon": f"resources/icons/buttons/{self.theme}/search.svg",
+ "tooltip": QCA.translate("toolbar_item", "Zoom to rectangle (CTRL+Q)"),
+ "shortcut": "Ctrl+Q",
+ }])
+
+ if other_tools is not None:
+ if zoom_tools:
+ tools.append({"type": ToolType.SEPARATOR})
+ tools.extend(other_tools)
+
+ self.toolbar.addTools(tools)
+
+ if zoom_tools:
+ self.registerTool(EditMode.PAN, clicked_fn=self.toolbar.press_pan, released_fn=self.toolbar.release_pan, use_mpl_event=True)
+ self.registerTool(EditMode.ZOOM, clicked_fn=self.toolbar.press_zoom, released_fn=self.toolbar.release_zoom, use_mpl_event=True)
+
+ self.use_data_coordinates = use_data_coordinates
+ self.last_x = self.last_y = None
+
+
+ self.emit_clicked = emit_clicked
+ self.emit_released = emit_released
+ self.emit_wheel = emit_wheel
+ self.emit_moved = emit_moved
+
+ self.mpl_connect("button_press_event", self.__eventHandler())
+ self.mpl_connect("button_release_event", self.__eventHandler())
+ self.mpl_connect("scroll_event", self.__eventHandler())
+ self.mpl_connect("motion_notify_event", self.__eventHandler())
+
+ def __eventHandler(self):
+ def f(event):
+ if event.name == "button_press_event":
+ args = self.__onClicked(event)
+
+ for k, v in self.event_handlers.items():
+ if k == str(self.toolbar.mode) and v["clicked"] is not None:
+ if v["use_mpl_event"]:
+ v["clicked"](event)
+ else:
+ v["clicked"](*args)
+
+
+ elif event.name == "button_release_event":
+ args = self.__onReleased(event)
+
+ for k, v in self.event_handlers.items():
+ if k == str(self.toolbar.mode) and v["released"] is not None:
+ if v["use_mpl_event"]:
+ v["released"](event)
+ else:
+ v["released"](*args)
+
+ elif event.name == "scroll_event":
+ self.__onWheel(event)
+ elif event.name == "motion_notify_event":
+ args = self.__onMoved(event)
+
+ for k, v in self.event_handlers.items():
+ if k == str(self.toolbar.mode) and v["moved"] is not None:
+ if v["use_mpl_event"]:
+ v["moved"](event)
+ else:
+ v["moved"](*args)
+
+ return f
+
+
+ def registerTool(self, tool_mode, clicked_fn=None, released_fn=None, moved_fn=None, use_mpl_event=False):
+ self.event_handlers[str(tool_mode)] = {"clicked": clicked_fn, "released": released_fn, "moved": moved_fn, "use_mpl_event": use_mpl_event}
+
+
+ # Process matplotlib event into a more abstract interface, and optionally emit a signal.
+ def __onClicked(self, event):
+ if self.use_data_coordinates:
+ x, y = event.xdata, event.ydata
+ else:
+ x, y = event.x, event.y
+ if event.button == MplMouseButton.LEFT:
+ btn = MouseButton.LEFT
+ elif event.button == MplMouseButton.RIGHT:
+ btn = MouseButton.RIGHT
+ elif event.button == MplMouseButton.MIDDLE:
+ btn = MouseButton.MIDDLE
+ else:
+ btn = MouseButton.NONE
+
+ self.last_x, self.last_y = x, y
+
+ x = int(x) if x is not None else -1
+ y = int(y) if y is not None else -1
+
+ if self.emit_clicked:
+ self.clicked.emit(btn, x, y)
+
+ return btn, x, y
+
+
+ def __onReleased(self, event):
+ if self.use_data_coordinates:
+ x, y = event.xdata, event.ydata
+ else:
+ x, y = event.x, event.y
+ if event.button == MplMouseButton.LEFT:
+ btn = MouseButton.LEFT
+ elif event.button == MplMouseButton.RIGHT:
+ btn = MouseButton.RIGHT
+ elif event.button == MplMouseButton.MIDDLE:
+ btn = MouseButton.MIDDLE
+ else:
+ btn = MouseButton.NONE
+
+ self.last_x = self.last_y = None
+
+ x = int(x) if x is not None else -1
+ y = int(y) if y is not None else -1
+
+ if self.emit_released:
+ self.released.emit(btn, x, y)
+
+ return btn, x, y
+
+
+ def __onMoved(self, event):
+ if self.use_data_coordinates:
+ x1, y1 = event.xdata, event.ydata
+ else:
+ x1, y1 = event.x, event.y
+ if event.button == MplMouseButton.LEFT:
+ btn = MouseButton.LEFT
+ elif event.button == MplMouseButton.RIGHT:
+ btn = MouseButton.RIGHT
+ elif event.button == MplMouseButton.MIDDLE:
+ btn = MouseButton.MIDDLE
+ else:
+ btn = MouseButton.NONE
+
+ x0, y0 = self.last_x, self.last_y
+
+ self.last_x, self.last_y = x1, y1
+
+ x0 = int(x0) if x0 is not None else -1
+ y0 = int(y0) if y0 is not None else -1
+ x1 = int(x1) if x1 is not None else -1
+ y1 = int(y1) if y1 is not None else -1
+
+ if self.emit_moved:
+ self.moved.emit(btn, x0, y0, x1, y1) # If -1, either start or finish is outside the canvas.
+
+ return btn, x0, y0, x1, y1
+
+ def __onWheel(self, event):
+ if self.emit_wheel:
+ if event.button == "up":
+ self.wheelUp.emit()
+ elif event.button == "down":
+ self.wheelDown.emit()
+
+ return event.button
diff --git a/modules/ui/utils/MaskDrawingToolbar.py b/modules/ui/utils/MaskDrawingToolbar.py
new file mode 100644
index 000000000..153951409
--- /dev/null
+++ b/modules/ui/utils/MaskDrawingToolbar.py
@@ -0,0 +1,88 @@
+from modules.util.enum.EditMode import EditMode
+from modules.util.enum.ToolType import ToolType
+
+import PySide6.QtGui as QtG
+import PySide6.QtWidgets as QtW
+from matplotlib.backends.backend_qtagg import NavigationToolbar2QT as NavigationToolbar
+from PySide6.QtCore import Qt
+
+
+# Toolbar class for the FigureWidget.
+class MaskDrawingToolbar(NavigationToolbar):
+ toolitems = [] # Override default matplotlib tools.
+
+ def __init__(self, canvas, parent):
+ super().__init__(canvas, parent, coordinates=False)
+ self.tools = {}
+ self.mode = EditMode.NONE
+
+ def toggleTool(self, tool_mode):
+ if self.canvas.widgetlock.available(self):
+ if self.mode == tool_mode:
+ self.mode = EditMode.NONE
+ self.canvas.widgetlock.release(self)
+ else:
+ self.mode = tool_mode
+ self.canvas.widgetlock(self)
+
+ for k, v in self.tools.items():
+ v.setChecked(k == tool_mode)
+
+ def addTools(self, tools):
+ for t in tools:
+ self.__addTool(t)
+
+ def __addTool(self, tool):
+ if tool["type"] == ToolType.SEPARATOR:
+ self.addSeparator()
+ elif tool["type"] == ToolType.SPINBOX or tool["type"] == ToolType.DOUBLE_SPINBOX:
+ range = tool.get("spinbox_range", (0.05, 1.0, 0.05, 0.05))
+ value = tool.get("value", range[0])
+ wdg = QtW.QLabel(self.canvas, text=tool.get("text", None))
+ if tool["type"] == ToolType.DOUBLE_SPINBOX:
+ wdg2 = QtW.QDoubleSpinBox(self.canvas, objectName=tool.get("name", None))
+ else:
+ wdg2 = QtW.QSpinBox(self.canvas, objectName=tool.get("name", None))
+ wdg2.setMinimum(range[0])
+ wdg2.setMaximum(range[1])
+ wdg2.setSingleStep(range[2])
+ wdg2.setValue(value)
+ if "fn" in tool:
+ wdg2.valueChanged.connect(tool["fn"])
+ tool["fn"](value)
+
+ wdg.setBuddy(wdg2)
+ if "icon" in tool:
+ wdg.setPixmap(QtG.QPixmap(tool["icon"]))
+ if "tooltip" in tool:
+ wdg2.setToolTip(tool["tooltip"])
+ self.addWidget(wdg)
+ self.addWidget(wdg2)
+ else:
+ wdg = QtW.QToolButton(self.canvas, objectName=tool.get("name", None))
+ if "shortcut" in tool:
+ scut = QtG.QShortcut(QtG.QKeySequence(tool["shortcut"]), self.canvas)
+ scut.setAutoRepeat(False)
+ scut.activated.connect(wdg.click)
+
+ if tool["type"] == ToolType.CHECKABLE_BUTTON:
+ wdg.setCheckable(True)
+ if "tool" in tool:
+ self.tools[tool["tool"]] = wdg
+ wdg.clicked.connect(lambda: self.toggleTool(tool["tool"]))
+ if "fn" in tool:
+ wdg.clicked.connect(tool["fn"])
+
+ if "text" in tool and "icon" in tool:
+ wdg.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
+
+ if "text" in tool:
+ wdg.setText(tool["text"])
+
+ if "icon" in tool:
+ wdg.setIcon(QtG.QIcon(tool["icon"]))
+
+ if "tooltip" in tool:
+ wdg.setToolTip(tool["tooltip"])
+
+ self.addWidget(wdg)
diff --git a/modules/ui/utils/OneTrainerApplication.py b/modules/ui/utils/OneTrainerApplication.py
new file mode 100644
index 000000000..2e8958e30
--- /dev/null
+++ b/modules/ui/utils/OneTrainerApplication.py
@@ -0,0 +1,26 @@
+from modules.util.enum.ModelType import ModelType
+from modules.util.enum.Optimizer import Optimizer
+from modules.util.enum.TrainingMethod import TrainingMethod
+
+from PySide6.QtCore import Signal
+from PySide6.QtWidgets import QApplication
+
+
+class OnetrainerApplication(QApplication):
+ # Signal for global UI invalidation (e.g., when a config file is reloaded from disk).
+ initialized = Signal()
+ stateChanged = Signal()
+ savedConfig = Signal(str)
+
+ # Signals for dynamic widget lists and sub windows. The passed value is the currently selected element.
+ conceptsChanged = Signal(bool) # If True, save changes.
+ openConcept = Signal(int)
+
+ samplesChanged = Signal()
+ openSample = Signal(int)
+
+ embeddingsChanged = Signal()
+
+ # Signals used to update only a subset of elements, passing relevant data for redrawing.
+ modelChanged = Signal(ModelType, TrainingMethod) # Signal for changed model/training method. Emit with emit(newmodel, newmethod) so that receivers can use directly those messages.
+ optimizerChanged = Signal(Optimizer)
diff --git a/modules/ui/utils/WorkerPool.py b/modules/ui/utils/WorkerPool.py
new file mode 100644
index 000000000..4e2f14fc5
--- /dev/null
+++ b/modules/ui/utils/WorkerPool.py
@@ -0,0 +1,198 @@
+import inspect
+import sys
+import threading
+import traceback
+import uuid
+
+from modules.ui.models.StateModel import StateModel
+
+from PySide6.QtCore import QObject, QRunnable, QThreadPool, Signal, Slot
+
+
+class BaseWorker(QObject):
+ initialized = Signal()
+ finished = Signal(str)
+ errored = Signal(tuple)
+ aborted = Signal()
+ result = Signal(object)
+ progress = Signal(dict) # Arbitrary key-value pairs.
+
+ def __init__(self):
+ super().__init__()
+
+
+ def progressCallback(self):
+ def f(data):
+ self.progress.emit(data)
+ return f
+
+ def _threadWrapper(self):
+ try:
+ self.initialized.emit()
+ if self.inject_progress_callback:
+ if "progress_fn" not in inspect.signature(self.fn).parameters:
+ print("WARNING: callable function has no progress_fn parameter. Invoking the function without it.")
+ out = self.fn(**self.kwargs)
+ else:
+ out = self.fn(progress_fn=self.progressCallback(), **self.kwargs)
+ else:
+ out = self.fn(**self.kwargs)
+
+ if self.abort_flag is not None and self.abort_flag.is_set():
+ self.abort_flag.clear()
+ self.aborted.emit()
+ except Exception:
+ StateModel.instance().log("critical", traceback.format_exc())
+ exctype, value = sys.exc_info()[:2]
+ self.errored.emit((exctype, value, traceback.format_exc()))
+ else:
+ self.result.emit(out)
+ finally:
+ self.finished.emit(self.name)
+
+ def connectCallbacks(self, init_fn=None, result_fn=None, finished_fn=None, errored_fn=None, aborted_fn=None, progress_fn=None):
+ if init_fn is not None:
+ self.connections["initialized"].append(self.initialized.connect(init_fn))
+ if result_fn is not None:
+ self.connections["result"].append(self.result.connect(result_fn))
+ if errored_fn is not None:
+ self.connections["errored"].append(self.errored.connect(errored_fn))
+ if finished_fn is not None:
+ self.connections["finished"].append(self.finished.connect(finished_fn))
+ if aborted_fn is not None:
+ self.connections["aborted"].append(self.aborted.connect(aborted_fn))
+ if progress_fn is not None:
+ self.connections["progress"].append(self.progress.connect(progress_fn))
+
+ def disconnectAll(self):
+ for v in self.connections.values():
+ for v2 in v:
+ v2.disconnect()
+ self.connections = {"initialized": [], "result": [], "errored": [], "finished": [], "aborted": [], "progress": []}
+
+
+# Thread Worker based on QRunnable (it cannot be joined, but it is automatically enqueued on QT6's QThreadPool, balancing loads automatically.
+# IMPORTANT: For severe exceptions (e.g., CUDA errors) it may crash the entire application with SIGSEGV.
+# According to this: https://stackoverflow.com/questions/59837773/qtcore-qrunnable-causes-sigsev-pyqt5
+# The problem may be that multiple inheritance may cause sometimes to access reserved memory
+class RunnableWorker(QRunnable, BaseWorker):
+ def __init__(self, fn, name, abort_flag=None, inject_progress_callback=False, **kwargs):
+ QRunnable.__init__(self)
+
+ self.fn = fn
+ self.name = name
+ self.abort_flag = abort_flag
+ self.kwargs = kwargs
+ self.inject_progress_callback = inject_progress_callback
+
+ self.connections = {"initialized": [], "result": [], "errored": [], "finished": [], "aborted": [], "progress": []}
+ self.destroyed.connect(lambda _: self.disconnectAll)
+
+ @Slot()
+ def run(self):
+ self._threadWrapper()
+
+# Thread Worker based on threading.Thread, it is a manually managed thread, with join capabilities.
+# It *should* survive severe exceptions, as it is a native python implementation.
+class PoolLessWorker(BaseWorker):
+ def __init__(self, fn, name, abort_flag=None, inject_progress_callback=False, daemon=False, **kwargs):
+ BaseWorker.__init__(self)
+
+ self.fn = fn
+ self.name = name
+ self.abort_flag = abort_flag
+ self.kwargs = kwargs
+ self.inject_progress_callback = inject_progress_callback
+
+ self.connections = {"initialized": [], "result": [], "errored": [], "finished": [], "aborted": [], "progress": []}
+ self.destroyed.connect(lambda _: self.disconnectAll)
+
+ self._thread = threading.Thread(target=self._threadWrapper, daemon=daemon)
+
+ def start(self):
+ self._thread.start()
+
+ def join(self, timeout=None):
+ self._thread.join(timeout)
+
+ def isAlive(self):
+ return self._thread.is_alive()
+
+
+
+# Simple worker pool class. It allows to enqueue arbitrary functions executed on a QThreadPool. All the function parameters must be passed BY NAME (kwargs).
+# If a job is associated with a name (createNamed()), its execution is reentrant (i.e., attempting to run the same job multiple times, will execute it only once).
+# Workers (returned by createNamed and createAnonymous) expose initialized(), finished(), aborted(), result(function output) and errored(exception, value, traceback) signals.
+# Abort events are a responsibility of the function, which can optionally be associated with a threading.Event() object (the aborted signal will be emitted if at the end of the execution, the event is_set()).
+# IMPORTANT: the finished signal also removes the worker reference from this class, therefore unless a reference is saved somewhere else, it will be garbage collected.
+# Using the worker's connect() method should avoid errors due to connections still active after garbage collection.
+#
+# A typical worker life-cycle is:
+# worker_object, worker_id = WorkerPool.instance().createNamed(...)/createAnonymous(...)
+# worker_object.connect(...)
+# WorkerPool.instance().start(worker_id)
+
+class WorkerPool:
+ _instance = None
+
+ @classmethod
+ def instance(cls):
+ if cls._instance is None:
+ cls._instance = cls()
+ return cls._instance
+
+ def __init__(self):
+ self.pool = QThreadPool()
+ self.named_workers = {} # This worker's list refuses to append a new worker with the same name.
+ self.anonymous_workers = {} # This worker's list can grow arbitrarily.
+ self.poolless_workers = {}
+
+ def __len__(self):
+ return len(self.anonymous_workers) + len(self.named_workers)
+
+ def createAnonymous(self, fn, abort_flag=None, **kwargs):
+ id = str(uuid.uuid4())
+ worker = RunnableWorker(fn, id, abort_flag=abort_flag, **kwargs)
+ worker.connectCallbacks(finished_fn=self.__removeFinished(is_named=False))
+ self.anonymous_workers[id] = worker
+ return worker, id
+
+ def createNamed(self, fn, name, poolless=False, daemon=False, abort_flag=None, **kwargs):
+ if name not in self.named_workers:
+ if poolless:
+ worker = PoolLessWorker(fn, name, abort_flag=abort_flag, daemon=daemon, **kwargs)
+ worker.connectCallbacks(finished_fn=self.__removeFinished(is_named=True))
+ self.poolless_workers[name] = worker
+ else:
+ worker = RunnableWorker(fn, name, abort_flag=abort_flag, **kwargs)
+ worker.connectCallbacks(finished_fn=self.__removeFinished(is_named=True))
+ self.named_workers[name] = worker
+ return worker, name
+ else:
+ return None, None
+
+
+ def start(self, worker_id):
+ ok = False
+ if worker_id in self.named_workers:
+ ok = True
+ self.pool.start(self.named_workers[worker_id])
+ elif worker_id in self.poolless_workers:
+ ok = True
+ self.poolless_workers[worker_id].start()
+ elif worker_id in self.anonymous_workers:
+ ok = True
+ self.pool.start(self.anonymous_workers[worker_id])
+
+ return ok
+
+ def __removeFinished(self, is_named):
+ def f(name):
+ if is_named:
+ if name in self.named_workers:
+ self.named_workers.pop(name)
+ elif name in self.poolless_workers:
+ self.poolless_workers.pop(name)
+ else:
+ self.anonymous_workers.pop(name)
+ return f
diff --git a/modules/ui/views/tabs/additional_embeddings.ui b/modules/ui/views/tabs/additional_embeddings.ui
new file mode 100644
index 000000000..be203e625
--- /dev/null
+++ b/modules/ui/views/tabs/additional_embeddings.ui
@@ -0,0 +1,38 @@
+
+
+ Form
+
+
+
+ 0
+ 0
+ 409
+ 120
+
+
+
+ Form
+
+
+ -
+
+
+ Add Embedding
+
+
+
+ -
+
+
+ Enable All
+
+
+
+ -
+
+
+
+
+
+
+
diff --git a/modules/ui/views/tabs/backup.ui b/modules/ui/views/tabs/backup.ui
new file mode 100644
index 000000000..a9ed5ebaa
--- /dev/null
+++ b/modules/ui/views/tabs/backup.ui
@@ -0,0 +1,241 @@
+
+
+ Form
+
+
+
+ 0
+ 0
+ 622
+ 298
+
+
+
+ Form
+
+
+ -
+
+
+ Backup After
+
+
+ backupSbx
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ The interval used when automatically creating model backups during training
+
+
+
+ -
+
+
+ -
+
+
+ Qt::Orientation::Horizontal
+
+
+
+ 108
+ 0
+
+
+
+
+ -
+
+
+ Backup Now
+
+
+
+ -
+
+
+ If rolling backups are enabled, older backups are deleted automatically
+
+
+ Rolling Backup
+
+
+
+ -
+
+
+ Rolling Backup Count
+
+
+ rollingCountSbx
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Defines the number of backups to keep if rolling backups are enabled
+
+
+
+
+
+
+ -
+
+
+ Create a full backup before saving the final model
+
+
+ Backup Before Save
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+ QSizePolicy::Policy::Fixed
+
+
+
+ 20
+ 20
+
+
+
+
+ -
+
+
+ Save Every
+
+
+ saveSbx
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ The interval used when automatically saving the model during training
+
+
+
+ -
+
+
+ -
+
+
+ Qt::Orientation::Horizontal
+
+
+
+ 108
+ 0
+
+
+
+
+ -
+
+
+ Save Now
+
+
+
+ -
+
+
+ Skip First
+
+
+ skipSbx
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Start saving automatically after this interval has elapsed
+
+
+
+ -
+
+
+ Save Filename Prefix
+
+
+ savePrefixLed
+
+
+
+ -
+
+
+ The prefix for filenames used when saving the model during training
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 55
+
+
+
+
+
+
+
+ backupSbx
+ backupCmb
+ backupBtn
+ rollingBackupCbx
+ rollingCountSbx
+ backupBeforeSaveCbx
+ saveSbx
+ saveCmb
+ saveBtn
+ skipSbx
+ savePrefixLed
+
+
+
+
diff --git a/modules/ui/views/tabs/cloud.ui b/modules/ui/views/tabs/cloud.ui
new file mode 100644
index 000000000..4767aa8f9
--- /dev/null
+++ b/modules/ui/views/tabs/cloud.ui
@@ -0,0 +1,681 @@
+
+
+ Form
+
+
+
+ 0
+ 0
+ 1346
+ 694
+
+
+
+ Form
+
+
+ -
+
+
+ true
+
+
+
+
+ 0
+ 0
+ 1326
+ 674
+
+
+
+
-
+
+
+ Enable cloud training
+
+
+ Enabled
+
+
+
+ -
+
+
+ false
+
+
+ QFrame::Shape::NoFrame
+
+
+ QFrame::Shadow::Plain
+
+
+
-
+
+
+ Type
+
+
+ cloudTypeCmb
+
+
+
+ -
+
+
+ Choose LINUX to connect to a linux machine via SSH. Choose RUNPOD for additional functionality such as automatically creating and deleting pods
+
+
+
+ -
+
+
+ Remote Directory
+
+
+ remoteDirectoryLed
+
+
+
+ -
+
+
+ The directory on the cloud where files will be uploaded and downloaded
+
+
+
+ -
+
+
+ Automatically creates a new cloud instance if both Host:Port and Cloud ID are empty. Currently supported for RUNPOD
+
+
+ Create Cloud Via API
+
+
+
+ -
+
+
+ Create Cloud Via Website
+
+
+
+ -
+
+
+ File Sync Method
+
+
+ fileSyncMethodCmb
+
+
+
+ -
+
+
+ Choose NATIVE_SCP to use scp.exe to transfer files. FABRIC_SFTP uses the Paramiko/Fabric SFTP implementation for file transfers instead
+
+
+
+ -
+
+
+ OneTrainer Directory
+
+
+ onetrainerDirectoryLed
+
+
+
+ -
+
+
+ The directory for OneTrainer on the cloud
+
+
+
+ -
+
+
+ Cloud Name
+
+
+ cloudNameLed
+
+
+
+ -
+
+
+ The name of the new cloud instance
+
+
+
+ -
+
+
+ API Key
+
+
+ apiKeyLed
+
+
+
+ -
+
+
+ Cloud service API key for RUNPOD. Leave empty for LINUX. This value is stored separately, not saved to your configuration file
+
+
+
+ -
+
+
+ Huggingface Cache Directory
+
+
+ huggingfaceCacheLed
+
+
+
+ -
+
+
+ Huggingface models are downloaded to this remote directory
+
+
+
+ -
+
+
+ Type
+
+
+ subTypeCmb
+
+
+
+ -
+
+
+ Select the RunPod cloud type. See RunPod's website for details
+
+
+
+ -
+
+
+ Hostname
+
+
+ hostnameLed
+
+
+
+ -
+
+
+ SSH server hostname or IP. Leave empty if you have a Cloud ID or want to automatically create a new cloud
+
+
+
+ -
+
+
+ Automatically install OneTrainer from GitHub if the directory doesn't already exist
+
+
+ Install OneTrainer
+
+
+
+ -
+
+
+ GPU
+
+
+ gpuCmb
+
+
+
+ -
+
+
+ Select the GPU type. Enter an API key before pressing the button
+
+
+
+ -
+
+
+ Port
+
+
+ portSbx
+
+
+
+ -
+
+
+ SSH server port. Leave empty if you have a Cloud ID or want to automatically create a new cloud
+
+
+ QAbstractSpinBox::ButtonSymbols::NoButtons
+
+
+ 65535
+
+
+
+ -
+
+
+ Install Command
+
+
+ installCommandLed
+
+
+
+ -
+
+
+ The command for installing OneTrainer. Leave the default, unless you want to use a development branch of OneTrainer
+
+
+
+ -
+
+
+ Volume Size
+
+
+ volumeSizeSbx
+
+
+
+ -
+
+
+ Set the storage volume size in GB. This volume persists only until the cloud is deleted - not a RunPod network volume
+
+
+ 1024
+
+
+
+ -
+
+
+ User
+
+
+ userLed
+
+
+
+ -
+
+
+ SSH username. Use "root" for RUNPOD. Your SSH client must be set up to connect to the cloud using a public key, without a password. For RUNPOD, create an ed25519 key locally, and copy the contents of the public keyfile to your "SSH Public Keys" on the RunPod website
+
+
+
+ -
+
+
+ Update OneTrainer if it already exists on the cloud
+
+
+ Update OneTrainer
+
+
+
+ -
+
+
+ Min Download
+
+
+ minDownloadSbx
+
+
+
+ -
+
+
+ Set the minimum download speed of the cloud in Mbps
+
+
+ 1024
+
+
+
+ -
+
+
+ Cloud ID
+
+
+ cloudIdLed
+
+
+
+ -
+
+
+ RUNPOD Cloud ID. The cloud service must have a public IP and SSH service. Leave empty if you want to automatically create a new RUNPOD cloud, or if you're connecting to another cloud provider via SSH Hostname and Port
+
+
+
+ -
+
+
+ Instead of starting tensorboard locally, make a TCP tunnel to a tensorboard on the cloud
+
+
+ Tensorboard TCP Tunnel
+
+
+
+ -
+
+
+ Qt::Orientation::Horizontal
+
+
+
+ 221
+ 20
+
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
+ Allows the trainer to keep running even if your connection to the cloud is lost
+
+
+ Detach Remote Trainer
+
+
+
+ -
+
+
+ Reattach ID
+
+
+ reattachIdLed
+
+
+
+ -
+
+
+
+ 156
+ 0
+
+
+
+ An id identifying the remotely running trainer. In case you have lost connection or closed OneTrainer, it will try to reattach to this id instead of starting a new remote trainer
+
+
+
+ -
+
+
+ Reattach Now
+
+
+
+ -
+
+
+ Qt::Orientation::Horizontal
+
+
+ QSizePolicy::Policy::Minimum
+
+
+
+ 156
+ 20
+
+
+
+
+ -
+
+
+ Action On Finish
+
+
+ onFinishCmb
+
+
+
+ -
+
+
+ What to do when training finishes and the data has been fully downloaded: Stop or delete the cloud, or do nothing
+
+
+
+ -
+
+
+ Download samples from the remote workspace directory to your local machine
+
+
+ Download Samples
+
+
+
+ -
+
+
+ Action On Error
+
+
+ onErrorCmb
+
+
+
+ -
+
+
+ What to do if training stops due to an error: Stop or delete the cloud, or do nothing. Data may be lost
+
+
+
+ -
+
+
+ Download the final model after training. You can disable this if you plan to use an automatically saved checkpoint instead
+
+
+ Download Output Model
+
+
+
+ -
+
+
+ Action On Detached
+
+
+ onDetachedCmb
+
+
+
+ -
+
+
+ What to do when training finishes, but the client has been detached and cannot download data. Data may be lost
+
+
+
+ -
+
+
+ Download the automatically saved training checkpoints from the remote workspace directory to your local machine
+
+
+ Download Saved Checkpoints
+
+
+
+ -
+
+
+ Action On Detached Error
+
+
+ onDetachedErrorCmb
+
+
+
+ -
+
+
+ What to if training stops due to an error, but the client has been detached and cannot download data. Data may be lost
+
+
+
+ -
+
+
+ Download backups from the remote workspace directory to your local machine. It's usually not necessary to download them, because as long as the backups are still available on the cloud, the training can be restarted using one of the cloud's backups
+
+
+ Download Backups
+
+
+
+ -
+
+
+ Download TensorBoard event logs from the remote workspace directory to your local machine. They can then be viewed locally in TensorBoard. It is recommended to disable "Sample to TensorBoard" to reduce the event log size
+
+
+ Download Tensorboard Logs
+
+
+
+ -
+
+
+ Delete the workspace directory on the cloud after training has finished successfully and data has been downloaded
+
+
+ Delete Remote Workspace
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+ QSizePolicy::Policy::MinimumExpanding
+
+
+
+ 20
+ 3
+
+
+
+
+ -
+
+
+ ...
+
+
+
+
+
+
+
+
+
+
+
+
+
+ scrollArea_2
+ enabledCbx
+ cloudTypeCmb
+ fileSyncMethodCmb
+ apiKeyLed
+ hostnameLed
+ portSbx
+ userLed
+ cloudIdLed
+ remoteDirectoryLed
+ onetrainerDirectoryLed
+ huggingfaceCacheLed
+ installOnetrainerCbx
+ installCommandLed
+ updateOnetrainerCbx
+ createCloudCbx
+ createCloudBtn
+ cloudNameLed
+ subTypeCmb
+ gpuCmb
+ gpuBtn
+ volumeSizeSbx
+ minDownloadSbx
+ tensorboardTcpTunnelCbx
+ detachRemoteTrainerCbx
+ downloadSamplesCbx
+ reattachIdLed
+ reattachBtn
+ downloadOutputModelCbx
+ downloadSavedCheckpointsCbx
+ downloadBackupsCbx
+ downloadTensorboardLogCbx
+ deleteRemoteWorkspaceCbx
+ onFinishCmb
+ onErrorCmb
+ onDetachedCmb
+ onDetachedErrorCmb
+
+
+
+
diff --git a/modules/ui/views/tabs/concepts.ui b/modules/ui/views/tabs/concepts.ui
new file mode 100644
index 000000000..c86ff4656
--- /dev/null
+++ b/modules/ui/views/tabs/concepts.ui
@@ -0,0 +1,119 @@
+
+
+ Form
+
+
+
+ 0
+ 0
+ 997
+ 497
+
+
+
+ Form
+
+
+ -
+
+
+ -
+
+
+ Type
+
+
+ typeCmb
+
+
+
+ -
+
+
+ -
+
+
+ QAbstractItemView::DragDropMode::DragDrop
+
+
+ QAbstractItemView::SelectionMode::NoSelection
+
+
+ QListView::ResizeMode::Adjust
+
+
+ QListView::ViewMode::IconMode
+
+
+ true
+
+
+
+ -
+
+
+ Search
+
+
+ searchLed
+
+
+
+ -
+
+
+ Show Disabled
+
+
+ true
+
+
+
+ -
+
+
+ Add Concept
+
+
+
+ -
+
+
+ Disable
+
+
+
+ -
+
+
+ true
+
+
-
+
+ concepts
+
+
+
+
+ -
+
+
+ Clear Filters
+
+
+
+
+
+
+ presetCmb
+ addConceptBtn
+ toggleBtn
+ searchLed
+ typeCmb
+ showDisabledCbx
+ clearBtn
+ listWidget
+
+
+
+
diff --git a/modules/ui/views/tabs/data.ui b/modules/ui/views/tabs/data.ui
new file mode 100644
index 000000000..af0be3deb
--- /dev/null
+++ b/modules/ui/views/tabs/data.ui
@@ -0,0 +1,64 @@
+
+
+ Form
+
+
+
+ 0
+ 0
+ 243
+ 143
+
+
+
+ Form
+
+
+ -
+
+
+ Aspect ratio bucketing enables training on images with different aspect ratios
+
+
+ Aspect Ratio Bucketing
+
+
+
+ -
+
+
+ Caching of intermediate training data that can be re-used between epochs
+
+
+ Latent Caching
+
+
+
+ -
+
+
+ Clears the cache directory before starting to train. Only disable this if you want to continue using the same cached data. Disabling this can lead to errors, if other settings are changed during a restart
+
+
+ Clear Cache Before Training
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+
+
+
+
+
diff --git a/modules/ui/views/tabs/embeddings.ui b/modules/ui/views/tabs/embeddings.ui
new file mode 100644
index 000000000..720d73235
--- /dev/null
+++ b/modules/ui/views/tabs/embeddings.ui
@@ -0,0 +1,154 @@
+
+
+ Form
+
+
+
+ 0
+ 0
+ 425
+ 215
+
+
+
+ Form
+
+
+ -
+
+
+ Base Embedding
+
+
+ baseEmbeddingLed
+
+
+
+ -
+
+
+ The base embedding to train on. Leave empty to create a new embedding
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Token Count
+
+
+ tokenSbx
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ The token count used when creating a new embedding. Leave empty to auto detect from the initial embedding text
+
+
+ 1
+
+
+ 75
+
+
+
+ -
+
+
+ Initial Embedding Text
+
+
+ initialEmbeddingLed
+
+
+
+ -
+
+
+ The initial embedding text used when creating a new embedding
+
+
+ *
+
+
+
+ -
+
+
+ Embedding Weight Data Type
+
+
+ embeddingDTypeCmb
+
+
+
+ -
+
+
+ The Embedding weight data type used for training. This can reduce memory consumption, but reduces precision
+
+
+
+ -
+
+
+ Placeholder
+
+
+ placeholderLed
+
+
+
+ -
+
+
+ The placeholder used when using the embedding in a prompt
+
+
+ <embedding>
+
+
+
+ -
+
+
+ Output embeddings are calculated at the output of the text encoder, not the input. This can improve results for larger text encoders and lower VRAM usage
+
+
+ Output Embedding
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 4
+
+
+
+
+
+
+
+
+
diff --git a/modules/ui/views/tabs/general.ui b/modules/ui/views/tabs/general.ui
new file mode 100644
index 000000000..fd31dccd5
--- /dev/null
+++ b/modules/ui/views/tabs/general.ui
@@ -0,0 +1,433 @@
+
+
+ Form
+
+
+
+ 0
+ 0
+ 949
+ 452
+
+
+
+ Form
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ false
+
+
+ true
+
+
+
+
+ 0
+ 0
+ 929
+ 432
+
+
+
+
-
+
+
+ Multi-GPU
+
+
+
+ -
+
+
+ Debug Directory
+
+
+ debugLed
+
+
+
+ -
+
+
+ The device used to temporarily offload models while they are not used. Default: "cpu"
+
+
+
+ -
+
+
+ -
+
+
+ 1048576
+
+
+ 100
+
+
+
+ -
+
+
+ The directory where debug data is saved
+
+
+
+ -
+
+
+ Qt::Orientation::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+ -
+
+
+ Temp Device
+
+
+ tempDeviceLed
+
+
+
+ -
+
+
+ -
+
+
+ Fused Gradient Reduce
+
+
+
+ -
+
+
+ Cache Directory
+
+
+ cacheLed
+
+
+
+ -
+
+
+ Dataloader Threads
+
+
+ dataloaderSbx
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Number of threads used for the data loader. Increase if your GPU has room during caching, decrease if it's going out of memory during caching
+
+
+ 1
+
+
+ 2
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
+ Save debug information during the training into the debug directory
+
+
+ Debug Mode
+
+
+
+ -
+
+
+ Validate After
+
+
+ validateSbx
+
+
+
+ -
+
+
+ Port to use for Tensorboard link
+
+
+ QAbstractSpinBox::ButtonSymbols::NoButtons
+
+
+ 1
+
+
+ 65535
+
+
+ 1
+
+
+ 6006
+
+
+
+ -
+
+
+ Enable validation steps and add new graph in tensorboard
+
+
+ Validation
+
+
+
+ -
+
+
+ The directory where all files of this training run are saved
+
+
+
+ -
+
+
+ Train Device
+
+
+ trainDeviceLed
+
+
+
+ -
+
+
+ Device Indexes
+
+
+ deviceIndexesLed
+
+
+
+ -
+
+
+ Automatically continues training from the last backup saved in <workspace>/backup
+
+
+ Continue From Last Backup
+
+
+
+ -
+
+
+ Tensorboard Port
+
+
+ tensorboardSbx
+
+
+
+ -
+
+
+ The device used for training. Can be "cuda", "cuda:0", "cuda:1" etc. Default:"cuda"
+
+
+ cuda
+
+
+
+ -
+
+
+ -
+
+
+ Only populate the cache, without any training
+
+
+ Only Cache
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Workspace Directory
+
+
+ workspaceLed
+
+
+
+ -
+
+
+ Starts the Tensorboard Web UI during training
+
+
+ Tensorboard
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ The interval used when validate training
+
+
+
+ -
+
+
+ Buffer Size (MB)
+
+
+ bufferSbx
+
+
+
+ -
+
+
+ Async Gradient Reduce
+
+
+
+ -
+
+
+ Gradient Reduce Precision
+
+
+ gradientReduceCmb
+
+
+
+ -
+
+
+ Exposes Tensorboard Web UI to all network interfaces (makes it accessible from the network)
+
+
+ Expose Tensorboard
+
+
+
+ -
+
+
+ Always-On Tensorboard
+
+
+
+ -
+
+
+ The directory where cached data is saved
+
+
+
+ -
+
+
+ ...
+
+
+
+
+
+
+
+
+
+
+ scrollArea
+ workspaceLed
+ workspaceBtn
+ cacheLed
+ cacheBtn
+ continueCbx
+ onlyCacheCbx
+ debugCbx
+ debugLed
+ debugBtn
+ tensorboardCbx
+ alwaysOnTensorboardCbx
+ exposeTensorboardCbx
+ tensorboardSbx
+ validateCbx
+ validateSbx
+ validateCmb
+ dataloaderSbx
+ trainDeviceLed
+ multiGpuCbx
+ deviceIndexesLed
+ gradientReduceCmb
+ fusedGradientCbx
+ asyncGradientCbx
+ bufferSbx
+ tempDeviceLed
+
+
+
+
diff --git a/modules/ui/views/tabs/lora.ui b/modules/ui/views/tabs/lora.ui
new file mode 100644
index 000000000..c5e028d72
--- /dev/null
+++ b/modules/ui/views/tabs/lora.ui
@@ -0,0 +1,331 @@
+
+
+ Form
+
+
+
+ 0
+ 0
+ 618
+ 301
+
+
+
+ Form
+
+
+ -
+
+
+ QFrame::Shape::NoFrame
+
+
+ QFrame::Shadow::Plain
+
+
+
-
+
+
+ Decompose LoRA Weights (aka, DoRA)
+
+
+ Decompose Weights (DoRA)
+
+
+ true
+
+
+
+ -
+
+
+ Add an epsilon to the norm divison calculation in DoRA. Can aid in training stability, and also acts as regularization
+
+
+ Use Norm Epsilon (DoRA Only)
+
+
+
+ -
+
+
+ Apply the weight decomposition on the output axis instead of the input axis
+
+
+ Apply On Output Axis (DoRA Only)
+
+
+
+
+
+
+ -
+
+
+ QFrame::Shape::NoFrame
+
+
+ QFrame::Shadow::Raised
+
+
+
-
+
+
+ Rank
+
+
+ rankSbx
+
+
+
+ -
+
+
+ The rank parameter used when creating a new LoRA
+
+
+ 512
+
+
+
+ -
+
+
+ OFT Block Size
+
+
+ oftBlockSizeSbx
+
+
+
+ -
+
+
+ 512
+
+
+
+ -
+
+
+ Alpha
+
+
+ alphaSbx
+
+
+
+ -
+
+
+ The alpha parameter used when creating a new LoRA
+
+
+
+ -
+
+
+ Dropout Probability
+
+
+ DropoutSbx
+
+
+
+ -
+
+
+ Dropout probability. This percentage of model nodes will be randomly ignored at each training step. Helps with overfitting. 0 disables, 1 maximum
+
+
+ 1.000000000000000
+
+
+ 0.050000000000000
+
+
+
+ -
+
+
+ Weight Data Type
+
+
+ weightDTypeCmb
+
+
+
+ -
+
+
+ The LoRA weight data type used for training. This can reduce memory consumption, but reduces precision
+
+
+
+ -
+
+
+ Bundles any additional embeddings into the LoRA output file, rather than as separate files
+
+
+ Bundle Embeddings
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 5
+
+
+
+
+
+
+
+ -
+
+
+ The base LoRA to train on. Leave empty to create a new LoRA
+
+
+
+ -
+
+
+ Type
+
+
+ typeCmb
+
+
+
+ -
+
+
+ The type of low-parameter finetuning method
+
+
+ false
+
+
+
+
+
+ -1
+
+
+
+ -
+
+
+ Base Model
+
+
+ baseModelLed
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ QFrame::Shape::NoFrame
+
+
+ QFrame::Shadow::Raised
+
+
+
-
+
+
+ Share the OFT parameters between blocks. A single rotation matrix is shared across all blocks within a layer, drastically cutting the number of trainable parameters and yielding very compact adapter files, potentially improving generalization but at the cost of significant expressiveness, which can lead to underfitting on more complex or diverse tasks
+
+
+ Block Share
+
+
+
+ -
+
+
+ The control strength of COFT. Only has an effect if COFT is enabled
+
+
+
+ -
+
+
+ COFT Epsilon
+
+
+ coftLed
+
+
+
+ -
+
+
+ Use the constrained variant of OFT. This constrains the learned rotation to stay very close to the identity matrix, limiting adaptation to only small changes. This improves training stability, helps prevent overfitting on small datasets, and better preserves the base models original knowledge but it may lack expressiveness for tasks requiring substantial adaptation and introduces an additional hyperparameter (COFT Epsilon) that needs tuning
+
+
+ Constrained OFT (COFT)
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+
+
+
+ -
+
+
+ Qt::Orientation::Horizontal
+
+
+
+ 27
+ 20
+
+
+
+
+
+
+
+ typeCmb
+ baseModelLed
+ baseModelBtn
+ decomposeCbx
+ normCbx
+ outputAxisCbx
+ coftCbx
+ coftLed
+ blockShareCbx
+
+
+
+
diff --git a/modules/ui/views/tabs/model.ui b/modules/ui/views/tabs/model.ui
new file mode 100644
index 000000000..2328f2da4
--- /dev/null
+++ b/modules/ui/views/tabs/model.ui
@@ -0,0 +1,595 @@
+
+
+ Form
+
+
+
+ 0
+ 0
+ 765
+ 585
+
+
+
+ Form
+
+
+ -
+
+
+ true
+
+
+
+
+ 0
+ 0
+ 745
+ 565
+
+
+
+
-
+
+
+ Enter your Hugging Face access token if you have used a protected Hugging Face repository below.
+This value is stored separately, not saved to your configuration file. Go to https://huggingface.co/settings/tokens to create an access token
+
+
+
+ -
+
+
+ Include Config
+
+
+ configCmb
+
+
+
+ -
+
+
+ Output Format
+
+
+ outputFormatCmb
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Overrides the decoder vqgan weight data type
+
+
+
+ -
+
+
+ Base Model
+
+
+ baseModelLed
+
+
+
+ -
+
+
+ Override Text Encoder 2 Data Type
+
+
+ te2DTypeCmb
+
+
+
+ -
+
+
+ Overrides the effnet encoder weight data type
+
+
+
+ -
+
+
+ Directory or Hugging Face repository of a VAE model in diffusers format. Can be used to override the VAE included in the base model. Using a safetensor VAE file will cause an error that the model cannot be loaded
+
+
+
+ -
+
+
+ Filename, directory or Hugging Face repository of the text encoder 4 model
+
+
+
+ -
+
+
+ Override Text Encoder 1 Data Type
+
+
+ te1DTypeCmb
+
+
+
+ -
+
+
+ Huggingface Token
+
+
+ huggingfaceLed
+
+
+
+ -
+
+
+ Filename, directory or Hugging Face repository of the effnet encoder model
+
+
+
+ -
+
+
+ Override VAE Data Type
+
+
+ vaeDTypeCmb
+
+
+
+ -
+
+
+ Overrides the unet weight data type
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ VAE Override
+
+
+ vaeLed
+
+
+
+ -
+
+
+ Prior Model
+
+
+ priorLed
+
+
+
+ -
+
+
+ Override Decoder Text Encoder Data Type
+
+
+ decTeDTypeCmb
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Overrides the text encoder 3 weight data type
+
+
+
+ -
+
+
+ Filename, directory or Hugging Face repository of the prior model
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 10
+
+
+
+
+ -
+
+
+ Overrides the transformer weight data type
+
+
+
+ -
+
+
+ Override Text Encoder 3 Data Type
+
+
+ te3DTypeCmb
+
+
+
+ -
+
+
+ Override Transformer Data Type
+
+
+ transformerDTypeCmb
+
+
+
+ -
+
+
+ Overrides the text encoder 4 weight data type
+
+
+
+ -
+
+
+ Override UNet Data Type
+
+
+ unetDTypeCmb
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Override Effnet Encoder Data Type
+
+
+ effnetDTypeCmb
+
+
+
+ -
+
+
+ Include the training configuration in the final model. Only supported for safetensors files. None: No config is included. Settings: All training settings are included. All: All settings, including the samples and concepts are included
+
+
+
+ -
+
+
+ Override Decoder VQGAN Data Type
+
+
+ vqganDTypeCmb
+
+
+
+ -
+
+
+ Filename or directory where the output model is saved
+
+
+
+ -
+
+
+ Format to use when saving the output model
+
+
+
+ -
+
+
+ Override Transformer / GGUF
+
+
+ transformerLed
+
+
+
+ -
+
+
+ Can be used to override the transformer in the base model. Safetensors and GGUF files are supported, local and on Huggingface. If a GGUF file is used, the DataType must also be set to GGUF
+
+
+
+ -
+
+
+ Overrides the text encoder 2 weight data type
+
+
+
+ -
+
+
+ Weight Data Type
+
+
+ weightDTypeCmb
+
+
+
+ -
+
+
+ Overrides the decoder text encoder weight data type
+
+
+
+ -
+
+
+ Override Prior Data Type
+
+
+ priorDTypeCmb
+
+
+
+ -
+
+
+ Uses torch.compile and Triton to significantly speed up training. Only applies to transformer/unet. Disable in case of compatibility issues
+
+
+ Compile Transformer Blocks
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Overrides the prior weight data type
+
+
+
+ -
+
+
+ Overrides the vae weight data type
+
+
+
+ -
+
+
+ Precision to use when saving the output model
+
+
+
+ -
+
+
+ Overrides the text encoder weight data type
+
+
+
+ -
+
+
+ Output Data Type
+
+
+ outputDTypeCmb
+
+
+
+ -
+
+
+ The base model weight data type used for training. This can reduce memory consumption, but reduces precision
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Qt::Orientation::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+ -
+
+
+ Override Text Encoder 4 Data Type
+
+
+ te4DTypeCmb
+
+
+
+ -
+
+
+ Filename, directory or Hugging Face repository of the base model
+
+
+
+ -
+
+
+ Text Encoder 4 Override
+
+
+ te4Led
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Model Output Destination
+
+
+ modelOutputLed
+
+
+
+ -
+
+
+ Effnet Encoder Model
+
+
+ effnetLed
+
+
+
+ -
+
+
+ Overrides the decoder weight data type
+
+
+
+ -
+
+
+ Override Decoder Data Type
+
+
+ decDTypeCmb
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Filename, directory or Hugging Face repository of the decoder model
+
+
+
+ -
+
+
+ Decoder Model
+
+
+ decLed
+
+
+
+
+
+
+
+
+
+
+ scrollArea
+ huggingfaceLed
+ baseModelLed
+ baseModelBtn
+ weightDTypeCmb
+ unetDTypeCmb
+ priorLed
+ priorBtn
+ priorDTypeCmb
+ transformerLed
+ transformerBtn
+ transformerDTypeCmb
+ compileTransformerCbx
+ te1DTypeCmb
+ te2DTypeCmb
+ te3DTypeCmb
+ te4Led
+ te4Btn
+ te4DTypeCmb
+ vaeLed
+ vaeBtn
+ vaeDTypeCmb
+ effnetLed
+ effnetBtn
+ effnetDTypeCmb
+ decLed
+ decBtn
+ decDTypeCmb
+ vqganDTypeCmb
+ decTeDTypeCmb
+ modelOutputLed
+ modelOutputBtn
+ outputDTypeCmb
+ outputFormatCmb
+ configCmb
+
+
+
+
diff --git a/modules/ui/views/tabs/sampling.ui b/modules/ui/views/tabs/sampling.ui
new file mode 100644
index 000000000..c7ff74921
--- /dev/null
+++ b/modules/ui/views/tabs/sampling.ui
@@ -0,0 +1,208 @@
+
+
+ Form
+
+
+
+ 0
+ 0
+ 828
+ 300
+
+
+
+ Form
+
+
+ -
+
+
+ Enable
+
+
+
+ -
+
+
+ Add Sample
+
+
+
+ -
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Start sampling automatically after this interval has elapsed
+
+
+
+ -
+
+
+ Qt::Orientation::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+ -
+
+
+ Sample After
+
+
+ sampleAfterSbx
+
+
+
+ -
+
+
+ Format
+
+
+ formatCmb
+
+
+
+ -
+
+
+ File Format used when saving samples
+
+
+
+ -
+
+
+ Manual Sample
+
+
+
+ -
+
+
+ Whether to include non-ema sampling when using ema
+
+
+ Non-EMA Sampling
+
+
+
+ -
+
+
+ Sample Now
+
+
+
+ -
+
+
+ Skip First
+
+
+ skipSbx
+
+
+
+ -
+
+
+ Qt::Orientation::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ The interval used when automatically sampling from the model during training
+
+
+
+ -
+
+
+ Qt::Orientation::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+ -
+
+
+ -
+
+
+ Whether to include sample images in the Tensorboard output
+
+
+ Samples To Tensorboard
+
+
+
+ -
+
+
+ Preset
+
+
+
+ -
+
+
+ true
+
+
+
+
+
+
+
+
+
+ sampleAfterSbx
+ sampleAfterCmb
+ skipSbx
+ formatCmb
+ sampleNowBtn
+ manualSampleBtn
+ nonEmaCbx
+ tensorboardCbx
+ configCmb
+ addSampleBtn
+ toggleBtn
+ listWidget
+
+
+
+
diff --git a/modules/ui/views/tabs/tools.ui b/modules/ui/views/tabs/tools.ui
new file mode 100644
index 000000000..e9afc0f2b
--- /dev/null
+++ b/modules/ui/views/tabs/tools.ui
@@ -0,0 +1,95 @@
+
+
+ Form
+
+
+
+ 0
+ 0
+ 400
+ 300
+
+
+
+ Form
+
+
+ -
+
+
+ Dataset Tools
+
+
+
+ -
+
+
+ Mask Tools
+
+
+
+ -
+
+
+ Caption Tools
+
+
+
+ -
+
+
+ Bulk Image Edit Tools
+
+
+
+ -
+
+
+ Bulk Caption Edit Tools
+
+
+
+ -
+
+
+ Video Tools
+
+
+
+ -
+
+
+ Convert Model Tools
+
+
+
+ -
+
+
+ Sampling Tools
+
+
+
+ -
+
+
+ Profiling Tools
+
+
+
+
+
+
+ datasetBtn
+ maskBtn
+ captionBtn
+ imageBtn
+ bulkCaptionBtn
+ videoBtn
+ convertBtn
+ samplingBtn
+ profilingBtn
+
+
+
+
diff --git a/modules/ui/views/tabs/training.ui b/modules/ui/views/tabs/training.ui
new file mode 100644
index 000000000..6277a1b52
--- /dev/null
+++ b/modules/ui/views/tabs/training.ui
@@ -0,0 +1,1916 @@
+
+
+ Form
+
+
+
+ 0
+ 0
+ 937
+ 538
+
+
+
+ Form
+
+
+ -
+
+
+ true
+
+
+
+
+ 0
+ -648
+ 1294
+ 1660
+
+
+
+
-
+
+
+
-
+
+
+ Masked Training
+
+
+
-
+
+
+ Unmasked Weight
+
+
+ unmaskedWeightSbx
+
+
+
+ -
+
+
+ When masked training is enabled, specifies the loss weight of areas outside the masked region
+
+
+ 0.050000000000000
+
+
+
+ -
+
+
+ Masked Prior Preservation Weight
+
+
+ maskedPriorPreservationSbx
+
+
+
+ -
+
+
+ Preserves regions outside the mask using the original untrained model output as a target. Only available for LoRA training. If enabled, use a low unmasked weight
+
+
+ 0.050000000000000
+
+
+
+ -
+
+
+ Unmasked Probability
+
+
+ unmaskedProbabilitySbx
+
+
+
+ -
+
+
+ When masked training is enabled, specifies the number of training steps done on unmasked samples
+
+
+ 1.000000000000000
+
+
+ 0.050000000000000
+
+
+
+ -
+
+
+ <html><head/><body><p>When custom conditioning image is enabled, will use png postfix with -condlabel instead of automatically generated. It's suitable for special scenarios, such as object removal, allowing the model to learn a certain behavior concept</p></body></html>
+
+
+ Custom Conditioning Image
+
+
+
+ -
+
+
+ When masked training is enabled, normalizes the loss for each sample based on the sizes of the masked region
+
+
+ Normalize Masked Area Loss
+
+
+
+ -
+
+
+ Masks the training samples to let the model focus on certain parts of the image. When enabled, one mask image is loaded for each training sample
+
+
+ Masked Training
+
+
+
+
+
+
+ -
+
+
+ Regularization
+
+
+
-
+
+
+ Selects the type of loss scaling to use during training. Functionally equated as: Loss * selection
+
+
+
+ -
+
+
+ Variational lower-bound strength for custom loss settings. Should be set to 1 for variational diffusion models
+
+
+
+ -
+
+
+ Mean Absolute Error strength for custom loss settings. Strengths should generally sum to 1
+
+
+
+ -
+
+
+ Inverse strength of loss weighting. Range: 1-20, only applies to Min SNR and P2
+
+
+
+ -
+
+
+ VB Strength
+
+
+ vbLossSbx
+
+
+
+ -
+
+
+ MSE Strength
+
+
+ mseSbx
+
+
+
+ -
+
+
+ Mean Squared Error strength for custom loss settings. Strengths should generally sum to 1
+
+
+
+ -
+
+
+ Log-Cosh Strength
+
+
+ logcoshSbx
+
+
+
+ -
+
+
+ Gamma
+
+
+ gammaSbx
+
+
+
+ -
+
+
+ Loss Weight Function
+
+
+ lossWeightFunctionCmb
+
+
+
+ -
+
+
+ Loss Scaler
+
+
+ lossScalerCmb
+
+
+
+ -
+
+
+ Log - Hyperbolic cosine Error strength for custom loss settings. Strengths should generally sum to 1
+
+
+
+ -
+
+
+ Huber Strength
+
+
+ huberStrengthSbx
+
+
+
+ -
+
+
+ Choice of loss weight function. Can help the model learn details more accurately
+
+
+
+ -
+
+
+ MAE Strength
+
+
+ maeSbx
+
+
+
+ -
+
+
+ Huber Delta
+
+
+ huberDeltaSbx
+
+
+
+ -
+
+
+ Huber loss strength for custom loss settings. Less sensitive to outliers than MSE. Strengths should generally sum to 1
+
+
+
+ -
+
+
+ Delta parameter for huber loss
+
+
+
+
+
+
+ -
+
+
+ Text Encoder 1 Settings
+
+
+
-
+
+
+ Enables training the text encoder 1 model
+
+
+ Train Text Encoder 1
+
+
+
+ -
+
+
+ Enables training embeddings for the text encoder 1 model
+
+
+ Train Text Encoder 1 Embedding
+
+
+
+ -
+
+
+ When to stop training the text encoder 1
+
+
+ 999999999
+
+
+
+ -
+
+
+ Text Encoder 1 Learning Rate
+
+
+ te1LearningRateLed
+
+
+
+ -
+
+
+ Dropout Probability
+
+
+ te1DropoutSbx
+
+
+
+ -
+
+
+ Text Encoder 1 Clip Skip
+
+
+ te1ClipSkipSbx
+
+
+
+ -
+
+
+ -
+
+
+ Stop Training After
+
+
+ te1StopTrainingSbx
+
+
+
+ -
+
+
+ The number of additional clip layers to skip. 0 = the model default
+
+
+
+ -
+
+
+ The Probability for dropping the text encoder 1 conditioning
+
+
+ 1.000000000000000
+
+
+ 0.050000000000000
+
+
+
+ -
+
+
+ Includes text encoder 1 in the training run
+
+
+ Include Text Encoder 1
+
+
+
+ -
+
+
+ The learning rate of the text encoder 1. Overrides the base learning rate
+
+
+
+
+
+
+ -
+
+
+ Text Encoder 2 Settings
+
+
+
-
+
+
+ Includes text encoder 2 in the training run
+
+
+ Include Text Encoder 2
+
+
+
+ -
+
+
+ Stop Training After
+
+
+ te2StopTrainingSbx
+
+
+
+ -
+
+
+ Dropout Probability
+
+
+ te2DropoutSbx
+
+
+
+ -
+
+
+ The learning rate of the text encoder 2. Overrides the base learning rate
+
+
+
+ -
+
+
+ The number of additional clip layers to skip. 0 = the model default
+
+
+
+ -
+
+
+ -
+
+
+ Text Encoder 2 Clip Skip
+
+
+ te2ClipSkipSbx
+
+
+
+ -
+
+
+ Enables training embeddings for the text encoder 2 model
+
+
+ Train Text Encoder 2 Embedding
+
+
+
+ -
+
+
+ Enables training the text encoder 2 model
+
+
+ Train Text Encoder 2
+
+
+
+ -
+
+
+ Text Encoder 2 Learning Rate
+
+
+ te2LearningRateLed
+
+
+
+ -
+
+
+ When to stop training the text encoder 2
+
+
+ 999999999
+
+
+
+ -
+
+
+ The Probability for dropping the text encoder 2 conditioning
+
+
+ 1.000000000000000
+
+
+ 0.050000000000000
+
+
+
+ -
+
+
+ Text Encoder 2 Sequence Length
+
+
+ te2SeqLenSbx
+
+
+
+ -
+
+
+ 512
+
+
+
+
+
+
+ -
+
+
+ Text Encoder 3 Settings
+
+
+
-
+
+
+ Enables training the text encoder 3 model
+
+
+ Train Text Encoder 3
+
+
+
+ -
+
+
+ Includes text encoder 3 in the training run
+
+
+ Include Text Encoder 3
+
+
+
+ -
+
+
+ Enables training embeddings for the text encoder 3 model
+
+
+ Train Text Encoder 3 Embedding
+
+
+
+ -
+
+
+ Stop Training After
+
+
+ te3StopTrainingSbx
+
+
+
+ -
+
+
+ Text Encoder 3 Learning Rate
+
+
+ te3LearningRateLed
+
+
+
+ -
+
+
+ The Probability for dropping the text encoder 3 conditioning
+
+
+ 1.000000000000000
+
+
+ 0.050000000000000
+
+
+
+ -
+
+
+ Dropout Probability
+
+
+ te3DropoutSbx
+
+
+
+ -
+
+
+ When to stop training the text encoder 3
+
+
+ 999999999
+
+
+
+ -
+
+
+ Text Encoder 3 Clip Skip
+
+
+ te3ClipSkipSbx
+
+
+
+ -
+
+
+ -
+
+
+ The number of additional clip layers to skip. 0 = the model default
+
+
+
+ -
+
+
+ The learning rate of the text encoder 3. Overrides the base learning rate
+
+
+
+
+
+
+ -
+
+
+ Text Encoder 4 Settings
+
+
+
-
+
+
+ Text Encoder 4 Clip Skip
+
+
+ te4ClipSkipSbx
+
+
+
+ -
+
+
+ Stop Training After
+
+
+ te4StopTrainingSbx
+
+
+
+ -
+
+
+ Text Encoder 4 Learning Rate
+
+
+ te4LearningRateLed
+
+
+
+ -
+
+
+ The Probability for dropping the text encoder 4 conditioning
+
+
+ 1.000000000000000
+
+
+ 0.050000000000000
+
+
+
+ -
+
+
+ Includes text encoder 4 in the training run
+
+
+ Include Text Encoder 4
+
+
+
+ -
+
+
+ The number of additional clip layers to skip. 0 = the model default
+
+
+
+ -
+
+
+ Enables training embeddings for the text encoder 4 model
+
+
+ Train Text Encoder 4 Embedding
+
+
+
+ -
+
+
+ -
+
+
+ When to stop training the text encoder 4
+
+
+ 999999999
+
+
+
+ -
+
+
+ Dropout Probability
+
+
+ te4DropoutSbx
+
+
+
+ -
+
+
+ Enables training the text encoder 4 model
+
+
+ Train Text Encoder 4
+
+
+
+ -
+
+
+ The learning rate of the text encoder 4. Overrides the base learning rate
+
+
+
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+
+
+
+ -
+
+
+
-
+
+
+ UNet Settings
+
+
+
-
+
+
+ UNet Learning Rate
+
+
+ unetLearningRateLed
+
+
+
+ -
+
+
+ When to stop training the UNet
+
+
+ 999999999
+
+
+
+ -
+
+
+ Stop Training After
+
+
+ unetStopSbx
+
+
+
+ -
+
+
+ -
+
+
+ Rescales the noise scheduler to a zero terminal signal to noise ratio and switches the model to a v-prediction target
+
+
+ Rescale Noise Scheduler + V-pred
+
+
+
+ -
+
+
+ Enables training the UNet model
+
+
+ Train UNet
+
+
+
+ -
+
+
+ The learning rate of the UNet. Overrides the base learning rate
+
+
+
+
+
+
+ -
+
+
+ Transformer Settings
+
+
+
-
+
+
+ Enables training the Transformer model
+
+
+ Train Transformer
+
+
+
+ -
+
+
+ Stop Training After
+
+
+ transformerStopSbx
+
+
+
+ -
+
+
+ When to stop training the Transformer
+
+
+ 999999999
+
+
+
+ -
+
+
+ -
+
+
+ Guidance Scale
+
+
+ transformerGuidanceSbx
+
+
+
+ -
+
+
+ Transformer Learning Rate
+
+
+ transformerLearningRateLed
+
+
+
+ -
+
+
+ Force enables passing of a text embedding attention mask to the transformer. This can improve training on shorter captions
+
+
+ Force Attention Mask
+
+
+
+ -
+
+
+ The guidance scale of guidance distilled models passed to the transformer during training
+
+
+
+ -
+
+
+ The learning rate of the Transformer. Overrides the base learning rate
+
+
+
+
+
+
+ -
+
+
+ Prior Model Settings
+
+
+
-
+
+
+ Enables training the Prior model
+
+
+ Train Prior
+
+
+
+ -
+
+
+ Stop Training After
+
+
+ priorStopSbx
+
+
+
+ -
+
+
+ When to stop training the Prior
+
+
+ 999999999
+
+
+
+ -
+
+
+ -
+
+
+ Prior Learning Rate
+
+
+ priorLearningRateLed
+
+
+
+ -
+
+
+ The learning rate of the Prior. Overrides the base learning rate
+
+
+
+
+
+
+ -
+
+
+ Embeddings
+
+
+
-
+
+
+ The learning rate of embeddings. Overrides the base learning rate
+
+
+
+ -
+
+
+ Rescales each trained embedding to the median embedding norm
+
+
+ Preserve Embedding Norm
+
+
+
+ -
+
+
+ Embeddings Learning Rate
+
+
+ embeddingLearningRateLed
+
+
+
+
+
+
+ -
+
+
+ Trainable Layers
+
+
+
-
+
+
+ Layer Filter
+
+
+ layerFilterCmb
+
+
+
+ -
+
+
+ -
+
+
+ Select a preset defining which layers to train, or select 'Custom' to define your own.
+A blank 'custom' field or 'Full' will train all layers
+
+
+
+ -
+
+
+ Use Regex
+
+
+
+
+
+
+ -
+
+
+ Noise
+
+
+
-
+
+
+ Timestep Distribution
+
+
+ timestepDistributionCmb
+
+
+
+ -
+
+
+ Max Noising Strength
+
+
+ maxNoisingStrengthSbx
+
+
+
+ -
+
+
+ Perturbation Noise Weight
+
+
+ perturbationNoiseWeightSbx
+
+
+
+ -
+
+
+ Selects the function to sample timesteps during training
+
+
+
+ -
+
+
+ Specifies the maximum noising strength used during training. This can be useful to reduce overfitting, but also reduces the impact of training samples on the overall image composition
+
+
+
+ -
+
+
+ Min Noising Strength
+
+
+ minNoisingStrengthSbx
+
+
+
+ -
+
+
+ Noising Weight
+
+
+ noisingWeightSbx
+
+
+
+ -
+
+
+ Generalized Offset Noise
+
+
+
+ -
+
+
+ Timestep Shift
+
+
+ timestepShiftSbx
+
+
+
+ -
+
+
+ Controls the bias parameter of the timestep distribution function. Use the preview to see more details
+
+
+
+ -
+
+
+ Controls the weight parameter of the timestep distribution function. Use the preview to see more details
+
+
+
+ -
+
+
+ The guidance scale of guidance distilled models passed to the transformer during training
+
+
+
+ -
+
+
+ Noising Bias
+
+
+ noisingBiasSbx
+
+
+
+ -
+
+
+ Update Preview
+
+
+
+ -
+
+
+ Specifies the minimum noising strength used during training. This can help to improve composition, but prevents finer details from being trained
+
+
+
+ -
+
+
+ The weight of perturbation noise added to each training step
+
+
+
+ -
+
+
+ <html><head/><body><p>Shift the timestep distribution</p></body></html>
+
+
+
+ -
+
+
+ Offset Noise Weight
+
+
+ offsetNoiseWeightSbx
+
+
+
+ -
+
+
+ Dynamically shift the timestep distribution based on resolution. For the preview, a random resolution between 512 and 1024 is used, assuming a VAE scale factor of 8. During training, the actual resolution
+
+
+ Dynamic Timestep Shifting
+
+
+
+ -
+
+
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+
+
+
+ -
+
+
+
-
+
+
+ Optimizer
+
+
+
-
+
+
+ true
+
+
+ Key-value pairs for an argument in your scheduler. Some special values can be used, wrapped in percent signs: LR, EPOCHS, STEPS_PER_EPOCH, TOTAL_STEPS, SCHEDULER_STEPS. Note that OneTrainer calls step() after every individual learning step, not every epoch, so what Torch calls 'epoch' you should treat as 'step'
+
+
+ QAbstractItemView::EditTrigger::AnyKeyPressed|QAbstractItemView::EditTrigger::EditKeyPressed|QAbstractItemView::EditTrigger::SelectedClicked
+
+
+ QAbstractItemView::SelectionMode::NoSelection
+
+
+ true
+
+
+ 1
+
+
+ 2
+
+
+ true
+
+
+ false
+
+
+ false
+
+
+
+
+ Parameter
+
+
+
+
+ Value
+
+
+
+
+ -
+
+
+ Selects the type of learning rate scaling to use during training. Functionally equated as: LR * SQRT(selection)
+
+
+
+ -
+
+
+ Learning Rate Cycles
+
+
+ cyclesSbx
+
+
+
+ -
+
+
+ Learning Rate Scheduler
+
+
+ schedulerCmb
+
+
+
+ -
+
+
+ Learning Rate Warmup Steps
+
+
+ warmupStepsSbx
+
+
+
+ -
+
+
+ The number of learning rate cycles. This is only applicable if the learning rate scheduler supports cycles
+
+
+ 999999999
+
+
+
+ -
+
+
+ true
+
+
+ Class Name
+
+
+ schedulerClassLed
+
+
+
+ -
+
+
+ Clips the gradient norm. Leave empty to disable gradient clipping
+
+
+ 999999999.000000000000000
+
+
+
+ -
+
+
+ Learning Rate Min Factor
+
+
+ minFactorSbx
+
+
+
+ -
+
+
+ Local Batch Size
+
+
+ batchSizeSbx
+
+
+
+ -
+
+
+ The number of steps it takes to gradually increase the learning rate from 0 to the specified learning rate. Values >1 are interpeted as a fixed number of steps, values <=1 are intepreted as a percentage of the total training steps (ex. 0.2 = 20% of the total step count)
+
+
+ 999999999.000000000000000
+
+
+
+ -
+
+
+ Unit = float. Method = percentage. For a factor of 0.1, the final LR will be 10% of the initial LR. If the initial LR is 1e-4, the final LR will be 1e-5
+
+
+ 999999999.000000000000000
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Clip Grad Norm
+
+
+ clipGradNormSbx
+
+
+
+ -
+
+
+ Epochs
+
+
+ epochsSbx
+
+
+
+ -
+
+
+ Python class module and name for the custom scheduler class, in the form of <module>.<class_name>
+
+
+
+ -
+
+
+ Number of accumulation steps. Increase this number to trade batch size for training speed
+
+
+ 999999999
+
+
+
+ -
+
+
+ The number of epochs for a full training run
+
+
+ 999999999
+
+
+
+ -
+
+
+ Optimizer
+
+
+ optimizerCmb
+
+
+
+ -
+
+
+ Learning Rate
+
+
+
+ -
+
+
+ Learning Rate Scaler
+
+
+ scalerCmb
+
+
+
+ -
+
+
+ The base learning rate
+
+
+
+ -
+
+
+ Learning rate scheduler that automatically changes the learning rate during training
+
+
+
+ -
+
+
+ Accumulation Steps
+
+
+ accumulationStepsSbx
+
+
+
+ -
+
+
+ The batch size of one training step
+
+
+ 999999999
+
+
+
+ -
+
+
+ The type of optimizer
+
+
+
+
+
+
+ -
+
+
+ General Settings
+
+
+
-
+
+
+ Enables circular padding for all conv layers to better train seamless images
+
+
+ Force Circular Padding
+
+
+
+ -
+
+
+ Number of steps between EMA update steps
+
+
+ 999999999
+
+
+
+ -
+
+
+ Enables Asynchronous offloading
+
+
+ Async Offloading
+
+
+
+ -
+
+
+ Enables Activation Offloading
+
+
+ Offload Activations
+
+
+
+ -
+
+
+ Gradient Checkpointing
+
+
+ gradientCheckpointingCmb
+
+
+
+ -
+
+
+ EMA Update Step Interval
+
+
+ emaUpdateIntervalSbx
+
+
+
+ -
+
+
+ Frames
+
+
+ framesSbx
+
+
+
+ -
+
+
+ Fallback Train Data Type
+
+
+ fallbackDTypeCmb
+
+
+
+ -
+
+
+ EMA averages the training progress over many steps, better preserving different concepts in big datasets
+
+
+
+ -
+
+
+ Resolution
+
+
+ resolutionLed
+
+
+
+ -
+
+
+ EMA
+
+
+ emaCmb
+
+
+
+ -
+
+
+ Decay parameter of the EMA model. Higher numbers will average more steps. For datasets of hundreds or thousands of images, set this to 0.9999. For smaller datasets, set it to 0.999 or even 0.998
+
+
+ 999999999.000000000000000
+
+
+
+ -
+
+
+ The number of frames used for training
+
+
+ 999
+
+
+
+ -
+
+
+ Train Data Type
+
+
+ trainDTypeCmb
+
+
+
+ -
+
+
+ The mixed precision data type used for training. This can increase training speed, but reduces precision
+
+
+
+ -
+
+
+ The resolution used for training. Optionally specify multiple resolutions separated by a comma, or a single exact resolution in the format <width>x<height>
+
+
+
+ -
+
+
+ Layer Offload Fraction
+
+
+ layerOffloadFractionSbx
+
+
+
+ -
+
+
+ Enables the autocast cache. Disabling this reduces memory usage, but increases training time
+
+
+ Autocast Cache
+
+
+
+ -
+
+
+ Enables gradient checkpointing. This reduces memory usage, but increases training time
+
+
+
+ -
+
+
+ EMA Decay
+
+
+ emaDecaySbx
+
+
+
+ -
+
+
+ Enables offloading of individual layers during training to reduce VRAM usage. Increases training time and uses more RAM. Only available if checkpointing is set to CPU_OFFLOADED. values between 0 and 1, 0=disabled
+
+
+ 1.000000000000000
+
+
+ 0.100000000000000
+
+
+
+ -
+
+
+ The mixed precision data type used for training stages that don't support float16 data types. This can increase training speed, but reduces precision
+
+
+
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ scrollArea
+ optimizerCmb
+ optimizerBtn
+ schedulerCmb
+ schedulerClassLed
+ tableWidget
+ warmupStepsSbx
+ minFactorSbx
+ cyclesSbx
+ epochsSbx
+ batchSizeSbx
+ accumulationStepsSbx
+ scalerCmb
+ clipGradNormSbx
+ emaCmb
+ emaDecaySbx
+ emaUpdateIntervalSbx
+ gradientCheckpointingCmb
+ asyncOffloadCbx
+ offloadActivationsCbx
+ layerOffloadFractionSbx
+ trainDTypeCmb
+ fallbackDTypeCmb
+ autocastCacheCbx
+ resolutionLed
+ framesSbx
+ circularPaddingCbx
+ maskedTrainingCbx
+ unmaskedProbabilitySbx
+ unmaskedWeightSbx
+ normalizeMaskedAreaCbx
+ maskedPriorPreservationSbx
+ customConditioningImageCbx
+ mseSbx
+ maeSbx
+ huberStrengthSbx
+ huberDeltaSbx
+ logcoshSbx
+ vbLossSbx
+ lossWeightFunctionCmb
+ gammaSbx
+ lossScalerCmb
+ te1IncludeCbx
+ te1TrainCbx
+ te1TrainEmbCbx
+ te1DropoutSbx
+ te1StopTrainingSbx
+ te1StopTrainingCmb
+ te1ClipSkipSbx
+ te2IncludeCbx
+ te2TrainCbx
+ te2TrainEmbCbx
+ te2DropoutSbx
+ te2StopTrainingSbx
+ te2StopTrainingCmb
+ te2ClipSkipSbx
+ te2SeqLenSbx
+ te3IncludeCbx
+ te3TrainCbx
+ te3TrainEmbCbx
+ te3DropoutSbx
+ te3StopTrainingSbx
+ te3StopTrainingCmb
+ te3ClipSkipSbx
+ te4IncludeCbx
+ te4TrainCbx
+ te4TrainEmbCbx
+ te4DropoutSbx
+ te4StopTrainingSbx
+ te4StopTrainingCmb
+ te4ClipSkipSbx
+ unetTrainCbx
+ unetStopSbx
+ unetStopCmb
+ unetRescaleCbx
+ transformerTrainCbx
+ transformerStopSbx
+ transformerStopCmb
+ transformerAttnMaskCbx
+ transformerGuidanceSbx
+ priorTrainCbx
+ priorStopSbx
+ priorStopCmb
+ embeddingNormCbx
+ layerFilterCmb
+ layerFilterLed
+ layerFilterRegexCbx
+ offsetNoiseWeightSbx
+ generalizedOffsetNoiseCbx
+ perturbationNoiseWeightSbx
+ timestepDistributionCmb
+ minNoisingStrengthSbx
+ maxNoisingStrengthSbx
+ noisingWeightSbx
+ noisingBiasSbx
+ timestepShiftSbx
+ dynamicTimestepShiftingCbx
+ updatePreviewBtn
+
+
+
+
diff --git a/modules/ui/views/widgets/concept.ui b/modules/ui/views/widgets/concept.ui
new file mode 100644
index 000000000..40fda7ae4
--- /dev/null
+++ b/modules/ui/views/widgets/concept.ui
@@ -0,0 +1,79 @@
+
+
+ Form
+
+
+
+ 0
+ 0
+ 144
+ 192
+
+
+
+
+ 0
+ 0
+
+
+
+ Form
+
+
+ -
+
+
+
+
+
+ true
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+
+ 128
+ 128
+
+
+
+
+
+
+
+ 128
+ 128
+
+
+
+ true
+
+
+
+ -
+
+
+ X
+
+
+
+ -
+
+
+ +
+
+
+
+
+
+
+
+
diff --git a/modules/ui/views/widgets/embedding.ui b/modules/ui/views/widgets/embedding.ui
new file mode 100644
index 000000000..c30d5f78a
--- /dev/null
+++ b/modules/ui/views/widgets/embedding.ui
@@ -0,0 +1,184 @@
+
+
+ Form
+
+
+
+ 0
+ 0
+ 1024
+ 78
+
+
+
+
+ 0
+ 0
+
+
+
+ Form
+
+
+ -
+
+
+ Stop Training After
+
+
+ stopTrainingSbx
+
+
+
+ -
+
+
+ Train this embedding
+
+
+ Train
+
+
+ true
+
+
+
+ -
+
+
+ When to stop training the embedding
+
+
+ 999
+
+
+
+ -
+
+
+ Base Embedding
+
+
+ baseEmbeddingLed
+
+
+
+ -
+
+
+ +
+
+
+
+ -
+
+
+ The base embedding to train on. Leave empty to create a new embedding
+
+
+
+ -
+
+
+ The placeholder used when using the embedding in a prompt
+
+
+ <embedding>
+
+
+
+ -
+
+
+ Initial Embedding Text
+
+
+ initialEmbeddingLed
+
+
+
+ -
+
+
+ Placeholder
+
+
+ placeholderLed
+
+
+
+ -
+
+
+ Output embeddings are calculated at the output of the text encoder, not the input. This can improve results for larger text encoders and lower VRAM usage
+
+
+ Output Embedding
+
+
+
+ -
+
+
+ -
+
+
+ The token count used when creating a new embedding. Leave empty to auto detect from the initial embedding text
+
+
+ 1
+
+
+ 75
+
+
+
+ -
+
+
+ X
+
+
+
+ -
+
+
+ Token Count
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ The initial embedding text used when creating a new embedding
+
+
+ *
+
+
+
+
+
+
+ deleteBtn
+ cloneBtn
+ baseEmbeddingLed
+ baseEmbeddingBtn
+ placeholderLed
+ tokenSbx
+ trainCbx
+ outputEmbeddingCbx
+ stopTrainingSbx
+ stopTrainingCmb
+ initialEmbeddingLed
+
+
+
+
diff --git a/modules/ui/views/widgets/sample.ui b/modules/ui/views/widgets/sample.ui
new file mode 100644
index 000000000..41fd47fbb
--- /dev/null
+++ b/modules/ui/views/widgets/sample.ui
@@ -0,0 +1,134 @@
+
+
+ Form
+
+
+
+ 0
+ 0
+ 648
+ 47
+
+
+
+ Form
+
+
+ -
+
+
+ X
+
+
+
+ -
+
+
+ +
+
+
+
+ -
+
+
+
+
+
+ true
+
+
+
+ -
+
+
+ Width
+
+
+ widthSbx
+
+
+
+ -
+
+
+ QAbstractSpinBox::ButtonSymbols::NoButtons
+
+
+ 8096
+
+
+ 512
+
+
+
+ -
+
+
+ Height
+
+
+ heightSbx
+
+
+
+ -
+
+
+ QAbstractSpinBox::ButtonSymbols::NoButtons
+
+
+ 8096
+
+
+ 512
+
+
+
+ -
+
+
+ Seed
+
+
+ seedLed
+
+
+
+ -
+
+
+ -
+
+
+ Prompt
+
+
+ promptLed
+
+
+
+ -
+
+
+ -
+
+
+ ...
+
+
+
+
+
+
+ deleteBtn
+ cloneBtn
+ enabledCbx
+ widthSbx
+ heightSbx
+ seedLed
+ promptLed
+ openWindowBtn
+
+
+
+
diff --git a/modules/ui/views/widgets/sampling_params.ui b/modules/ui/views/widgets/sampling_params.ui
new file mode 100644
index 000000000..4a2c803de
--- /dev/null
+++ b/modules/ui/views/widgets/sampling_params.ui
@@ -0,0 +1,269 @@
+
+
+ Form
+
+
+
+ 0
+ 0
+ 496
+ 302
+
+
+
+ Form
+
+
+ -
+
+
+ Prompt
+
+
+ promptLed
+
+
+
+ -
+
+
+ -
+
+
+ Negative Prompt
+
+
+ negativePromptLed
+
+
+
+ -
+
+
+ -
+
+
+ Width
+
+
+ widthSbx
+
+
+
+ -
+
+
+ 8096
+
+
+ 512
+
+
+
+ -
+
+
+ Height
+
+
+ heightSbx
+
+
+
+ -
+
+
+ 8096
+
+
+ 512
+
+
+
+ -
+
+
+ Frames
+
+
+ framesSbx
+
+
+
+ -
+
+
+ Number of frames to generate. Only used when generating videos
+
+
+
+ -
+
+
+ Length
+
+
+ lengthSbx
+
+
+
+ -
+
+
+ Length in seconds of audio output
+
+
+
+ -
+
+
+ Seed
+
+
+ seedLed
+
+
+
+ -
+
+
+ Random Seed
+
+
+
+ -
+
+
+ CFG Scale
+
+
+ cfgSbx
+
+
+
+ -
+
+
+ -
+
+
+ Steps
+
+
+ stepsSbx
+
+
+
+ -
+
+
+ -
+
+
+ Sampler
+
+
+ samplerCmb
+
+
+
+ -
+
+
+ -
+
+
+ Enables inpainting sampling. Only available when sampling from an inpainting model
+
+
+ Inpainting
+
+
+
+ -
+
+
+ Base Image Path
+
+
+ imagePathLed
+
+
+
+ -
+
+
+ The base image used when inpainting
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Qt::Orientation::Horizontal
+
+
+
+ 0
+ 20
+
+
+
+
+ -
+
+
+ Mask Image Path
+
+
+ maskPathLed
+
+
+
+ -
+
+
+ The mask used when inpainting
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+
+
+
+ promptLed
+ negativePromptLed
+ widthSbx
+ heightSbx
+ framesSbx
+ lengthSbx
+ seedLed
+ randomSeedCbx
+ cfgSbx
+ stepsSbx
+ samplerCmb
+ inpaintingCbx
+ imagePathLed
+ imagePathBtn
+ maskPathLed
+ maskPathBtn
+
+
+
+
diff --git a/modules/ui/views/windows/bulk_caption.ui b/modules/ui/views/windows/bulk_caption.ui
new file mode 100644
index 000000000..b58a7c250
--- /dev/null
+++ b/modules/ui/views/windows/bulk_caption.ui
@@ -0,0 +1,171 @@
+
+
+ Dialog
+
+
+ Qt::WindowModality::WindowModal
+
+
+
+ 0
+ 0
+ 531
+ 578
+
+
+
+
+ 0
+ 0
+
+
+
+ Bulk Caption Edit Tools
+
+
+ false
+
+
+ true
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ -
+
+
+ Dataset Directory
+
+
+ directoryLed
+
+
+
+ -
+
+
+ QPlainTextEdit::LineWrapMode::NoWrap
+
+
+ true
+
+
+
+ -
+
+
+ Preview
+
+
+
+ -
+
+
+ -
+
+
+ -
+
+
+ -
+
+
+ Remove Text
+
+
+
+ -
+
+
+ Regex Replace
+
+
+
+ -
+
+
+ Preview (First 10 captions)
+
+
+
+ -
+
+
+ -
+
+
+ With
+
+
+
+ -
+
+
+ With
+
+
+
+ -
+
+
+ Add Text
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Apply Changes
+
+
+
+ -
+
+
+ -
+
+
+ Replace
+
+
+
+ -
+
+
+ -
+
+
+
+
+
+ directoryLed
+ directoryBtn
+ addLed
+ addCmb
+ removeLed
+ replaceLed
+ replaceWithLed
+ regexLed
+ regexWithLed
+ previewBtn
+ applyBtn
+ previewTed
+
+
+
+
diff --git a/modules/ui/views/windows/bulk_image.ui b/modules/ui/views/windows/bulk_image.ui
new file mode 100644
index 000000000..a01da5751
--- /dev/null
+++ b/modules/ui/views/windows/bulk_image.ui
@@ -0,0 +1,185 @@
+
+
+ Dialog
+
+
+ Qt::WindowModality::WindowModal
+
+
+
+ 0
+ 0
+ 442
+ 519
+
+
+
+
+ 0
+ 0
+
+
+
+ Bulk Image Edit Tools
+
+
+ false
+
+
+ true
+
+
+ -
+
+
-
+
+
+ Process Files
+
+
+
+ -
+
+
+ Cancel
+
+
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Verify Images for Corruption
+
+
+
+ -
+
+
+ Image Optimization Type
+
+
+ optimizationCmb
+
+
+
+ -
+
+
+ Select the megapixel threshold for resizing
+
+
+
+ -
+
+
+ <html><head/><body><p>Select the type of image optimization to apply.</p><p>None: No image optimization will be applied,</p><p>Optimize PNGs: Optimize PNGs using PyOxiPNG (level 5, fix_errors=True),</p><p>Convert to WebP: Re-encode all images to WebP format at 90% quality,</p><p>Convert to JPEG XL: Encode images as JPEG XL at 90% quality or losslessly for JPEGs</p></body></html>
+
+
+
+ -
+
+
+ 24
+
+
+
+ -
+
+
+ Replace Transparency with Color
+
+
+
+ -
+
+
+ 1
+
+
+
+ -
+
+
+ Custom Megapixels
+
+
+ customSbx
+
+
+
+ -
+
+
+ Directory
+
+
+ directoryLed
+
+
+
+ -
+
+
+ Enter color name (e.g., 'white', 'black'), hex code (e.g., '#FFFFFF'), or 'random'/-1 for random color
+
+
+
+ -
+
+
+ Resize Images Above
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Sequential Renaming (1.txt, 2.txt, etc.)
+
+
+
+ -
+
+
+ -
+
+
+ true
+
+
+
+
+
+
+ directoryLed
+ directoryBtn
+ verifyCbx
+ renameCbx
+ replaceColorCbx
+ colorLed
+ resizeCbx
+ resizeCmb
+ customSbx
+ optimizationCmb
+ processBtn
+ cancelBtn
+
+
+
+
diff --git a/modules/ui/views/windows/concept.ui b/modules/ui/views/windows/concept.ui
new file mode 100644
index 000000000..f2e01170a
--- /dev/null
+++ b/modules/ui/views/windows/concept.ui
@@ -0,0 +1,1380 @@
+
+
+ Dialog
+
+
+ Qt::WindowModality::ApplicationModal
+
+
+
+ 0
+ 0
+ 841
+ 730
+
+
+
+
+ 0
+ 0
+
+
+
+ Edit Concept
+
+
+ false
+
+
+ true
+
+
+ -
+
+
+ Ok
+
+
+
+ -
+
+
+ 0
+
+
+
+ General
+
+
+
-
+
+
+ Name of the concept
+
+
+
+ -
+
+
+ true
+
+
+
+
+
+
-
+
+
+ Concept Type
+
+
+ conceptTypeCmb
+
+
+
+ -
+
+
+ STANDARD: Standard finetuning with the sample as training target
+VALIDATION: Use concept for validation instead of training
+PRIOR_PREDICTION: Use the sample to make a prediction using the model as it was before training. This prediction is then used as the training target for the model in training. This can be used as regularisation and to preserve prior model knowledge while finetuning the model on other concepts. Only implemented for LoRA
+
+
+
+ -
+
+
+ Path
+
+
+ pathLed
+
+
+
+ -
+
+
+ Path where the training data is located
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Download Now
+
+
+
+ -
+
+
+ Prompt Source
+
+
+ promptSourceCmb
+
+
+
+ -
+
+
+ The source for prompts used during training. When selecting "From single text file", select a text file that contains a list of prompts
+
+
+
+ -
+
+
+ false
+
+
+
+ -
+
+
+ false
+
+
+ ...
+
+
+
+ -
+
+
+ Includes images from subdirectories into the dataset
+
+
+ Include Subdirectories
+
+
+
+ -
+
+
+ Image Variations
+
+
+ imageVariationsSbx
+
+
+
+ -
+
+
+ The number of different image versions to cache if latent caching is enabled
+
+
+
+ -
+
+
+ Text Variations
+
+
+ textVariationsSbx
+
+
+
+ -
+
+
+ The number of different text versions to cache if latent caching is enabled
+
+
+
+ -
+
+
+ Balancing
+
+
+ balancingSbx
+
+
+
+ -
+
+
+ The number of samples used during training. Use repeats to multiply the concept, or samples to specify an exact number of samples used in each epoch
+
+
+
+ -
+
+
+ -
+
+
+ Loss Weight
+
+
+ lossWeightSbx
+
+
+
+ -
+
+
+ The loss multiplyer for this concept
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+
+
+
+ -
+
+
+ Name
+
+
+ nameLed
+
+
+
+ -
+
+
+ true
+
+
+ Enable or disable this concept
+
+
+ Enabled
+
+
+ true
+
+
+
+
+
+
+
+ Image Augmentation
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ -1.000000000000000
+
+
+ 0.100000000000000
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Randomly adjusts the contrast of the sample during training
+
+
+ Random Contrast
+
+
+
+ -
+
+
+ -
+
+
+ Randomly adjusts the saturation of the sample during training
+
+
+ Random Saturation
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Enable this augmentation with fixed values
+
+
+ Fixed
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Automatically create circular masks for masked training
+
+
+ Circular Mask Generation
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Override the resolution for this concept. Optionally specify multiple resolutions separated by a comma, or a single exact resolution in the format <width>x<height>
+
+
+ Resolution Override
+
+
+
+ -
+
+
+ -1.000000000000000
+
+
+ 0.100000000000000
+
+
+
+ -
+
+
+ Randomly rotates the sample during training
+
+
+ Random Rotation
+
+
+
+ -
+
+
+ Enable this augmentation with random values
+
+
+ Random
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ -1.000000000000000
+
+
+ 1.000000000000000
+
+
+ 0.050000000000000
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Randomly adjusts the hue of the sample during training
+
+
+ Random Hue
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ -360.000000000000000
+
+
+ 360.000000000000000
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Randomly rotate the training samples and crop to the masked region
+
+
+ Random Rotate And Crop
+
+
+
+ -
+
+
+ Randomly flip the sample during training
+
+
+ Random Flip
+
+
+
+ -
+
+
+ Enables random cropping of samples
+
+
+ Crop Jitter
+
+
+
+ -
+
+
+ Qt::Orientation::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+ -
+
+
+ -1.000000000000000
+
+
+ 0.100000000000000
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Randomly adjusts the brightness of the sample during training
+
+
+ Random Brightness
+
+
+
+ -
+
+
+ Preview
+
+
+
-
+
+
+
+ 0
+ 0
+
+
+
+
+ 300
+ 300
+
+
+
+
+ 300
+ 300
+
+
+
+
+
+
+ Qt::AlignmentFlag::AlignCenter
+
+
+
+ -
+
+
+ <
+
+
+
+ -
+
+
+ Update Preview
+
+
+
+ -
+
+
+ >
+
+
+
+ -
+
+
+ Show Augmentations
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ true
+
+
+
+
+
+
+
+
+
+
+ Text Augmentation
+
+
+ -
+
+
+ Keep Tag Count
+
+
+ keepTagCountSbx
+
+
+
+ -
+
+
+ The number of tags at the start of the caption that are not shuffled or dropped
+
+
+
+ -
+
+
+ -
+
+
+ List of tags which will be whitelisted/blacklisted by dropout. 'Whitelist' tags will never be dropped but all others may be, 'Blacklist' tags may be dropped but all others will never be, 'None' may drop any tags. Can specify either a delimiter-separated list in the field, or a file path to a .txt or .csv file with entries separated by newlines
+
+
+
+ -
+
+
+ Enables random dropout for tags in the captions
+
+
+ Tag Dropout
+
+
+
+ -
+
+
+ The delimiter between tags
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
+ Capitalization Mode
+
+
+ capitalizationModeLed
+
+
+
+ -
+
+
+ Probability to randomize capitialization of each tag, from 0 to 1
+
+
+ 1.000000000000000
+
+
+ 0.100000000000000
+
+
+
+ -
+
+
+ Tag Delimiter
+
+
+ tagDelimiterLed
+
+
+
+ -
+
+
+ If enabled, converts the caption to lowercase before any further processing
+
+
+ Force Lowercase
+
+
+
+ -
+
+
+ Enables tag shuffling
+
+
+ Tag Shuffling
+
+
+
+ -
+
+
+ Special Dropout Tags
+
+
+ specialDropoutTagsCmb
+
+
+
+ -
+
+
+ Method used to drop captions. 'Full' will drop the entire caption past the 'kept' tags with a certain probability, 'Random' will drop individual tags with the set probability, and 'Random Weighted' will linearly increase the probability of dropping tags, more likely to preseve tags near the front with full probability to drop at the end
+
+
+
+ -
+
+
+ Dropout Mode
+
+
+ dropoutModeCmb
+
+
+
+ -
+
+
+ Interpret special tags with regex, such as 'photo.*' to match 'photo, photograph, photon' but not 'telephoto'. Includes exception for '/(' and '/)' syntax found in many booru/e6 tags
+
+
+ Special Tags Regex
+
+
+
+ -
+
+
+ Probability
+
+
+
+ -
+
+
+ Comma-separated list of types of capitalization randomization to perform. 'capslock' for ALL CAPS, 'title' for First Letter Of Every Word, 'first' for First word only, 'random' for rAndOMiZeD lEtTERs
+
+
+
+ -
+
+
+ Probability to drop tags, from 0 to 1
+
+
+ 1.000000000000000
+
+
+ 0.100000000000000
+
+
+
+ -
+
+
+ Enables randomization of capitalization for tags in the caption
+
+
+ Randomize Capitalization
+
+
+
+ -
+
+
+ Probability
+
+
+ dropoutProbabilitySbx
+
+
+
+
+
+
+
+ Statistics
+
+
+ -
+
+
+ true
+
+
+
+
+ 0
+ 0
+ 484
+ 534
+
+
+
+
-
+
+
+ Basic Stats
+
+
+
-
+
+
+ Directories:
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Total Images:
+
+
+
+ -
+
+
+ Total Videos:
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Total Captions:
+
+
+
+ -
+
+
+ Total Size:
+
+
+
+ -
+
+
+ Total Masks:
+
+
+
+
+
+
+ -
+
+
+ Advanced Stats
+
+
+
-
+
+
+ Min Pixels:
+
+
+
+ -
+
+
+ Avg Pixels:
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Max Caption Length:
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Images with Captions:
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Unpaired Captions:
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Max Pixels:
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Avg FPS:
+
+
+
+ -
+
+
+ Max FPS:
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Avg Length:
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Max Length:
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Avg Caption Length:
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Min Length:
+
+
+
+ -
+
+
+ Min FPS:
+
+
+
+ -
+
+
+ Images with Masks:
+
+
+
+ -
+
+
+ Smallest Buckets:
+
+
+
+ -
+
+
+ Unpaired Masks:
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Min Caption Length:
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Videos with Captions:
+
+
+
+ -
+
+
+
+
+
+
+
+
+
+
+
+
+
+ -
+
+
+ Abort Scan
+
+
+
+ -
+
+
+ Refresh Basic
+
+
+
+ -
+
+
+ Aspect Bucketing
+
+
+
-
+
+
+
+
+
+ -
+
+
+
+
+
+
+ -
+
+
+ Refresh Advanced
+
+
+
+
+
+
+
+
+
+
+ tabWidget
+ nameLed
+ enabledCbx
+ conceptTypeCmb
+ pathLed
+ pathBtn
+ downloadNowBtn
+ promptSourceCmb
+ promptSourceLed
+ promptSourceBtn
+ includeSubdirectoriesCbx
+ imageVariationsSbx
+ textVariationsSbx
+ balancingSbx
+ balancingCmb
+ lossWeightSbx
+ rndJitterCbx
+ rndFlipCbx
+ fixFlipCbx
+ rndRotationCbx
+ fixRotationCbx
+ rotationSbx
+ rndBrightnessCbx
+ fixBrightnessCbx
+ brightnessSbx
+ rndContrastCbx
+ fixContrastCbx
+ contrastSbx
+ rndSaturationCbx
+ fixSaturationCbx
+ saturationSbx
+ rndHueCbx
+ fixHueCbx
+ hueSbx
+ rndCircularMaskCbx
+ rndRotateCropCbx
+ fixResolutionOverrideCbx
+ resolutionOverrideLed
+ prevBtn
+ updatePreviewBtn
+ nextBtn
+ showAugmentationsCbx
+ tagShufflingCbx
+ tagDelimiterLed
+ keepTagCountSbx
+ tagDropoutCbx
+ dropoutModeCmb
+ dropoutProbabilitySbx
+ specialDropoutTagsCmb
+ specialDropoutTagsLed
+ specialTagsRegexCbx
+ randomizeCapitalizationCbx
+ forceLowercaseCbx
+ capitalizationModeLed
+ capitalizationProbabilitySbx
+ refreshBasicBtn
+ refreshAdvancedBtn
+ abortScanBtn
+ okBtn
+ promptTed
+
+
+
+
+ enabledCbx
+ toggled(bool)
+ conceptBox
+ setEnabled(bool)
+
+
+ 66
+ 96
+
+
+ 120
+ 111
+
+
+
+
+
diff --git a/modules/ui/views/windows/convert.ui b/modules/ui/views/windows/convert.ui
new file mode 100644
index 000000000..a37991a73
--- /dev/null
+++ b/modules/ui/views/windows/convert.ui
@@ -0,0 +1,179 @@
+
+
+ Dialog
+
+
+ Qt::WindowModality::WindowModal
+
+
+
+ 0
+ 0
+ 400
+ 282
+
+
+
+
+ 0
+ 0
+
+
+
+ Convert Tools
+
+
+ false
+
+
+ true
+
+
+ -
+
+
+ Model Type
+
+
+ modelTypeCmb
+
+
+
+ -
+
+
+ Type of the model
+
+
+
+ -
+
+
+ Precision to use when saving the output model
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
+ Output Data Type
+
+
+ outputDTypeCmb
+
+
+
+ -
+
+
+ Input Name
+
+
+ inputLed
+
+
+
+ -
+
+
+ Filename or directory where the output model is saved
+
+
+
+ -
+
+
+ Filename, directory or hugging face repository of the base model
+
+
+
+ -
+
+
+ The type of model to convert
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Training Method
+
+
+ trainingMethodCmb
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Model Output Destination
+
+
+ outputLed
+
+
+
+ -
+
+
+ Convert
+
+
+
+ -
+
+
+ Output Format
+
+
+
+ -
+
+
+ Format to use when saving the output model
+
+
+
+
+
+
+ modelTypeCmb
+ trainingMethodCmb
+ inputLed
+ inputBtn
+ outputDTypeCmb
+ outputLed
+ outputBtn
+ convertBtn
+
+
+
+
diff --git a/modules/ui/views/windows/dataset.ui b/modules/ui/views/windows/dataset.ui
new file mode 100644
index 000000000..5564f2d51
--- /dev/null
+++ b/modules/ui/views/windows/dataset.ui
@@ -0,0 +1,222 @@
+
+
+ Dialog
+
+
+ Qt::WindowModality::WindowModal
+
+
+
+ 0
+ 0
+ 1398
+ 650
+
+
+
+
+ 0
+ 0
+
+
+
+ Dataset Tools
+
+
+ false
+
+
+ true
+
+
+ -
+
+
+ -
+
+
+ Caption Filter
+
+
+
+ -
+
+
+ Qt::Orientation::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+ -
+
+
+ Include subdirectories when loading images
+
+
+ Include Subdirectories
+
+
+
+ -
+
+
+ Browse
+
+
+
+ -
+
+
+ -
+
+
+ Open
+
+
+
+ -
+
+
+ -
+
+
+ File/Path Filter
+
+
+
+ -
+
+
-
+
+
+ Qt::Orientation::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+
+
+ -
+
+
+ -
+
+
+ With Mask Only
+
+
+
+ -
+
+
-
+
+
+
+ 0
+ 0
+
+
+
+ QPlainTextEdit::LineWrapMode::NoWrap
+
+
+
+ -
+
+
-
+
+
+ Save Caption
+
+
+
+ -
+
+
+ Delete Caption
+
+
+
+ -
+
+
+ Reset Caption
+
+
+
+
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ QAbstractItemView::EditTrigger::CurrentChanged
+
+
+ false
+
+
+
+ 1
+
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+
+
+
+
+ -
+
+
+ With Caption Only
+
+
+
+
+
+
+ openBtn
+ includeSubdirCbx
+ browseBtn
+ fileFilterLed
+ fileFilterCmb
+ captionFilterLed
+ captionFilterCmb
+ maskFilterCbx
+ captionFilterCbx
+ fileTreeWdg
+ captionTed
+ saveCaptionBtn
+ deleteCaptionBtn
+ resetCaptionBtn
+
+
+
+
diff --git a/modules/ui/views/windows/generate_caption.ui b/modules/ui/views/windows/generate_caption.ui
new file mode 100644
index 000000000..a36f41571
--- /dev/null
+++ b/modules/ui/views/windows/generate_caption.ui
@@ -0,0 +1,157 @@
+
+
+ Dialog
+
+
+ Qt::WindowModality::WindowModal
+
+
+
+ 0
+ 0
+ 382
+ 297
+
+
+
+
+ 0
+ 0
+
+
+
+ Generate Captions
+
+
+ false
+
+
+ true
+
+
+ -
+
+
+ -
+
+
+ Caption Postfix
+
+
+ captionPostfixLed
+
+
+
+ -
+
+
+ Folder
+
+
+ folderLed
+
+
+
+ -
+
+
+ -
+
+
+ Progress
+
+
+
+ -
+
+
+ Model
+
+
+ modelCmb
+
+
+
+ -
+
+
+ -
+
+
+ Create Caption
+
+
+
+ -
+
+
+ Caption Prefix
+
+
+ captionPrefixLed
+
+
+
+ -
+
+
+ Include Subfolders
+
+
+
+ -
+
+
+ Initial Caption
+
+
+ initialCaptionLed
+
+
+
+ -
+
+
+ 24
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Mode
+
+
+
+ -
+
+
+ -
+
+
+ -
+
+
+
+
+
+ modelCmb
+ folderLed
+ folderBtn
+ initialCaptionLed
+ captionPrefixLed
+ captionPostfixLed
+ modeCmb
+ includeSubfolderCbx
+ createMaskBtn
+
+
+
+
diff --git a/modules/ui/views/windows/generate_mask.ui b/modules/ui/views/windows/generate_mask.ui
new file mode 100644
index 000000000..770a0eb2f
--- /dev/null
+++ b/modules/ui/views/windows/generate_mask.ui
@@ -0,0 +1,220 @@
+
+
+ Dialog
+
+
+ Qt::WindowModality::WindowModal
+
+
+
+ 0
+ 0
+ 382
+ 365
+
+
+
+
+ 0
+ 0
+
+
+
+ Generate Masks
+
+
+ false
+
+
+ true
+
+
+ -
+
+
+ -
+
+
+ Model
+
+
+ modelCmb
+
+
+
+ -
+
+
+ -
+
+
+ Blending strength when combining with existing masks
+
+
+ 1.000000000000000
+
+
+ 0.050000000000000
+
+
+
+ -
+
+
+ Folder
+
+
+ folderLed
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Include Subfolders
+
+
+
+ -
+
+
+ Prompt
+
+
+ promptLed
+
+
+
+ -
+
+
+ Smooth
+
+
+ smoothSbx
+
+
+
+ -
+
+
+ -
+
+
+ Progress
+
+
+
+ -
+
+
+ Enter object to detect (e.g. 'person', 'dog', 'car')
+
+
+
+ -
+
+
+ Create Mask
+
+
+
+ -
+
+
+ Confidence threshold: Lower values detect more objects but may include incorrect regions
+
+
+ 1.000000000000000
+
+
+ 0.050000000000000
+
+
+
+ -
+
+
+ Additional smoothing (0=use built-in smoothing, higher values for extra smoothing)
+
+
+
+ -
+
+
+ Alpha
+
+
+ alphaSbx
+
+
+
+ -
+
+
+ Threshold
+
+
+ thresholdSbx
+
+
+
+ -
+
+
+ Mode
+
+
+ modeCmb
+
+
+
+ -
+
+
+ 24
+
+
+
+ -
+
+
+ Expansion pixels: Expands mask boundaries outward
+
+
+
+ -
+
+
+ Expand
+
+
+ expandSbx
+
+
+
+
+
+
+ modelCmb
+ folderLed
+ folderBtn
+ promptLed
+ modeCmb
+ thresholdSbx
+ smoothSbx
+ expandSbx
+ alphaSbx
+ includeSubfolderCbx
+ createMaskBtn
+
+
+
+
diff --git a/modules/ui/views/windows/new_sample.ui b/modules/ui/views/windows/new_sample.ui
new file mode 100644
index 000000000..0b13cd352
--- /dev/null
+++ b/modules/ui/views/windows/new_sample.ui
@@ -0,0 +1,46 @@
+
+
+ Dialog
+
+
+ Qt::WindowModality::ApplicationModal
+
+
+
+ 0
+ 0
+ 662
+ 376
+
+
+
+
+ 0
+ 0
+
+
+
+ Sample Settings
+
+
+ false
+
+
+ true
+
+
+ -
+
+
+ Ok
+
+
+
+ -
+
+
+
+
+
+
+
diff --git a/modules/ui/views/windows/onetrainer.ui b/modules/ui/views/windows/onetrainer.ui
new file mode 100644
index 000000000..db24ba54a
--- /dev/null
+++ b/modules/ui/views/windows/onetrainer.ui
@@ -0,0 +1,302 @@
+
+
+ MainWindow
+
+
+
+ 0
+ 0
+ 1126
+ 837
+
+
+
+ OneTrainer
+
+
+
+ -
+
+
+ QFrame::Shape::StyledPanel
+
+
+ QFrame::Shadow::Raised
+
+
+
-
+
+
+
+ 0
+ 0
+
+
+
+
+ 50
+ 50
+
+
+
+
+ 50
+ 50
+
+
+
+ Qt::TextFormat::AutoText
+
+
+ resources/icons/icon.png
+
+
+ true
+
+
+
+ -
+
+
+
+ 16
+ true
+
+
+
+ OneTrainer
+
+
+
+ -
+
+
+ -
+
+
+ Save Config
+
+
+
+ -
+
+
+ Wiki
+
+
+
+ -
+
+
+ Qt::Orientation::Horizontal
+
+
+
+ 40
+ 20
+
+
+
+
+ -
+
+
+ -
+
+
+
+
+
+ -
+
+
+ -1
+
+
+
+ -
+
+
+ QFrame::Shape::StyledPanel
+
+
+ QFrame::Shadow::Raised
+
+
+
-
+
+
-
+
+
-
+
+
+
+ 16777215
+ 8
+
+
+
+
+ 8
+
+
+
+ Step
+
+
+
+ -
+
+
+
+ 16777215
+ 16
+
+
+
+
+ 8
+
+
+
+ 0
+
+
+ %p% (%v/%m)
+
+
+
+ -
+
+
+
+ 16777215
+ 16
+
+
+
+
+ 8
+
+
+
+ 0
+
+
+ %p% (%v/%m)
+
+
+
+ -
+
+
+
+ 16777215
+ 8
+
+
+
+
+ 8
+
+
+
+ Epoch
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+
+ 400
+ 0
+
+
+
+
+ 8
+
+
+
+ Current status of the training run
+
+
+
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+
+ 400
+ 0
+
+
+
+
+ 8
+
+
+
+
+
+
+
+
+
+
+
+ -
+
+
+ Export the current configuration as a script to run without a UI
+
+
+ Export
+
+
+
+ -
+
+
+ Debug
+
+
+
+ -
+
+
+ Tensorboard
+
+
+
+ -
+
+
+ Start Training
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/modules/ui/views/windows/optimizer.ui b/modules/ui/views/windows/optimizer.ui
new file mode 100644
index 000000000..d62544809
--- /dev/null
+++ b/modules/ui/views/windows/optimizer.ui
@@ -0,0 +1,86 @@
+
+
+ Dialog
+
+
+ Qt::WindowModality::ApplicationModal
+
+
+
+ 0
+ 0
+ 480
+ 367
+
+
+
+
+ 0
+ 0
+
+
+
+ Optimizer Settings
+
+
+ false
+
+
+ true
+
+
+ -
+
+
+ Optimizer
+
+
+
+ -
+
+
+ Load default settings for the selected optimizer
+
+
+ Load Defaults
+
+
+
+ -
+
+
+ The type of optimizer
+
+
+
+ -
+
+
+ true
+
+
+
+
+ 0
+ 0
+ 460
+ 315
+
+
+
+
-
+
+
+
+
+
+
+
+
+
+ optimizerCmb
+ loadDefaultsBtn
+
+
+
+
diff --git a/modules/ui/views/windows/profile.ui b/modules/ui/views/windows/profile.ui
new file mode 100644
index 000000000..dea165c5e
--- /dev/null
+++ b/modules/ui/views/windows/profile.ui
@@ -0,0 +1,73 @@
+
+
+ Dialog
+
+
+ Qt::WindowModality::WindowModal
+
+
+
+ 0
+ 0
+ 234
+ 146
+
+
+
+
+ 0
+ 0
+
+
+
+ Profiling Tools
+
+
+ false
+
+
+ true
+
+
+ -
+
+
+ Dump Stack
+
+
+
+ -
+
+
+ Turns on/off Scalene profiling. Only works when OneTrainer is launched with Scalene!
+
+
+ Start Profiling
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
+ Inactive
+
+
+
+
+
+
+
+
diff --git a/modules/ui/views/windows/sample.ui b/modules/ui/views/windows/sample.ui
new file mode 100644
index 000000000..df627bfb6
--- /dev/null
+++ b/modules/ui/views/windows/sample.ui
@@ -0,0 +1,72 @@
+
+
+ Dialog
+
+
+ Qt::WindowModality::WindowModal
+
+
+
+ 0
+ 0
+ 1161
+ 584
+
+
+
+
+ 0
+ 0
+
+
+
+ Sampling Tools
+
+
+ false
+
+
+ true
+
+
+ -
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+
+
+
+ true
+
+
+
+ -
+
+
+ Sample
+
+
+
+ -
+
+
+ 24
+
+
+
+
+
+
+ sampleBtn
+
+
+
+
diff --git a/modules/ui/views/windows/save.ui b/modules/ui/views/windows/save.ui
new file mode 100644
index 000000000..4172a809c
--- /dev/null
+++ b/modules/ui/views/windows/save.ui
@@ -0,0 +1,67 @@
+
+
+ Dialog
+
+
+ Qt::WindowModality::ApplicationModal
+
+
+
+ 0
+ 0
+ 400
+ 84
+
+
+
+
+ 0
+ 0
+
+
+
+ Save Config
+
+
+ false
+
+
+ true
+
+
+ -
+
+
+ Config Name
+
+
+ configCmb
+
+
+
+ -
+
+
+ Ok
+
+
+
+ -
+
+
+ true
+
+
+
+ -
+
+
+ Cancel
+
+
+
+
+
+
+
+
diff --git a/modules/ui/views/windows/video.ui b/modules/ui/views/windows/video.ui
new file mode 100644
index 000000000..03fde20ac
--- /dev/null
+++ b/modules/ui/views/windows/video.ui
@@ -0,0 +1,653 @@
+
+
+ Dialog
+
+
+ Qt::WindowModality::WindowModal
+
+
+
+ 0
+ 0
+ 697
+ 424
+
+
+
+
+ 0
+ 0
+
+
+
+ Video Tools
+
+
+ false
+
+
+ true
+
+
+ -
+
+
+ 0
+
+
+
+ Extract Clips
+
+
+
-
+
+
+ If enabled, files are saved to subfolders based on filename and input directory. Otherwise will all be saved to the top level of the output directory
+
+
+ Output To Subdirectories
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Maximum length in seconds for saved clips, larger clips will be broken into multiple small clips
+
+
+
+ -
+
+
+ Extract Single
+
+
+
+ -
+
+
+ Single Video
+
+
+ singleVideo1Led
+
+
+
+ -
+
+
+ Output
+
+
+ output1Led
+
+
+
+ -
+
+
+ Directory
+
+
+ directory1Led
+
+
+
+ -
+
+
+ Crop Variation
+
+
+ cropVariation1Sbx
+
+
+
+ -
+
+
+ FPS to convert output videos to, set to 0 to keep original rate
+
+
+
+ -
+
+
+ Extract Directory
+
+
+
+ -
+
+
+ If enabled, detect cuts in the input video and split at those points. Otherwise will split at any point, and clips may contain cuts
+
+
+ Split At Cuts
+
+
+
+ -
+
+
+ Path to directory with multiple videos to process, including in subdirectories
+
+
+
+ -
+
+
+ Remove Borders
+
+
+
+ -
+
+
+ Link to single video file to process
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Path to folder where extracted clips will be saved
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Max Length (s)
+
+
+ maxLengthSbx
+
+
+
+ -
+
+
+ Output clips will be randomly cropped to +- the base aspect ratio, somewhat biased towards making square videos. Set to 0 to use only base aspect
+
+
+
+ -
+
+
+ Time Range
+
+
+ timeRangeStart1Led
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Time range to limit selection for single video, format as hour:minute:second, minute:second, or seconds
+
+
+ 99:99:99
+
+
+ 00:00:00
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Time range to limit selection for single video, format as hour:minute:second, minute:second, or seconds
+
+
+ 99:99:99
+
+
+ 99:99:99
+
+
+
+ -
+
+
+ Set FPS
+
+
+ fpsSbx
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+
+
+
+
+ Extract Images
+
+
+ -
+
+
+ Extract Directory
+
+
+
+ -
+
+
+ Output images will be randomly cropped to +- the base aspect ratio, somewhat biased towards making square images. Set to 0 to use only base aspect
+
+
+
+ -
+
+
+ Extract Single
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Crop Variation
+
+
+ cropVariation2Sbx
+
+
+
+ -
+
+
+ Remove Borders
+
+
+
+ -
+
+
+ Images/sec
+
+
+ imagesSecSbx
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Output
+
+
+ output2Led
+
+
+
+ -
+
+
+ Link to single video file to process
+
+
+
+ -
+
+
+ Path to folder where extracted images will be saved
+
+
+
+ -
+
+
+ Single Video
+
+
+ singleVideo2Led
+
+
+
+ -
+
+
+ Time Range
+
+
+ timeRangeStart2
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Time range to limit selection for single video, format as hour:minute:second, minute:second, or seconds
+
+
+ 99:99:99
+
+
+ 99:99:99
+
+
+
+ -
+
+
+ Directory
+
+
+ directory2Led
+
+
+
+ -
+
+
+ Blur Removal
+
+
+ blurRemovalSbx
+
+
+
+ -
+
+
+ If enabled, files are saved to subfolders based on filename and input directory. Otherwise will all be saved to the top level of the output directory
+
+
+ Output To Subdirectories
+
+
+
+ -
+
+
+
+ 0
+ 0
+
+
+
+ Time range to limit selection for single video, format as hour:minute:second, minute:second, or seconds
+
+
+ 99:99:99
+
+
+ 00:00:00
+
+
+
+ -
+
+
+ Qt::Orientation::Vertical
+
+
+
+ 20
+ 40
+
+
+
+
+ -
+
+
+ Path to directory with multiple videos to process, including in subdirectories
+
+
+
+ -
+
+
+ Number of images to capture per second of video. Images will be taken at semi-random frames around the specified frequency
+
+
+
+ -
+
+
+ Threshold for removal of blurry images, relative to all others. For example at 0.2, the blurriest 20% of the final selected frames will not be saved
+
+
+
+
+
+
+
+ Download
+
+
+ -
+
+
+ Single Link
+
+
+ singleLinkLed
+
+
+
+ -
+
+
+ Link to video/playlist to download. Uses yt-dlp, supports youtube, twitch, instagram, and many other sites
+
+
+
+ -
+
+
+ Download Link
+
+
+
+ -
+
+
+ Link List
+
+
+ linkListLed
+
+
+
+ -
+
+
+ Path to txt file with list of links separated by newlines
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Download List
+
+
+
+ -
+
+
+ Output
+
+
+ output3Led
+
+
+
+ -
+
+
+ Path to folder where downloaded videos will be saved
+
+
+
+ -
+
+
+ ...
+
+
+
+ -
+
+
+ Additional Args
+
+
+ additionalArgsTed
+
+
+
+ -
+
+
+ Any additional arguments to pass to yt-dlp, for example '--restrict-filenames --force-overwrite'. Default args will hide most terminal outputs
+
+
+ --quiet --no-warnings --progress
+
+
+
+ -
+
+
+ yt-dlp Info
+
+
+
+
+
+
+
+
+
+
+ tabWidget
+ singleVideo1Led
+ singleVideo1Btn
+ extractSingle1Btn
+ timeRangeStart1Led
+ timeRangeStop1Led
+ directory1Led
+ directory1Btn
+ extractDirectory1Btn
+ output1Led
+ output1Btn
+ outputSubdirectories1Cbx
+ splitCutsCbx
+ maxLengthSbx
+ fpsSbx
+ removeBorders1Cbx
+ cropVariation1Sbx
+ singleVideo2Led
+ singleVideo2Btn
+ extractSingle2Btn
+ timeRangeStart2
+ timeRangeStop2
+ directory2Led
+ directory2Btn
+ extractDirectory2Btn
+ output2Led
+ output2Btn
+ outputSubdirectories2Cbx
+ imagesSecSbx
+ blurRemovalSbx
+ removeBorders2Cbx
+ cropVariation2Sbx
+ singleLinkLed
+ downloadLinkBtn
+ linkListLed
+ linkListBtn
+ downloadListBtn
+ output3Led
+ output3Btn
+ additionalArgsTed
+ infoBtn
+
+
+
+
diff --git a/modules/util/enum.legacy/AudioFormat.py b/modules/util/enum.legacy/AudioFormat.py
new file mode 100644
index 000000000..abdf7520c
--- /dev/null
+++ b/modules/util/enum.legacy/AudioFormat.py
@@ -0,0 +1,15 @@
+from enum import Enum
+
+
+class AudioFormat(Enum):
+ MP3 = 'MP3'
+
+ def __str__(self):
+ return self.value
+
+ def extension(self) -> str:
+ match self:
+ case AudioFormat.MP3:
+ return ".mp3"
+ case _:
+ return ""
diff --git a/modules/util/enum.legacy/BalancingStrategy.py b/modules/util/enum.legacy/BalancingStrategy.py
new file mode 100644
index 000000000..4a247f64d
--- /dev/null
+++ b/modules/util/enum.legacy/BalancingStrategy.py
@@ -0,0 +1,9 @@
+from enum import Enum
+
+
+class BalancingStrategy(Enum):
+ REPEATS = 'REPEATS'
+ SAMPLES = 'SAMPLES'
+
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/CloudAction.py b/modules/util/enum.legacy/CloudAction.py
new file mode 100644
index 000000000..a4cb639bc
--- /dev/null
+++ b/modules/util/enum.legacy/CloudAction.py
@@ -0,0 +1,9 @@
+from enum import Enum
+
+
+class CloudAction(Enum):
+ NONE = 'NONE'
+ STOP = 'STOP'
+ DELETE = 'DELETE'
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/CloudFileSync.py b/modules/util/enum.legacy/CloudFileSync.py
new file mode 100644
index 000000000..dcb008f1d
--- /dev/null
+++ b/modules/util/enum.legacy/CloudFileSync.py
@@ -0,0 +1,8 @@
+from enum import Enum
+
+
+class CloudFileSync(Enum):
+ FABRIC_SFTP = 'FABRIC_SFTP'
+ NATIVE_SCP = 'NATIVE_SCP'
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/CloudType.py b/modules/util/enum.legacy/CloudType.py
new file mode 100644
index 000000000..d20f61a07
--- /dev/null
+++ b/modules/util/enum.legacy/CloudType.py
@@ -0,0 +1,8 @@
+from enum import Enum
+
+
+class CloudType(Enum):
+ RUNPOD = 'RUNPOD'
+ LINUX = 'LINUX'
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/ConceptType.py b/modules/util/enum.legacy/ConceptType.py
new file mode 100644
index 000000000..4bda9f49c
--- /dev/null
+++ b/modules/util/enum.legacy/ConceptType.py
@@ -0,0 +1,10 @@
+from enum import Enum
+
+
+class ConceptType(Enum):
+ STANDARD = 'STANDARD'
+ VALIDATION = 'VALIDATION'
+ PRIOR_PREDICTION = 'PRIOR_PREDICTION'
+
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/ConfigPart.py b/modules/util/enum.legacy/ConfigPart.py
new file mode 100644
index 000000000..271e0ba1e
--- /dev/null
+++ b/modules/util/enum.legacy/ConfigPart.py
@@ -0,0 +1,10 @@
+from enum import Enum
+
+
+class ConfigPart(Enum):
+ NONE = 'NONE'
+ SETTINGS = 'SETTINGS'
+ ALL = 'ALL'
+
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/DataType.py b/modules/util/enum.legacy/DataType.py
new file mode 100644
index 000000000..dd15e01d8
--- /dev/null
+++ b/modules/util/enum.legacy/DataType.py
@@ -0,0 +1,54 @@
+from enum import Enum
+
+import torch
+
+
+class DataType(Enum):
+ NONE = 'NONE'
+ FLOAT_8 = 'FLOAT_8'
+ FLOAT_16 = 'FLOAT_16'
+ FLOAT_32 = 'FLOAT_32'
+ BFLOAT_16 = 'BFLOAT_16'
+ TFLOAT_32 = 'TFLOAT_32'
+ INT_8 = 'INT_8'
+ NFLOAT_4 = 'NFLOAT_4'
+ GGUF = 'GGUF'
+
+ def __str__(self):
+ return self.value
+
+ def torch_dtype(
+ self,
+ supports_quantization: bool = True,
+ ):
+ if self.is_quantized() and not supports_quantization:
+ return torch.float16
+
+ match self:
+ case DataType.FLOAT_16:
+ return torch.float16
+ case DataType.FLOAT_32:
+ return torch.float32
+ case DataType.BFLOAT_16:
+ return torch.bfloat16
+ case DataType.TFLOAT_32:
+ return torch.float32
+ case _:
+ return None
+
+ def enable_tf(self):
+ return self == DataType.TFLOAT_32
+
+ def is_quantized(self):
+ return self in [DataType.FLOAT_8,
+ DataType.INT_8,
+ DataType.NFLOAT_4]
+
+ def quantize_fp8(self):
+ return self == DataType.FLOAT_8
+
+ def quantize_int8(self):
+ return self == DataType.INT_8
+
+ def quantize_nf4(self):
+ return self == DataType.NFLOAT_4
diff --git a/modules/util/enum.legacy/EMAMode.py b/modules/util/enum.legacy/EMAMode.py
new file mode 100644
index 000000000..2742c445d
--- /dev/null
+++ b/modules/util/enum.legacy/EMAMode.py
@@ -0,0 +1,10 @@
+from enum import Enum
+
+
+class EMAMode(Enum):
+ OFF = 'OFF'
+ GPU = 'GPU'
+ CPU = 'CPU'
+
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/FileType.py b/modules/util/enum.legacy/FileType.py
new file mode 100644
index 000000000..e0fa4fe81
--- /dev/null
+++ b/modules/util/enum.legacy/FileType.py
@@ -0,0 +1,10 @@
+from enum import Enum
+
+
+class FileType(Enum):
+ IMAGE = 'IMAGE'
+ VIDEO = 'VIDEO'
+ AUDIO = 'AUDIO'
+
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/GenerateCaptionsModel.py b/modules/util/enum.legacy/GenerateCaptionsModel.py
new file mode 100644
index 000000000..02da82d05
--- /dev/null
+++ b/modules/util/enum.legacy/GenerateCaptionsModel.py
@@ -0,0 +1,10 @@
+from enum import Enum
+
+
+class GenerateCaptionsModel(Enum):
+ BLIP = 'BLIP'
+ BLIP2 = 'BLIP2'
+ WD14_VIT_2 = 'WD14_VIT_2'
+
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/GenerateMasksModel.py b/modules/util/enum.legacy/GenerateMasksModel.py
new file mode 100644
index 000000000..a81653754
--- /dev/null
+++ b/modules/util/enum.legacy/GenerateMasksModel.py
@@ -0,0 +1,11 @@
+from enum import Enum
+
+
+class GenerateMasksModel(Enum):
+ CLIPSEG = 'CLIPSEG'
+ REMBG = 'REMBG'
+ REMBG_HUMAN = 'REMBG_HUMAN'
+ COLOR = 'COLOR'
+
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/GradientCheckpointingMethod.py b/modules/util/enum.legacy/GradientCheckpointingMethod.py
new file mode 100644
index 000000000..d3f05666a
--- /dev/null
+++ b/modules/util/enum.legacy/GradientCheckpointingMethod.py
@@ -0,0 +1,17 @@
+from enum import Enum
+
+
+class GradientCheckpointingMethod(Enum):
+ OFF = 'OFF'
+ ON = 'ON'
+ CPU_OFFLOADED = 'CPU_OFFLOADED'
+
+ def __str__(self):
+ return self.value
+
+ def enabled(self):
+ return self == GradientCheckpointingMethod.ON \
+ or self == GradientCheckpointingMethod.CPU_OFFLOADED
+
+ def offload(self):
+ return self == GradientCheckpointingMethod.CPU_OFFLOADED
diff --git a/modules/util/enum.legacy/GradientReducePrecision.py b/modules/util/enum.legacy/GradientReducePrecision.py
new file mode 100644
index 000000000..04eb2e526
--- /dev/null
+++ b/modules/util/enum.legacy/GradientReducePrecision.py
@@ -0,0 +1,39 @@
+from enum import Enum
+
+import torch
+
+
+class GradientReducePrecision(Enum):
+ WEIGHT_DTYPE = 'WEIGHT_DTYPE'
+ FLOAT_32 = 'FLOAT_32'
+ WEIGHT_DTYPE_STOCHASTIC = 'WEIGHT_DTYPE_STOCHASTIC'
+ FLOAT_32_STOCHASTIC = 'FLOAT_32_STOCHASTIC'
+
+ def torch_dtype(self, weight_dtype: torch.dtype) -> torch.dtype:
+ match self:
+ case GradientReducePrecision.WEIGHT_DTYPE:
+ return weight_dtype
+ case GradientReducePrecision.FLOAT_32:
+ return torch.float32
+ case GradientReducePrecision.WEIGHT_DTYPE_STOCHASTIC:
+ return weight_dtype
+ case GradientReducePrecision.FLOAT_32_STOCHASTIC:
+ return torch.float32
+ case _:
+ raise ValueError
+
+ def stochastic_rounding(self, weight_dtype: torch.dtype) -> bool:
+ match self:
+ case GradientReducePrecision.WEIGHT_DTYPE:
+ return False
+ case GradientReducePrecision.FLOAT_32:
+ return False
+ case GradientReducePrecision.WEIGHT_DTYPE_STOCHASTIC:
+ return weight_dtype == torch.bfloat16
+ case GradientReducePrecision.FLOAT_32_STOCHASTIC:
+ return weight_dtype == torch.bfloat16
+ case _:
+ raise ValueError
+
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/ImageFormat.py b/modules/util/enum.legacy/ImageFormat.py
new file mode 100644
index 000000000..15b6183dc
--- /dev/null
+++ b/modules/util/enum.legacy/ImageFormat.py
@@ -0,0 +1,27 @@
+from enum import Enum
+
+
+class ImageFormat(Enum):
+ PNG = 'PNG'
+ JPG = 'JPG'
+
+ def __str__(self):
+ return self.value
+
+ def extension(self) -> str:
+ match self:
+ case ImageFormat.PNG:
+ return ".png"
+ case ImageFormat.JPG:
+ return ".jpg"
+ case _:
+ return ""
+
+ def pil_format(self) -> str:
+ match self:
+ case ImageFormat.PNG:
+ return "PNG"
+ case ImageFormat.JPG:
+ return "JPEG"
+ case _:
+ return ""
diff --git a/modules/util/enum.legacy/LearningRateScaler.py b/modules/util/enum.legacy/LearningRateScaler.py
new file mode 100644
index 000000000..d8ce19ec5
--- /dev/null
+++ b/modules/util/enum.legacy/LearningRateScaler.py
@@ -0,0 +1,32 @@
+from enum import Enum
+
+import modules.util.multi_gpu_util as multi
+
+
+class LearningRateScaler(Enum):
+ NONE = 'NONE'
+ BATCH = 'BATCH'
+ GLOBAL_BATCH = 'GLOBAL_BATCH'
+ GRADIENT_ACCUMULATION = 'GRADIENT_ACCUMULATION'
+ BOTH = 'BOTH'
+ GLOBAL_BOTH = 'GLOBAL_BOTH'
+
+ def __str__(self):
+ return self.value
+
+ def get_scale(self, batch_size: int, accumulation_steps: int) -> int:
+ match self:
+ case LearningRateScaler.NONE:
+ return 1
+ case LearningRateScaler.BATCH:
+ return batch_size
+ case LearningRateScaler.GLOBAL_BATCH:
+ return batch_size * multi.world_size()
+ case LearningRateScaler.GRADIENT_ACCUMULATION:
+ return accumulation_steps
+ case LearningRateScaler.BOTH:
+ return accumulation_steps * batch_size
+ case LearningRateScaler.GLOBAL_BOTH:
+ return accumulation_steps * batch_size * multi.world_size()
+ case _:
+ raise ValueError
diff --git a/modules/util/enum.legacy/LearningRateScheduler.py b/modules/util/enum.legacy/LearningRateScheduler.py
new file mode 100644
index 000000000..3e6292103
--- /dev/null
+++ b/modules/util/enum.legacy/LearningRateScheduler.py
@@ -0,0 +1,15 @@
+from enum import Enum
+
+
+class LearningRateScheduler(Enum):
+ CONSTANT = 'CONSTANT'
+ LINEAR = 'LINEAR'
+ COSINE = 'COSINE'
+ COSINE_WITH_RESTARTS = 'COSINE_WITH_RESTARTS'
+ COSINE_WITH_HARD_RESTARTS = 'COSINE_WITH_HARD_RESTARTS'
+ REX = 'REX'
+ ADAFACTOR = 'ADAFACTOR'
+ CUSTOM = 'CUSTOM'
+
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/LossScaler.py b/modules/util/enum.legacy/LossScaler.py
new file mode 100644
index 000000000..4ac82095e
--- /dev/null
+++ b/modules/util/enum.legacy/LossScaler.py
@@ -0,0 +1,32 @@
+from enum import Enum
+
+import modules.util.multi_gpu_util as multi
+
+
+class LossScaler(Enum):
+ NONE = 'NONE'
+ BATCH = 'BATCH'
+ GLOBAL_BATCH = 'GLOBAL_BATCH'
+ GRADIENT_ACCUMULATION = 'GRADIENT_ACCUMULATION'
+ BOTH = 'BOTH'
+ GLOBAL_BOTH = 'GLOBAL_BOTH'
+
+ def __str__(self):
+ return self.value
+
+ def get_scale(self, batch_size: int, accumulation_steps: int) -> int:
+ match self:
+ case LossScaler.NONE:
+ return 1
+ case LossScaler.BATCH:
+ return batch_size
+ case LossScaler.GLOBAL_BATCH:
+ return batch_size * multi.world_size()
+ case LossScaler.GRADIENT_ACCUMULATION:
+ return accumulation_steps
+ case LossScaler.BOTH:
+ return accumulation_steps * batch_size
+ case LossScaler.GLOBAL_BOTH:
+ return accumulation_steps * batch_size * multi.world_size()
+ case _:
+ raise ValueError
diff --git a/modules/util/enum.legacy/LossWeight.py b/modules/util/enum.legacy/LossWeight.py
new file mode 100644
index 000000000..b47ea9ae4
--- /dev/null
+++ b/modules/util/enum.legacy/LossWeight.py
@@ -0,0 +1,16 @@
+from enum import Enum
+
+
+class LossWeight(Enum):
+ CONSTANT = 'CONSTANT'
+ P2 = 'P2'
+ MIN_SNR_GAMMA = 'MIN_SNR_GAMMA'
+ DEBIASED_ESTIMATION = 'DEBIASED_ESTIMATION'
+ SIGMA = 'SIGMA'
+
+ def supports_flow_matching(self) -> bool:
+ return self == LossWeight.CONSTANT \
+ or self == LossWeight.SIGMA
+
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/ModelFormat.py b/modules/util/enum.legacy/ModelFormat.py
new file mode 100644
index 000000000..597ad4442
--- /dev/null
+++ b/modules/util/enum.legacy/ModelFormat.py
@@ -0,0 +1,30 @@
+from enum import Enum
+
+
+class ModelFormat(Enum):
+ DIFFUSERS = 'DIFFUSERS'
+ CKPT = 'CKPT'
+ SAFETENSORS = 'SAFETENSORS'
+ LEGACY_SAFETENSORS = 'LEGACY_SAFETENSORS'
+
+ INTERNAL = 'INTERNAL' # an internal format that stores all information to resume training
+
+ def __str__(self):
+ return self.value
+
+
+ def file_extension(self) -> str:
+ match self:
+ case ModelFormat.DIFFUSERS:
+ return ''
+ case ModelFormat.CKPT:
+ return '.ckpt'
+ case ModelFormat.SAFETENSORS:
+ return '.safetensors'
+ case ModelFormat.LEGACY_SAFETENSORS:
+ return '.safetensors'
+ case _:
+ return ''
+
+ def is_single_file(self) -> bool:
+ return self.file_extension() != ''
diff --git a/modules/util/enum.legacy/ModelType.py b/modules/util/enum.legacy/ModelType.py
new file mode 100644
index 000000000..cd6fd77bf
--- /dev/null
+++ b/modules/util/enum.legacy/ModelType.py
@@ -0,0 +1,152 @@
+from enum import Enum
+
+
+class ModelType(Enum):
+ STABLE_DIFFUSION_15 = 'STABLE_DIFFUSION_15'
+ STABLE_DIFFUSION_15_INPAINTING = 'STABLE_DIFFUSION_15_INPAINTING'
+ STABLE_DIFFUSION_20 = 'STABLE_DIFFUSION_20'
+ STABLE_DIFFUSION_20_BASE = 'STABLE_DIFFUSION_20_BASE'
+ STABLE_DIFFUSION_20_INPAINTING = 'STABLE_DIFFUSION_20_INPAINTING'
+ STABLE_DIFFUSION_20_DEPTH = 'STABLE_DIFFUSION_20_DEPTH'
+ STABLE_DIFFUSION_21 = 'STABLE_DIFFUSION_21'
+ STABLE_DIFFUSION_21_BASE = 'STABLE_DIFFUSION_21_BASE'
+
+ STABLE_DIFFUSION_3 = 'STABLE_DIFFUSION_3'
+ STABLE_DIFFUSION_35 = 'STABLE_DIFFUSION_35'
+
+ STABLE_DIFFUSION_XL_10_BASE = 'STABLE_DIFFUSION_XL_10_BASE'
+ STABLE_DIFFUSION_XL_10_BASE_INPAINTING = 'STABLE_DIFFUSION_XL_10_BASE_INPAINTING'
+
+ WUERSTCHEN_2 = 'WUERSTCHEN_2'
+ STABLE_CASCADE_1 = 'STABLE_CASCADE_1'
+
+ PIXART_ALPHA = 'PIXART_ALPHA'
+ PIXART_SIGMA = 'PIXART_SIGMA'
+
+ FLUX_DEV_1 = 'FLUX_DEV_1'
+ FLUX_FILL_DEV_1 = 'FLUX_FILL_DEV_1'
+
+ SANA = 'SANA'
+
+ HUNYUAN_VIDEO = 'HUNYUAN_VIDEO'
+
+ HI_DREAM_FULL = 'HI_DREAM_FULL'
+
+ CHROMA_1 = 'CHROMA_1'
+
+ QWEN = 'QWEN'
+
+ def __str__(self):
+ return self.value
+
+ def is_stable_diffusion(self):
+ return self == ModelType.STABLE_DIFFUSION_15 \
+ or self == ModelType.STABLE_DIFFUSION_15_INPAINTING \
+ or self == ModelType.STABLE_DIFFUSION_20 \
+ or self == ModelType.STABLE_DIFFUSION_20_BASE \
+ or self == ModelType.STABLE_DIFFUSION_20_INPAINTING \
+ or self == ModelType.STABLE_DIFFUSION_20_DEPTH \
+ or self == ModelType.STABLE_DIFFUSION_21 \
+ or self == ModelType.STABLE_DIFFUSION_21_BASE
+
+ def is_stable_diffusion_xl(self):
+ return self == ModelType.STABLE_DIFFUSION_XL_10_BASE \
+ or self == ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING
+
+ def is_stable_diffusion_3(self):
+ return self == ModelType.STABLE_DIFFUSION_3 \
+ or self == ModelType.STABLE_DIFFUSION_35
+
+ def is_stable_diffusion_3_5(self):
+ return self == ModelType.STABLE_DIFFUSION_35
+
+ def is_wuerstchen(self):
+ return self == ModelType.WUERSTCHEN_2 \
+ or self == ModelType.STABLE_CASCADE_1
+
+ def is_pixart(self):
+ return self == ModelType.PIXART_ALPHA \
+ or self == ModelType.PIXART_SIGMA
+
+ def is_pixart_alpha(self):
+ return self == ModelType.PIXART_ALPHA
+
+ def is_pixart_sigma(self):
+ return self == ModelType.PIXART_SIGMA
+
+ def is_flux(self):
+ return self == ModelType.FLUX_DEV_1 \
+ or self == ModelType.FLUX_FILL_DEV_1
+
+ def is_chroma(self):
+ return self == ModelType.CHROMA_1
+
+ def is_qwen(self):
+ return self == ModelType.QWEN
+
+ def is_sana(self):
+ return self == ModelType.SANA
+
+ def is_hunyuan_video(self):
+ return self == ModelType.HUNYUAN_VIDEO
+
+ def is_hi_dream(self):
+ return self == ModelType.HI_DREAM_FULL
+
+ def has_mask_input(self) -> bool:
+ return self == ModelType.STABLE_DIFFUSION_15_INPAINTING \
+ or self == ModelType.STABLE_DIFFUSION_20_INPAINTING \
+ or self == ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING \
+ or self == ModelType.FLUX_FILL_DEV_1
+
+ def has_conditioning_image_input(self) -> bool:
+ return self == ModelType.STABLE_DIFFUSION_15_INPAINTING \
+ or self == ModelType.STABLE_DIFFUSION_20_INPAINTING \
+ or self == ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING \
+ or self == ModelType.FLUX_FILL_DEV_1
+
+ def has_depth_input(self):
+ return self == ModelType.STABLE_DIFFUSION_20_DEPTH
+
+ def has_multiple_text_encoders(self):
+ return self.is_stable_diffusion_3() \
+ or self.is_stable_diffusion_xl() \
+ or self.is_flux() \
+ or self.is_hunyuan_video() \
+ or self.is_hi_dream() \
+
+ def is_sd_v1(self):
+ return self == ModelType.STABLE_DIFFUSION_15 \
+ or self == ModelType.STABLE_DIFFUSION_15_INPAINTING
+
+ def is_sd_v2(self):
+ return self == ModelType.STABLE_DIFFUSION_20 \
+ or self == ModelType.STABLE_DIFFUSION_20_BASE \
+ or self == ModelType.STABLE_DIFFUSION_20_INPAINTING \
+ or self == ModelType.STABLE_DIFFUSION_20_DEPTH \
+ or self == ModelType.STABLE_DIFFUSION_21 \
+ or self == ModelType.STABLE_DIFFUSION_21_BASE
+
+ def is_wuerstchen_v2(self):
+ return self == ModelType.WUERSTCHEN_2
+
+ def is_stable_cascade(self):
+ return self == ModelType.STABLE_CASCADE_1
+
+ def is_flow_matching(self) -> bool:
+ return self.is_stable_diffusion_3() \
+ or self.is_flux() \
+ or self.is_chroma() \
+ or self.is_qwen() \
+ or self.is_sana() \
+ or self.is_hunyuan_video() \
+ or self.is_hi_dream()
+
+
+class PeftType(Enum):
+ LORA = 'LORA'
+ LOHA = 'LOHA'
+ OFT_2 = 'OFT_2'
+
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/NoiseScheduler.py b/modules/util/enum.legacy/NoiseScheduler.py
new file mode 100644
index 000000000..62d9af9dd
--- /dev/null
+++ b/modules/util/enum.legacy/NoiseScheduler.py
@@ -0,0 +1,19 @@
+from enum import Enum
+
+
+class NoiseScheduler(Enum):
+ DDIM = 'DDIM'
+
+ EULER = 'EULER'
+ EULER_A = 'EULER_A'
+ DPMPP = 'DPMPP'
+ DPMPP_SDE = 'DPMPP_SDE'
+ UNIPC = 'UNIPC'
+
+ EULER_KARRAS = 'EULER_KARRAS'
+ DPMPP_KARRAS = 'DPMPP_KARRAS'
+ DPMPP_SDE_KARRAS = 'DPMPP_SDE_KARRAS'
+ UNIPC_KARRAS = 'UNIPC_KARRAS'
+
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/Optimizer.py b/modules/util/enum.legacy/Optimizer.py
new file mode 100644
index 000000000..c5f0c1089
--- /dev/null
+++ b/modules/util/enum.legacy/Optimizer.py
@@ -0,0 +1,126 @@
+from enum import Enum
+
+import torch
+
+
+class Optimizer(Enum):
+ # Sorted by origin (BNB / torch first, then DADAPT), then by adapter name, then interleaved by variant.
+
+ # BNB Standard & 8-bit
+ ADAGRAD = 'ADAGRAD'
+ ADAGRAD_8BIT = 'ADAGRAD_8BIT'
+
+ # 32 bit is torch and not bnb
+ ADAM = 'ADAM'
+ ADAM_8BIT = 'ADAM_8BIT'
+
+ # 32 bit is torch and not bnb
+ ADAMW = 'ADAMW'
+ ADAMW_8BIT = 'ADAMW_8BIT'
+ ADAMW_ADV = 'ADAMW_ADV'
+
+ AdEMAMix = 'AdEMAMix'
+ AdEMAMix_8BIT = "AdEMAMix_8BIT"
+ SIMPLIFIED_AdEMAMix = "SIMPLIFIED_AdEMAMix"
+
+ ADOPT = 'ADOPT'
+ ADOPT_ADV = 'ADOPT_ADV'
+
+ LAMB = 'LAMB'
+ LAMB_8BIT = 'LAMB_8BIT'
+
+ LARS = 'LARS'
+ LARS_8BIT = 'LARS_8BIT'
+
+ LION = 'LION'
+ LION_8BIT = 'LION_8BIT'
+ LION_ADV = 'LION_ADV'
+
+ RMSPROP = 'RMSPROP'
+ RMSPROP_8BIT = 'RMSPROP_8BIT'
+
+ # 32 bit is torch and not bnb
+ SGD = 'SGD'
+ SGD_8BIT = 'SGD_8BIT'
+
+ # Schedule-free optimizers
+ SCHEDULE_FREE_ADAMW = 'SCHEDULE_FREE_ADAMW'
+ SCHEDULE_FREE_SGD = 'SCHEDULE_FREE_SGD'
+
+ # DADAPT
+ DADAPT_ADA_GRAD = 'DADAPT_ADA_GRAD'
+ DADAPT_ADAM = 'DADAPT_ADAM'
+ DADAPT_ADAN = 'DADAPT_ADAN'
+ DADAPT_LION = 'DADAPT_LION'
+ DADAPT_SGD = 'DADAPT_SGD'
+
+ # Prodigy
+ PRODIGY = 'PRODIGY'
+ PRODIGY_PLUS_SCHEDULE_FREE = 'PRODIGY_PLUS_SCHEDULE_FREE'
+ PRODIGY_ADV = 'PRODIGY_ADV'
+ LION_PRODIGY_ADV = 'LION_PRODIGY_ADV'
+
+ # ADAFACTOR
+ ADAFACTOR = 'ADAFACTOR'
+
+ # CAME
+ CAME = 'CAME'
+ CAME_8BIT = 'CAME_8BIT'
+
+ #Pytorch Optimizers
+ ADABELIEF = 'ADABELIEF'
+ TIGER = 'TIGER'
+ AIDA = 'AIDA'
+ YOGI = 'YOGI'
+
+ @property
+ def is_adaptive(self):
+ return self in [
+ self.DADAPT_SGD,
+ self.DADAPT_ADAM,
+ self.DADAPT_ADAN,
+ self.DADAPT_ADA_GRAD,
+ self.DADAPT_LION,
+ self.PRODIGY,
+ self.PRODIGY_PLUS_SCHEDULE_FREE,
+ self.PRODIGY_ADV,
+ self.LION_PRODIGY_ADV,
+ ]
+
+ @property
+ def is_schedule_free(self):
+ return self in [
+ self.SCHEDULE_FREE_ADAMW,
+ self.SCHEDULE_FREE_SGD,
+ self.PRODIGY_PLUS_SCHEDULE_FREE,
+ ]
+
+ def supports_fused_back_pass(self):
+ return self in [
+ Optimizer.ADAFACTOR,
+ Optimizer.CAME,
+ Optimizer.CAME_8BIT,
+ Optimizer.ADAM,
+ Optimizer.ADAMW,
+ Optimizer.ADAMW_ADV,
+ Optimizer.ADOPT_ADV,
+ Optimizer.SIMPLIFIED_AdEMAMix,
+ Optimizer.PRODIGY_PLUS_SCHEDULE_FREE,
+ Optimizer.PRODIGY_ADV,
+ Optimizer.LION_ADV,
+ Optimizer.LION_PRODIGY_ADV,
+ ]
+
+ # Small helper for adjusting learning rates to adaptive optimizers.
+ def maybe_adjust_lrs(self, lrs: dict[str, float], optimizer: torch.optim.Optimizer):
+ if self.is_adaptive:
+ return {
+ # Return `effective_lr * d` if "effective_lr" key present, otherwise return `lr * d`
+ key: (optimizer.param_groups[i].get("effective_lr", lr) * optimizer.param_groups[i].get("d", 1.0)
+ if lr is not None else None)
+ for i, (key, lr) in enumerate(lrs.items())
+ }
+ return lrs
+
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/TimeUnit.py b/modules/util/enum.legacy/TimeUnit.py
new file mode 100644
index 000000000..c06922d72
--- /dev/null
+++ b/modules/util/enum.legacy/TimeUnit.py
@@ -0,0 +1,20 @@
+from enum import Enum
+
+
+class TimeUnit(Enum):
+ EPOCH = 'EPOCH'
+ STEP = 'STEP'
+ SECOND = 'SECOND'
+ MINUTE = 'MINUTE'
+ HOUR = 'HOUR'
+
+ NEVER = 'NEVER'
+ ALWAYS = 'ALWAYS'
+
+ def __str__(self):
+ return self.value
+
+ def is_time_unit(self) -> bool:
+ return self == TimeUnit.SECOND \
+ or self == TimeUnit.MINUTE \
+ or self == TimeUnit.HOUR
diff --git a/modules/util/enum.legacy/TimestepDistribution.py b/modules/util/enum.legacy/TimestepDistribution.py
new file mode 100644
index 000000000..55d8efd02
--- /dev/null
+++ b/modules/util/enum.legacy/TimestepDistribution.py
@@ -0,0 +1,13 @@
+from enum import Enum
+
+
+class TimestepDistribution(Enum):
+ UNIFORM = 'UNIFORM'
+ SIGMOID = 'SIGMOID'
+ LOGIT_NORMAL = 'LOGIT_NORMAL'
+ HEAVY_TAIL = 'HEAVY_TAIL'
+ COS_MAP = 'COS_MAP'
+ INVERTED_PARABOLA = 'INVERTED_PARABOLA'
+
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/TrainingMethod.py b/modules/util/enum.legacy/TrainingMethod.py
new file mode 100644
index 000000000..403a1ec10
--- /dev/null
+++ b/modules/util/enum.legacy/TrainingMethod.py
@@ -0,0 +1,11 @@
+from enum import Enum
+
+
+class TrainingMethod(Enum):
+ FINE_TUNE = 'FINE_TUNE'
+ LORA = 'LORA'
+ EMBEDDING = 'EMBEDDING'
+ FINE_TUNE_VAE = 'FINE_TUNE_VAE'
+
+ def __str__(self):
+ return self.value
diff --git a/modules/util/enum.legacy/VideoFormat.py b/modules/util/enum.legacy/VideoFormat.py
new file mode 100644
index 000000000..67d8e8242
--- /dev/null
+++ b/modules/util/enum.legacy/VideoFormat.py
@@ -0,0 +1,30 @@
+from enum import Enum
+
+
+class VideoFormat(Enum):
+ PNG_IMAGE_SEQUENCE = 'PNG_IMAGE_SEQUENCE'
+ JPG_IMAGE_SEQUENCE = 'JPG_IMAGE_SEQUENCE'
+ MP4 = 'MP4'
+
+ def __str__(self):
+ return self.value
+
+ def extension(self) -> str:
+ match self:
+ case VideoFormat.PNG_IMAGE_SEQUENCE:
+ return ".png"
+ case VideoFormat.JPG_IMAGE_SEQUENCE:
+ return ".jpg"
+ case VideoFormat.MP4:
+ return ".mp4"
+ case _:
+ return ""
+
+ def pil_format(self) -> str:
+ match self:
+ case VideoFormat.PNG_IMAGE_SEQUENCE:
+ return "PNG"
+ case VideoFormat.JPG_IMAGE_SEQUENCE:
+ return "JPEG"
+ case _:
+ return ""
diff --git a/modules/util/enum/AudioFormat.py b/modules/util/enum/AudioFormat.py
index abdf7520c..75914f87d 100644
--- a/modules/util/enum/AudioFormat.py
+++ b/modules/util/enum/AudioFormat.py
@@ -1,12 +1,9 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class AudioFormat(Enum):
+class AudioFormat(BaseEnum):
MP3 = 'MP3'
- def __str__(self):
- return self.value
-
def extension(self) -> str:
match self:
case AudioFormat.MP3:
diff --git a/modules/util/enum/BalancingStrategy.py b/modules/util/enum/BalancingStrategy.py
index 4a247f64d..383bb00d9 100644
--- a/modules/util/enum/BalancingStrategy.py
+++ b/modules/util/enum/BalancingStrategy.py
@@ -1,9 +1,6 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class BalancingStrategy(Enum):
+class BalancingStrategy(BaseEnum):
REPEATS = 'REPEATS'
SAMPLES = 'SAMPLES'
-
- def __str__(self):
- return self.value
diff --git a/modules/util/enum/BaseEnum.py b/modules/util/enum/BaseEnum.py
new file mode 100644
index 000000000..7f9489341
--- /dev/null
+++ b/modules/util/enum/BaseEnum.py
@@ -0,0 +1,18 @@
+from enum import Enum
+
+
+class BaseEnum(Enum):
+ def __str__(self):
+ return self.value
+
+ def pretty_print(self):
+ # TODO: do we want this method to use translatable strings? If so, how to avoid introducing an undesirable QT dependency in modules.util.enum?
+ return self.value.replace("_", " ").title()
+
+ @staticmethod
+ def is_enabled(value, context=None):
+ return True
+
+ @classmethod
+ def enabled_values(cls, context=None):
+ return [v for v in cls if cls.is_enabled(v, context)]
diff --git a/modules/util/enum/BulkEditMode.py b/modules/util/enum/BulkEditMode.py
new file mode 100644
index 000000000..67dc9db41
--- /dev/null
+++ b/modules/util/enum/BulkEditMode.py
@@ -0,0 +1,6 @@
+from modules.util.enum.BaseEnum import BaseEnum
+
+
+class BulkEditMode(BaseEnum):
+ PREPEND = 'PREPEND'
+ APPEND = 'APPEND'
diff --git a/modules/util/enum/CaptionFilter.py b/modules/util/enum/CaptionFilter.py
new file mode 100644
index 000000000..824f15528
--- /dev/null
+++ b/modules/util/enum/CaptionFilter.py
@@ -0,0 +1,8 @@
+from modules.util.enum.BaseEnum import BaseEnum
+
+
+class CaptionFilter(BaseEnum):
+ CONTAINS = "contains"
+ MATCHES = "matches"
+ EXCLUDES = "excludes"
+ REGEX = "regex"
diff --git a/modules/util/enum/CloudAction.py b/modules/util/enum/CloudAction.py
index a4cb639bc..1272eea28 100644
--- a/modules/util/enum/CloudAction.py
+++ b/modules/util/enum/CloudAction.py
@@ -1,9 +1,7 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class CloudAction(Enum):
+class CloudAction(BaseEnum):
NONE = 'NONE'
STOP = 'STOP'
DELETE = 'DELETE'
- def __str__(self):
- return self.value
diff --git a/modules/util/enum/CloudFileSync.py b/modules/util/enum/CloudFileSync.py
index dcb008f1d..9d496b47e 100644
--- a/modules/util/enum/CloudFileSync.py
+++ b/modules/util/enum/CloudFileSync.py
@@ -1,8 +1,6 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class CloudFileSync(Enum):
+class CloudFileSync(BaseEnum):
FABRIC_SFTP = 'FABRIC_SFTP'
NATIVE_SCP = 'NATIVE_SCP'
- def __str__(self):
- return self.value
diff --git a/modules/util/enum/CloudSubtype.py b/modules/util/enum/CloudSubtype.py
new file mode 100644
index 000000000..76b1e85d0
--- /dev/null
+++ b/modules/util/enum/CloudSubtype.py
@@ -0,0 +1,7 @@
+from modules.util.enum.BaseEnum import BaseEnum
+
+
+class CloudSubtype(BaseEnum):
+ NONE = ""
+ COMMUNITY = "COMMUNITY"
+ SECURE = "SECURE"
diff --git a/modules/util/enum/CloudType.py b/modules/util/enum/CloudType.py
index d20f61a07..832094bcf 100644
--- a/modules/util/enum/CloudType.py
+++ b/modules/util/enum/CloudType.py
@@ -1,8 +1,6 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class CloudType(Enum):
+class CloudType(BaseEnum):
RUNPOD = 'RUNPOD'
LINUX = 'LINUX'
- def __str__(self):
- return self.value
diff --git a/modules/util/enum/ConceptType.py b/modules/util/enum/ConceptType.py
index 4bda9f49c..2c1f2cee5 100644
--- a/modules/util/enum/ConceptType.py
+++ b/modules/util/enum/ConceptType.py
@@ -1,10 +1,17 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class ConceptType(Enum):
+class ConceptType(BaseEnum):
+ ALL = 'ALL'
STANDARD = 'STANDARD'
VALIDATION = 'VALIDATION'
PRIOR_PREDICTION = 'PRIOR_PREDICTION'
- def __str__(self):
- return self.value
+ @staticmethod
+ def is_enabled(value, context=None):
+ if context == "all":
+ return True
+ elif context == "prior_pred_enabled":
+ return value in [ConceptType.STANDARD, ConceptType.VALIDATION, ConceptType.PRIOR_PREDICTION]
+ else: # prior_pred_disabled
+ return value in [ConceptType.STANDARD, ConceptType.VALIDATION]
diff --git a/modules/util/enum/ConfigPart.py b/modules/util/enum/ConfigPart.py
index 271e0ba1e..b591121c6 100644
--- a/modules/util/enum/ConfigPart.py
+++ b/modules/util/enum/ConfigPart.py
@@ -1,10 +1,7 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class ConfigPart(Enum):
+class ConfigPart(BaseEnum):
NONE = 'NONE'
SETTINGS = 'SETTINGS'
ALL = 'ALL'
-
- def __str__(self):
- return self.value
diff --git a/modules/util/enum/DataType.py b/modules/util/enum/DataType.py
index dd15e01d8..2bd7bd90f 100644
--- a/modules/util/enum/DataType.py
+++ b/modules/util/enum/DataType.py
@@ -1,9 +1,9 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
import torch
-class DataType(Enum):
+class DataType(BaseEnum):
NONE = 'NONE'
FLOAT_8 = 'FLOAT_8'
FLOAT_16 = 'FLOAT_16'
@@ -14,8 +14,58 @@ class DataType(Enum):
NFLOAT_4 = 'NFLOAT_4'
GGUF = 'GGUF'
- def __str__(self):
- return self.value
+ def pretty_print(self):
+ return {
+ DataType.NONE: '',
+ DataType.FLOAT_8: 'Float8',
+ DataType.FLOAT_16: 'Float16',
+ DataType.FLOAT_32: 'Float32',
+ DataType.BFLOAT_16: 'BFloat16',
+ DataType.TFLOAT_32: 'TFloat32',
+ DataType.INT_8: 'Int8',
+ DataType.NFLOAT_4: 'NFloat4',
+ DataType.GGUF: 'GGUF'
+ }[self]
+
+ @staticmethod
+ def is_enabled(value, context=None):
+ if context == "embeddings" or context == "lora":
+ return value in [DataType.FLOAT_32, DataType.BFLOAT_16]
+ elif context == "convert_window":
+ return value in [DataType.FLOAT_32, DataType.FLOAT_16, DataType.BFLOAT_16]
+ elif context == "training_dtype":
+ return value in [DataType.FLOAT_32, DataType.FLOAT_16, DataType.BFLOAT_16, DataType.TFLOAT_32]
+ elif context == "training_fallback":
+ return value in [DataType.FLOAT_32, DataType.BFLOAT_16]
+ elif context == "output_dtype":
+ return value in [
+ DataType.FLOAT_16,
+ DataType.FLOAT_32,
+ DataType.BFLOAT_16,
+ DataType.FLOAT_8,
+ DataType.NFLOAT_4
+ ]
+ elif context == "transformer_dtype":
+ return value in [
+ DataType.FLOAT_32,
+ DataType.BFLOAT_16,
+ DataType.FLOAT_16,
+ DataType.FLOAT_8,
+ # DataType.INT_8, # TODO: reactivate when the int8 implementation is fixed in bitsandbytes: https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1332
+ DataType.NFLOAT_4,
+ DataType.GGUF
+ ]
+ else: # model_dtypes
+ return value in [
+ DataType.FLOAT_32,
+ DataType.BFLOAT_16,
+ DataType.FLOAT_16,
+ DataType.FLOAT_8,
+ # DataType.INT_8, # TODO: reactivate when the int8 implementation is fixed in bitsandbytes: https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1332
+ DataType.NFLOAT_4,
+ ]
+
+ return True
def torch_dtype(
self,
diff --git a/modules/util/enum/DropoutMode.py b/modules/util/enum/DropoutMode.py
new file mode 100644
index 000000000..a14193665
--- /dev/null
+++ b/modules/util/enum/DropoutMode.py
@@ -0,0 +1,7 @@
+from modules.util.enum.BaseEnum import BaseEnum
+
+
+class DropoutMode(BaseEnum):
+ FULL = 'FULL'
+ RANDOM = 'RANDOM'
+ RANDOM_WEIGHTED = 'RANDOM WEIGHTED'
diff --git a/modules/util/enum/EMAMode.py b/modules/util/enum/EMAMode.py
index 2742c445d..ad8aa2689 100644
--- a/modules/util/enum/EMAMode.py
+++ b/modules/util/enum/EMAMode.py
@@ -1,10 +1,14 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class EMAMode(Enum):
+class EMAMode(BaseEnum):
OFF = 'OFF'
GPU = 'GPU'
CPU = 'CPU'
- def __str__(self):
- return self.value
+ def pretty_print(self):
+ return {
+ EMAMode.OFF: "Off",
+ EMAMode.GPU: "GPU",
+ EMAMode.CPU: "CPU",
+ }[self]
diff --git a/modules/util/enum/EditMode.py b/modules/util/enum/EditMode.py
new file mode 100644
index 000000000..b2590929b
--- /dev/null
+++ b/modules/util/enum/EditMode.py
@@ -0,0 +1,9 @@
+from modules.util.enum.BaseEnum import BaseEnum
+
+
+class EditMode(BaseEnum):
+ NONE = "none"
+ DRAW = "draw"
+ FILL = "fill"
+ ZOOM = "zoom"
+ PAN = "pan"
diff --git a/modules/util/enum/FileFilter.py b/modules/util/enum/FileFilter.py
new file mode 100644
index 000000000..b78834537
--- /dev/null
+++ b/modules/util/enum/FileFilter.py
@@ -0,0 +1,7 @@
+from modules.util.enum.BaseEnum import BaseEnum
+
+
+class FileFilter(BaseEnum):
+ FILE = "file"
+ PATH = "path"
+ BOTH = "both"
diff --git a/modules/util/enum/FileType.py b/modules/util/enum/FileType.py
index e0fa4fe81..93a890486 100644
--- a/modules/util/enum/FileType.py
+++ b/modules/util/enum/FileType.py
@@ -1,10 +1,7 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class FileType(Enum):
+class FileType(BaseEnum):
IMAGE = 'IMAGE'
VIDEO = 'VIDEO'
AUDIO = 'AUDIO'
-
- def __str__(self):
- return self.value
diff --git a/modules/util/enum/GenerateCaptionsModel.py b/modules/util/enum/GenerateCaptionsModel.py
index 02da82d05..f0ada64d6 100644
--- a/modules/util/enum/GenerateCaptionsModel.py
+++ b/modules/util/enum/GenerateCaptionsModel.py
@@ -1,10 +1,27 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class GenerateCaptionsModel(Enum):
+class GenerateCaptionsModel(BaseEnum):
BLIP = 'BLIP'
BLIP2 = 'BLIP2'
WD14_VIT_2 = 'WD14_VIT_2'
- def __str__(self):
- return self.value
+ def pretty_print(self):
+ return {
+ GenerateCaptionsModel.BLIP: "BLIP",
+ GenerateCaptionsModel.BLIP2: "BLIP-2",
+ GenerateCaptionsModel.WD14_VIT_2: "WD 1.4 ViT Tagger V2",
+ }[self]
+
+
+class GenerateCaptionsAction(BaseEnum):
+ REPLACE = 'REPLACE'
+ CREATE = 'CREATE'
+ ADD = 'ADD'
+
+ def pretty_print(self):
+ return {
+ GenerateCaptionsAction.REPLACE: 'Replace all captions',
+ GenerateCaptionsAction.CREATE: 'Create if absent',
+ GenerateCaptionsAction.ADD: 'Add as new line'
+ }[self]
diff --git a/modules/util/enum/GenerateMasksModel.py b/modules/util/enum/GenerateMasksModel.py
index a81653754..b3a2888f5 100644
--- a/modules/util/enum/GenerateMasksModel.py
+++ b/modules/util/enum/GenerateMasksModel.py
@@ -1,11 +1,32 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class GenerateMasksModel(Enum):
+class GenerateMasksModel(BaseEnum):
CLIPSEG = 'CLIPSEG'
REMBG = 'REMBG'
REMBG_HUMAN = 'REMBG_HUMAN'
COLOR = 'COLOR'
- def __str__(self):
- return self.value
+ def pretty_print(self):
+ return {
+ GenerateMasksModel.CLIPSEG: "CLIPSeg",
+ GenerateMasksModel.REMBG: "RemBG",
+ GenerateMasksModel.REMBG_HUMAN: "RemBG-Human",
+ GenerateMasksModel.COLOR: "Hex Color"
+ }[self]
+
+class GenerateMasksAction(BaseEnum):
+ REPLACE = 'REPLACE'
+ FILL = 'FILL'
+ ADD = 'ADD'
+ SUBTRACT = 'SUBTRACT'
+ BLEND = 'BLEND'
+
+ def pretty_print(self):
+ return {
+ GenerateMasksAction.REPLACE: 'Replace all masks',
+ GenerateMasksAction.FILL: 'Create if absent',
+ GenerateMasksAction.ADD: 'Add to existing',
+ GenerateMasksAction.SUBTRACT: 'Subtract from existing',
+ GenerateMasksAction.BLEND: 'Blend with existing'
+ }[self]
diff --git a/modules/util/enum/GradientCheckpointingMethod.py b/modules/util/enum/GradientCheckpointingMethod.py
index d3f05666a..64d52616f 100644
--- a/modules/util/enum/GradientCheckpointingMethod.py
+++ b/modules/util/enum/GradientCheckpointingMethod.py
@@ -1,13 +1,17 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class GradientCheckpointingMethod(Enum):
+class GradientCheckpointingMethod(BaseEnum):
OFF = 'OFF'
ON = 'ON'
CPU_OFFLOADED = 'CPU_OFFLOADED'
- def __str__(self):
- return self.value
+ def pretty_print(self):
+ return {
+ GradientCheckpointingMethod.OFF: "Off",
+ GradientCheckpointingMethod.ON: "On",
+ GradientCheckpointingMethod.CPU_OFFLOADED: "CPU Offloaded",
+ }[self]
def enabled(self):
return self == GradientCheckpointingMethod.ON \
diff --git a/modules/util/enum/GradientReducePrecision.py b/modules/util/enum/GradientReducePrecision.py
index 04eb2e526..2f5a1eb70 100644
--- a/modules/util/enum/GradientReducePrecision.py
+++ b/modules/util/enum/GradientReducePrecision.py
@@ -1,14 +1,22 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
import torch
-class GradientReducePrecision(Enum):
+class GradientReducePrecision(BaseEnum):
WEIGHT_DTYPE = 'WEIGHT_DTYPE'
FLOAT_32 = 'FLOAT_32'
WEIGHT_DTYPE_STOCHASTIC = 'WEIGHT_DTYPE_STOCHASTIC'
FLOAT_32_STOCHASTIC = 'FLOAT_32_STOCHASTIC'
+ def pretty_print(self):
+ return {
+ GradientReducePrecision.WEIGHT_DTYPE: "Weight DType",
+ GradientReducePrecision.FLOAT_32: "Float32",
+ GradientReducePrecision.WEIGHT_DTYPE_STOCHASTIC: "Weight DType Stochastic",
+ GradientReducePrecision.FLOAT_32_STOCHASTIC: "Float32 Stochastic"
+ }[self]
+
def torch_dtype(self, weight_dtype: torch.dtype) -> torch.dtype:
match self:
case GradientReducePrecision.WEIGHT_DTYPE:
@@ -34,6 +42,3 @@ def stochastic_rounding(self, weight_dtype: torch.dtype) -> bool:
return weight_dtype == torch.bfloat16
case _:
raise ValueError
-
- def __str__(self):
- return self.value
diff --git a/modules/util/enum/ImageFormat.py b/modules/util/enum/ImageFormat.py
index 15b6183dc..1393bd9a6 100644
--- a/modules/util/enum/ImageFormat.py
+++ b/modules/util/enum/ImageFormat.py
@@ -1,11 +1,12 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class ImageFormat(Enum):
+class ImageFormat(BaseEnum):
PNG = 'PNG'
JPG = 'JPG'
- def __str__(self):
+
+ def pretty_print(self):
return self.value
def extension(self) -> str:
diff --git a/modules/util/enum/ImageMegapixels.py b/modules/util/enum/ImageMegapixels.py
new file mode 100644
index 000000000..4a4598e8c
--- /dev/null
+++ b/modules/util/enum/ImageMegapixels.py
@@ -0,0 +1,21 @@
+from modules.util.enum.BaseEnum import BaseEnum
+
+
+class ImageMegapixels(BaseEnum):
+ ONE_MEGAPIXEL = 1_048_576
+ COMPUTE_PROOF_MEGAPIXEL_THRESHOLD = 4_194_304
+ MIDDLEGROUND_MEGAPIXEL_THRESHOLD = 8_388_608
+ FUTURE_PROOF_MEGAPIXEL_THRESHOLD = 16_777_216
+ CUSTOM = -1
+
+ def __str__(self):
+ return str(self.value)
+
+ def pretty_print(self):
+ return {
+ ImageMegapixels.ONE_MEGAPIXEL: "1MP",
+ ImageMegapixels.COMPUTE_PROOF_MEGAPIXEL_THRESHOLD: "Compute Proof (4MP)",
+ ImageMegapixels.MIDDLEGROUND_MEGAPIXEL_THRESHOLD: "Middleground (8MP)",
+ ImageMegapixels.FUTURE_PROOF_MEGAPIXEL_THRESHOLD: "Future Proof (16MP)",
+ ImageMegapixels.CUSTOM: "Custom",
+ }[self]
diff --git a/modules/util/enum/ImageOperations.py b/modules/util/enum/ImageOperations.py
new file mode 100644
index 000000000..983ea60ae
--- /dev/null
+++ b/modules/util/enum/ImageOperations.py
@@ -0,0 +1,24 @@
+from modules.util.enum.BaseEnum import BaseEnum
+
+
+class ImageOperations(BaseEnum):
+ NONE = "none"
+ VERIFY_IMG = "verify_img"
+ SEQUENTIAL_RENAME = "sequential_rename"
+ PROCESS_ALPHA = "process_alpha"
+ RESIZE_LARGE_IMG = "resize_large_image"
+ OPTIMIZE_PNG = "optimize_png"
+ CONVERT_WEBP = "convert_webp"
+ CONVERT_JXL = "convert_jxl"
+
+ def pretty_print(self):
+ return {
+ ImageOperations.NONE: "No operation",
+ ImageOperations.VERIFY_IMG: "Verifying images",
+ ImageOperations.SEQUENTIAL_RENAME: "Sequential renaming",
+ ImageOperations.PROCESS_ALPHA: "Processing transparent images",
+ ImageOperations.RESIZE_LARGE_IMG: "Resizing large images",
+ ImageOperations.OPTIMIZE_PNG: "Optimizing PNGs",
+ ImageOperations.CONVERT_WEBP: "Converting to WebP",
+ ImageOperations.CONVERT_JXL: "Converting to JPEG XL",
+ }[self]
diff --git a/modules/util/enum/ImageOptimization.py b/modules/util/enum/ImageOptimization.py
new file mode 100644
index 000000000..405a95020
--- /dev/null
+++ b/modules/util/enum/ImageOptimization.py
@@ -0,0 +1,16 @@
+from modules.util.enum.BaseEnum import BaseEnum
+
+
+class ImageOptimization(BaseEnum):
+ NONE = "none"
+ PNG = "png"
+ WEBP = "webp"
+ JXL = "jxl"
+
+ def pretty_print(self):
+ return {
+ ImageOptimization.NONE: "None",
+ ImageOptimization.PNG: "Optimize PNGs",
+ ImageOptimization.WEBP: "Convert to WebP",
+ ImageOptimization.JXL: "Convert to JPEG XL",
+ }[self]
diff --git a/modules/util/enum/LearningRateScaler.py b/modules/util/enum/LearningRateScaler.py
index d8ce19ec5..c55667233 100644
--- a/modules/util/enum/LearningRateScaler.py
+++ b/modules/util/enum/LearningRateScaler.py
@@ -1,9 +1,8 @@
-from enum import Enum
-
import modules.util.multi_gpu_util as multi
+from modules.util.enum.BaseEnum import BaseEnum
-class LearningRateScaler(Enum):
+class LearningRateScaler(BaseEnum):
NONE = 'NONE'
BATCH = 'BATCH'
GLOBAL_BATCH = 'GLOBAL_BATCH'
@@ -11,9 +10,6 @@ class LearningRateScaler(Enum):
BOTH = 'BOTH'
GLOBAL_BOTH = 'GLOBAL_BOTH'
- def __str__(self):
- return self.value
-
def get_scale(self, batch_size: int, accumulation_steps: int) -> int:
match self:
case LearningRateScaler.NONE:
diff --git a/modules/util/enum/LearningRateScheduler.py b/modules/util/enum/LearningRateScheduler.py
index 3e6292103..67d8e0421 100644
--- a/modules/util/enum/LearningRateScheduler.py
+++ b/modules/util/enum/LearningRateScheduler.py
@@ -1,7 +1,7 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class LearningRateScheduler(Enum):
+class LearningRateScheduler(BaseEnum):
CONSTANT = 'CONSTANT'
LINEAR = 'LINEAR'
COSINE = 'COSINE'
@@ -10,6 +10,3 @@ class LearningRateScheduler(Enum):
REX = 'REX'
ADAFACTOR = 'ADAFACTOR'
CUSTOM = 'CUSTOM'
-
- def __str__(self):
- return self.value
diff --git a/modules/util/enum/LossScaler.py b/modules/util/enum/LossScaler.py
index 4ac82095e..8bd163f65 100644
--- a/modules/util/enum/LossScaler.py
+++ b/modules/util/enum/LossScaler.py
@@ -1,9 +1,8 @@
-from enum import Enum
-
import modules.util.multi_gpu_util as multi
+from modules.util.enum.BaseEnum import BaseEnum
-class LossScaler(Enum):
+class LossScaler(BaseEnum):
NONE = 'NONE'
BATCH = 'BATCH'
GLOBAL_BATCH = 'GLOBAL_BATCH'
@@ -11,9 +10,6 @@ class LossScaler(Enum):
BOTH = 'BOTH'
GLOBAL_BOTH = 'GLOBAL_BOTH'
- def __str__(self):
- return self.value
-
def get_scale(self, batch_size: int, accumulation_steps: int) -> int:
match self:
case LossScaler.NONE:
diff --git a/modules/util/enum/LossWeight.py b/modules/util/enum/LossWeight.py
index b47ea9ae4..fc41f8b8a 100644
--- a/modules/util/enum/LossWeight.py
+++ b/modules/util/enum/LossWeight.py
@@ -1,7 +1,7 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class LossWeight(Enum):
+class LossWeight(BaseEnum):
CONSTANT = 'CONSTANT'
P2 = 'P2'
MIN_SNR_GAMMA = 'MIN_SNR_GAMMA'
@@ -12,5 +12,18 @@ def supports_flow_matching(self) -> bool:
return self == LossWeight.CONSTANT \
or self == LossWeight.SIGMA
- def __str__(self):
- return self.value
+ def pretty_print(self):
+ return {
+ LossWeight.CONSTANT: 'Constant',
+ LossWeight.P2: 'P2',
+ LossWeight.MIN_SNR_GAMMA: 'Min SNR Gamma',
+ LossWeight.DEBIASED_ESTIMATION: 'Debiased Estimation',
+ LossWeight.SIGMA: 'Sigma',
+ }[self]
+
+ @staticmethod
+ def is_enabled(value, context=None):
+ if context == "flow_matching":
+ return value in [LossWeight.CONSTANT, LossWeight.SIGMA]
+ else:
+ return value in [LossWeight.CONSTANT, LossWeight.P2, LossWeight.MIN_SNR_GAMMA, LossWeight.DEBIASED_ESTIMATION]
diff --git a/modules/util/enum/ModelFlags.py b/modules/util/enum/ModelFlags.py
new file mode 100644
index 000000000..f82cdd73a
--- /dev/null
+++ b/modules/util/enum/ModelFlags.py
@@ -0,0 +1,193 @@
+from enum import Flag, auto
+
+from modules.modelSetup.BaseChromaSetup import PRESETS as chroma_presets
+from modules.modelSetup.BaseFluxSetup import PRESETS as flux_presets
+from modules.modelSetup.BaseHiDreamSetup import PRESETS as hidream_presets
+from modules.modelSetup.BaseHunyuanVideoSetup import PRESETS as hunyuan_video_presets
+from modules.modelSetup.BasePixArtAlphaSetup import PRESETS as pixart_presets
+from modules.modelSetup.BaseQwenSetup import PRESETS as qwen_presets
+from modules.modelSetup.BaseSanaSetup import PRESETS as sana_presets
+from modules.modelSetup.BaseStableDiffusion3Setup import PRESETS as sd3_presets
+from modules.modelSetup.BaseStableDiffusionSetup import PRESETS as sd_presets
+from modules.modelSetup.BaseStableDiffusionXLSetup import PRESETS as sdxl_presets
+from modules.modelSetup.BaseWuerstchenSetup import PRESETS as sc_presets
+from modules.util.enum.TrainingMethod import TrainingMethod
+
+
+class ModelFlags(Flag):
+ NONE = 0 # Invalid initial value.
+
+ # Model + training flags.
+ UNET = auto()
+ PRIOR = auto()
+ OVERRIDE_PRIOR = auto()
+ TRANSFORMER = auto()
+ OVERRIDE_TRANSFORMER = auto()
+ OVERRIDE_TE4 = auto()
+ TE1 = auto()
+ TE2 = auto()
+ TE3 = auto()
+ TE4 = auto()
+ VAE = auto()
+
+ # Model-only flags.
+ EFFNET = auto()
+ DEC = auto()
+ DEC_TE = auto()
+ ALLOW_SAFETENSORS = auto()
+ ALLOW_DIFFUSERS = auto()
+ ALLOW_LEGACY_SAFETENSORS = auto()
+
+ # Training-only flags.
+ TRAIN_TRANSFORMER = auto()
+ TRAIN_PRIOR = auto()
+ GENERALIZED_OFFSET_NOISE = auto()
+ TE_INCLUDE = auto()
+ VB_LOSS = auto()
+ GUIDANCE_SCALE = auto()
+ DYNAMIC_TIMESTEP_SHIFTING = auto()
+ DISABLE_FORCE_ATTN_MASK = auto()
+ DISABLE_CLIP_SKIP = auto()
+ VIDEO_TRAINING = auto()
+ DISABLE_TE4_LAYER_SKIP = auto()
+ OVERRIDE_SEQUENCE_LENGTH_TE2 = auto()
+
+ # Training method flags.
+ CAN_TRAIN_EMBEDDING = auto()
+ CAN_FINE_TUNE_VAE = auto()
+
+ @staticmethod
+ def getFlags(model_type, training_method):
+ flags = ModelFlags.NONE
+
+ if model_type.is_stable_diffusion(): # TODO simplify
+ flags = ModelFlags.UNET | ModelFlags.TE1 | ModelFlags.VAE | ModelFlags.ALLOW_SAFETENSORS | ModelFlags.GENERALIZED_OFFSET_NOISE | ModelFlags.EFFNET | ModelFlags.CAN_FINE_TUNE_VAE | ModelFlags.CAN_TRAIN_EMBEDDING
+ if training_method in [TrainingMethod.FINE_TUNE, TrainingMethod.FINE_TUNE_VAE]:
+ flags |= ModelFlags.ALLOW_DIFFUSERS
+ if training_method == TrainingMethod.LORA:
+ flags |= ModelFlags.ALLOW_LEGACY_SAFETENSORS
+
+ elif model_type.is_stable_diffusion_3():
+ flags = ModelFlags.TE1 | ModelFlags.TE2 | ModelFlags.TE3 | ModelFlags.VAE | ModelFlags.ALLOW_SAFETENSORS | ModelFlags.TE_INCLUDE | ModelFlags.TRANSFORMER
+ if training_method == TrainingMethod.FINE_TUNE:
+ flags |= ModelFlags.ALLOW_DIFFUSERS
+ if training_method == TrainingMethod.LORA:
+ flags |= ModelFlags.ALLOW_LEGACY_SAFETENSORS
+
+ elif model_type.is_stable_diffusion_xl():
+ flags = ModelFlags.UNET | ModelFlags.TE1 | ModelFlags.TE2 | ModelFlags.VAE | ModelFlags.ALLOW_SAFETENSORS | ModelFlags.GENERALIZED_OFFSET_NOISE
+ if training_method == TrainingMethod.FINE_TUNE:
+ flags |= ModelFlags.ALLOW_DIFFUSERS
+ if training_method == TrainingMethod.LORA:
+ flags |= ModelFlags.ALLOW_LEGACY_SAFETENSORS
+
+ elif model_type.is_wuerstchen():
+ flags = ModelFlags.PRIOR | ModelFlags.TE1 | ModelFlags.TRAIN_TRANSFORMER | ModelFlags.DEC
+ if model_type.is_stable_cascade():
+ flags |= ModelFlags.OVERRIDE_PRIOR
+ else:
+ flags |= ModelFlags.DEC_TE
+ if training_method == TrainingMethod.FINE_TUNE:
+ flags |= ModelFlags.ALLOW_DIFFUSERS
+ if training_method != TrainingMethod.FINE_TUNE or model_type.is_stable_cascade():
+ flags |= ModelFlags.ALLOW_SAFETENSORS
+ if training_method == TrainingMethod.LORA:
+ flags |= ModelFlags.ALLOW_LEGACY_SAFETENSORS
+
+ elif model_type.is_pixart():
+ flags = ModelFlags.TRANSFORMER | ModelFlags.TE1 | ModelFlags.VAE | ModelFlags.ALLOW_SAFETENSORS | ModelFlags.TRAIN_TRANSFORMER | ModelFlags.VB_LOSS
+ if training_method == TrainingMethod.FINE_TUNE:
+ flags |= ModelFlags.ALLOW_DIFFUSERS
+ if training_method == TrainingMethod.LORA:
+ flags |= ModelFlags.ALLOW_LEGACY_SAFETENSORS
+
+ elif model_type.is_flux():
+ flags = (ModelFlags.OVERRIDE_TRANSFORMER | ModelFlags.TE1 | ModelFlags.TE2 | ModelFlags.VAE | ModelFlags.OVERRIDE_SEQUENCE_LENGTH_TE2 |
+ ModelFlags.ALLOW_SAFETENSORS | ModelFlags.TRANSFORMER | ModelFlags.TE_INCLUDE | ModelFlags.GUIDANCE_SCALE | ModelFlags.DYNAMIC_TIMESTEP_SHIFTING)
+ if training_method == TrainingMethod.FINE_TUNE:
+ flags |= ModelFlags.ALLOW_DIFFUSERS
+ if training_method == TrainingMethod.LORA:
+ flags |= ModelFlags.ALLOW_LEGACY_SAFETENSORS
+
+ elif model_type.is_chroma():
+ flags = (ModelFlags.OVERRIDE_TRANSFORMER | ModelFlags.TE1 | ModelFlags.VAE | ModelFlags.ALLOW_SAFETENSORS |
+ ModelFlags.DISABLE_FORCE_ATTN_MASK | ModelFlags.TRANSFORMER)
+ if training_method == TrainingMethod.FINE_TUNE:
+ flags |= ModelFlags.ALLOW_DIFFUSERS
+ if training_method == TrainingMethod.LORA:
+ flags |= ModelFlags.ALLOW_LEGACY_SAFETENSORS
+
+ elif model_type.is_qwen():
+ flags = (ModelFlags.OVERRIDE_TRANSFORMER | ModelFlags.TE1 | ModelFlags.VAE | ModelFlags.ALLOW_SAFETENSORS |
+ ModelFlags.DISABLE_FORCE_ATTN_MASK | ModelFlags.TRANSFORMER | ModelFlags.DYNAMIC_TIMESTEP_SHIFTING | ModelFlags.DISABLE_CLIP_SKIP)
+ if training_method == TrainingMethod.FINE_TUNE:
+ flags |= ModelFlags.ALLOW_DIFFUSERS
+ if training_method == TrainingMethod.LORA:
+ flags |= ModelFlags.ALLOW_LEGACY_SAFETENSORS
+
+ elif model_type.is_sana():
+ flags = ModelFlags.TRANSFORMER | ModelFlags.TE1 | ModelFlags.VAE | ModelFlags.TRAIN_TRANSFORMER
+ if training_method == TrainingMethod.FINE_TUNE:
+ flags |= ModelFlags.ALLOW_DIFFUSERS
+ else:
+ flags |= ModelFlags.ALLOW_SAFETENSORS
+ if training_method == TrainingMethod.LORA:
+ flags |= ModelFlags.ALLOW_LEGACY_SAFETENSORS
+
+ elif model_type.is_hunyuan_video():
+ flags = (ModelFlags.TE1 | ModelFlags.TE2 | ModelFlags.VAE | ModelFlags.ALLOW_SAFETENSORS |
+ ModelFlags.TE_INCLUDE | ModelFlags.VIDEO_TRAINING | ModelFlags.TRANSFORMER | ModelFlags.GUIDANCE_SCALE)
+ if training_method == TrainingMethod.FINE_TUNE:
+ flags |= ModelFlags.ALLOW_DIFFUSERS
+ if training_method == TrainingMethod.LORA:
+ flags |= ModelFlags.ALLOW_LEGACY_SAFETENSORS
+
+ elif model_type.is_hi_dream():
+ flags = (ModelFlags.OVERRIDE_TE4 | ModelFlags.TE1 | ModelFlags.TE2 | ModelFlags.TE3 | ModelFlags.TE4 | ModelFlags.VAE | ModelFlags.ALLOW_SAFETENSORS |
+ ModelFlags.TRANSFORMER | ModelFlags.VIDEO_TRAINING | ModelFlags.DISABLE_TE4_LAYER_SKIP | ModelFlags.TE_INCLUDE)
+ if training_method == TrainingMethod.FINE_TUNE:
+ flags |= ModelFlags.ALLOW_DIFFUSERS
+ if training_method == TrainingMethod.LORA:
+ flags |= ModelFlags.ALLOW_LEGACY_SAFETENSORS
+
+ if model_type.is_stable_diffusion_3() \
+ or model_type.is_stable_diffusion_xl() \
+ or model_type.is_wuerstchen() \
+ or model_type.is_pixart() \
+ or model_type.is_flux() \
+ or model_type.is_sana() \
+ or model_type.is_hunyuan_video() \
+ or model_type.is_hi_dream() \
+ or model_type.is_chroma():
+ flags |= ModelFlags.CAN_TRAIN_EMBEDDING
+
+ return flags
+
+ @staticmethod
+ def getPresets(model_type):
+ if model_type.is_stable_diffusion(): #TODO simplify
+ presets = sd_presets
+ elif model_type.is_stable_diffusion_xl():
+ presets = sdxl_presets
+ elif model_type.is_stable_diffusion_3():
+ presets = sd3_presets
+ elif model_type.is_wuerstchen():
+ presets = sc_presets
+ elif model_type.is_pixart():
+ presets = pixart_presets
+ elif model_type.is_flux():
+ presets = flux_presets
+ elif model_type.is_qwen():
+ presets = qwen_presets
+ elif model_type.is_chroma():
+ presets = chroma_presets
+ elif model_type.is_sana():
+ presets = sana_presets
+ elif model_type.is_hunyuan_video():
+ presets = hunyuan_video_presets
+ elif model_type.is_hi_dream():
+ presets = hidream_presets
+ else:
+ presets = {"full": []}
+
+ return presets
diff --git a/modules/util/enum/ModelFormat.py b/modules/util/enum/ModelFormat.py
index 597ad4442..1ceb22915 100644
--- a/modules/util/enum/ModelFormat.py
+++ b/modules/util/enum/ModelFormat.py
@@ -1,7 +1,7 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class ModelFormat(Enum):
+class ModelFormat(BaseEnum):
DIFFUSERS = 'DIFFUSERS'
CKPT = 'CKPT'
SAFETENSORS = 'SAFETENSORS'
@@ -9,9 +9,15 @@ class ModelFormat(Enum):
INTERNAL = 'INTERNAL' # an internal format that stores all information to resume training
- def __str__(self):
- return self.value
+ @staticmethod
+ def is_enabled(value, context=None):
+ if context == "convert_window":
+ return value in [ModelFormat.SAFETENSORS, ModelFormat.DIFFUSERS]
+ else: # model tab
+ pass
+ return True
+ # TODO
def file_extension(self) -> str:
match self:
diff --git a/modules/util/enum/ModelType.py b/modules/util/enum/ModelType.py
index cd6fd77bf..a86065bd6 100644
--- a/modules/util/enum/ModelType.py
+++ b/modules/util/enum/ModelType.py
@@ -1,7 +1,7 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class ModelType(Enum):
+class ModelType(BaseEnum):
STABLE_DIFFUSION_15 = 'STABLE_DIFFUSION_15'
STABLE_DIFFUSION_15_INPAINTING = 'STABLE_DIFFUSION_15_INPAINTING'
STABLE_DIFFUSION_20 = 'STABLE_DIFFUSION_20'
@@ -36,8 +36,82 @@ class ModelType(Enum):
QWEN = 'QWEN'
- def __str__(self):
- return self.value
+ def pretty_print(self):
+ return {
+ ModelType.STABLE_DIFFUSION_15: "SD1.5",
+ ModelType.STABLE_DIFFUSION_15_INPAINTING: "SD1.5 Inpainting",
+ ModelType.STABLE_DIFFUSION_20: "SD2.0",
+ #ModelType.STABLE_DIFFUSION_20_BASE: "SD2.0 Base",
+ ModelType.STABLE_DIFFUSION_20_INPAINTING: "SD2.0 Inpainting",
+ #ModelType.STABLE_DIFFUSION_20_DEPTH: "SD2.0 Depth",
+ ModelType.STABLE_DIFFUSION_21: "SD2.1",
+ #ModelType.STABLE_DIFFUSION_21_BASE: "SD2.1 Base",
+ ModelType.STABLE_DIFFUSION_3: "SD3",
+ ModelType.STABLE_DIFFUSION_35: "SD3.5",
+ ModelType.STABLE_DIFFUSION_XL_10_BASE: "SDXL 1.0 Base",
+ ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING: "SDXL 1.0 Base Inpainting",
+ ModelType.WUERSTCHEN_2: "Wuerstchen v2",
+ ModelType.STABLE_CASCADE_1: "Stable Cascade",
+ ModelType.PIXART_ALPHA: "PixArt Alpha",
+ ModelType.PIXART_SIGMA: "PixArt Sigma",
+ ModelType.FLUX_DEV_1: "Flux Dev",
+ ModelType.FLUX_FILL_DEV_1: "Flux Fill Dev",
+ ModelType.SANA: "Sana",
+ ModelType.HUNYUAN_VIDEO: "Hunyuan Video",
+ ModelType.HI_DREAM_FULL: "HiDream Full",
+ ModelType.CHROMA_1: "Chroma1",
+ ModelType.QWEN: "Qwen Image",
+ }[self]
+
+ @staticmethod
+ def is_enabled(value, context=None):
+ if context == "convert_window":
+ return value in [
+ ModelType.STABLE_DIFFUSION_15,
+ ModelType.STABLE_DIFFUSION_15_INPAINTING,
+ ModelType.STABLE_DIFFUSION_20,
+ ModelType.STABLE_DIFFUSION_20_INPAINTING,
+ ModelType.STABLE_DIFFUSION_21,
+ ModelType.STABLE_DIFFUSION_3,
+ ModelType.STABLE_DIFFUSION_35,
+ ModelType.STABLE_DIFFUSION_XL_10_BASE,
+ ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING,
+ ModelType.WUERSTCHEN_2,
+ ModelType.STABLE_CASCADE_1,
+ ModelType.PIXART_ALPHA,
+ ModelType.PIXART_SIGMA,
+ ModelType.FLUX_DEV_1,
+ ModelType.FLUX_FILL_DEV_1,
+ ModelType.HUNYUAN_VIDEO,
+ ModelType.CHROMA_1, # TODO does this just work? HiDream is not here
+ ModelType.QWEN, # TODO does this just work? HiDream is not here
+ ]
+ else: # main_window
+ return value in [
+ ModelType.STABLE_DIFFUSION_15,
+ ModelType.STABLE_DIFFUSION_15_INPAINTING,
+ ModelType.STABLE_DIFFUSION_20,
+ # ModelType.STABLE_DIFFUSION_20_BASE,
+ ModelType.STABLE_DIFFUSION_20_INPAINTING,
+ # ModelType.STABLE_DIFFUSION_20_DEPTH,
+ ModelType.STABLE_DIFFUSION_21,
+ # ModelType.STABLE_DIFFUSION_21_BASE,
+ ModelType.STABLE_DIFFUSION_3,
+ ModelType.STABLE_DIFFUSION_35,
+ ModelType.STABLE_DIFFUSION_XL_10_BASE,
+ ModelType.STABLE_DIFFUSION_XL_10_BASE_INPAINTING,
+ ModelType.WUERSTCHEN_2,
+ ModelType.STABLE_CASCADE_1,
+ ModelType.PIXART_ALPHA,
+ ModelType.PIXART_SIGMA,
+ ModelType.FLUX_DEV_1,
+ ModelType.FLUX_FILL_DEV_1,
+ ModelType.SANA,
+ ModelType.HUNYUAN_VIDEO,
+ ModelType.HI_DREAM_FULL,
+ ModelType.CHROMA_1,
+ ModelType.QWEN,
+ ]
def is_stable_diffusion(self):
return self == ModelType.STABLE_DIFFUSION_15 \
@@ -143,10 +217,14 @@ def is_flow_matching(self) -> bool:
or self.is_hi_dream()
-class PeftType(Enum):
+class PeftType(BaseEnum):
LORA = 'LORA'
LOHA = 'LOHA'
OFT_2 = 'OFT_2'
- def __str__(self):
- return self.value
+ def pretty_print(self):
+ return {
+ PeftType.LORA: "LoRA",
+ PeftType.LOHA: "LoHA",
+ PeftType.OFT_2: "OFT 2",
+ }[self]
diff --git a/modules/util/enum/MouseButton.py b/modules/util/enum/MouseButton.py
new file mode 100644
index 000000000..c3a63b43d
--- /dev/null
+++ b/modules/util/enum/MouseButton.py
@@ -0,0 +1,8 @@
+from modules.util.enum.BaseEnum import BaseEnum
+
+
+class MouseButton(BaseEnum):
+ NONE = "none"
+ LEFT = "left"
+ MIDDLE = "middle"
+ RIGHT = "right"
diff --git a/modules/util/enum/NoiseScheduler.py b/modules/util/enum/NoiseScheduler.py
index 62d9af9dd..0762782fe 100644
--- a/modules/util/enum/NoiseScheduler.py
+++ b/modules/util/enum/NoiseScheduler.py
@@ -1,7 +1,7 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class NoiseScheduler(Enum):
+class NoiseScheduler(BaseEnum):
DDIM = 'DDIM'
EULER = 'EULER'
@@ -15,5 +15,31 @@ class NoiseScheduler(Enum):
DPMPP_SDE_KARRAS = 'DPMPP_SDE_KARRAS'
UNIPC_KARRAS = 'UNIPC_KARRAS'
- def __str__(self):
- return self.value
+ def pretty_print(self):
+ return {
+ NoiseScheduler.DDIM: 'DDIM',
+ NoiseScheduler.EULER: 'Euler',
+ NoiseScheduler.EULER_A: 'Euler A',
+ NoiseScheduler.DPMPP: 'DPM++',
+ NoiseScheduler.DPMPP_SDE: 'DPM++ SDE',
+ NoiseScheduler.UNIPC: 'UniPC',
+ NoiseScheduler.EULER_KARRAS: 'Euler Karras',
+ NoiseScheduler.DPMPP_KARRAS: 'DPM++ Karras',
+ NoiseScheduler.DPMPP_SDE_KARRAS: 'DPM++ SDE Karras',
+ NoiseScheduler.UNIPC_KARRAS: 'UniPC Karras',
+ }[self]
+
+ @staticmethod
+ def is_enabled(value, context=None):
+ return value in [
+ NoiseScheduler.DDIM,
+ NoiseScheduler.EULER,
+ NoiseScheduler.EULER_A,
+ # NoiseScheduler.DPMPP, # TODO: produces noisy samples
+ # NoiseScheduler.DPMPP_SDE, # TODO: produces noisy samples
+ NoiseScheduler.UNIPC,
+ NoiseScheduler.EULER_KARRAS,
+ NoiseScheduler.DPMPP_KARRAS,
+ NoiseScheduler.DPMPP_SDE_KARRAS,
+ # NoiseScheduler.UNIPC_KARRAS, # TODO: update diffusers to fix UNIPC_KARRAS (see https://github.com/huggingface/diffusers/pull/4581)
+ ]
diff --git a/modules/util/enum/Optimizer.py b/modules/util/enum/Optimizer.py
index c5f0c1089..5bfb35ac5 100644
--- a/modules/util/enum/Optimizer.py
+++ b/modules/util/enum/Optimizer.py
@@ -1,9 +1,9 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
import torch
-class Optimizer(Enum):
+class Optimizer(BaseEnum):
# Sorted by origin (BNB / torch first, then DADAPT), then by adapter name, then interleaved by variant.
# BNB Standard & 8-bit
@@ -73,6 +73,51 @@ class Optimizer(Enum):
AIDA = 'AIDA'
YOGI = 'YOGI'
+ def pretty_print(self):
+ return {
+ Optimizer.ADAGRAD: "AdaGrad",
+ Optimizer.ADAGRAD_8BIT: "AdaGrad 8 bit",
+ Optimizer.ADAM: "Adam",
+ Optimizer.ADAM_8BIT: "Adam 8 bit",
+ Optimizer.ADAMW: "AdamW",
+ Optimizer.ADAMW_8BIT: "AdamW 8 bit",
+ Optimizer.ADAMW_ADV: "AdamW Advanced",
+ Optimizer.AdEMAMix: "AdEMAMix",
+ Optimizer.AdEMAMix_8BIT: "AdEMAMix 8 bit",
+ Optimizer.SIMPLIFIED_AdEMAMix: "Simplified AdEMAMix",
+ Optimizer.ADOPT: "ADOPT",
+ Optimizer.ADOPT_ADV: "ADOPT Advanced",
+ Optimizer.LAMB: "LAMB",
+ Optimizer.LAMB_8BIT: "LAMB 8 bit",
+ Optimizer.LARS: "LARS",
+ Optimizer.LARS_8BIT: "LARS 8 bit",
+ Optimizer.LION: "Lion",
+ Optimizer.LION_8BIT: "Lion 8 bit",
+ Optimizer.LION_ADV: "Lion Advanced",
+ Optimizer.RMSPROP: "RMSProp",
+ Optimizer.RMSPROP_8BIT: "RMSProp 8 bit",
+ Optimizer.SGD: "SGD",
+ Optimizer.SGD_8BIT: "SGD 8 bit",
+ Optimizer.SCHEDULE_FREE_ADAMW: "Schedule Free AdamW",
+ Optimizer.SCHEDULE_FREE_SGD: "Schedule Free SGD",
+ Optimizer.DADAPT_ADA_GRAD: "DAdapt AdaGrad",
+ Optimizer.DADAPT_ADAM: "DAdapt Adam",
+ Optimizer.DADAPT_ADAN: "DAdapt ADAN",
+ Optimizer.DADAPT_LION: "DAdapt Lion",
+ Optimizer.DADAPT_SGD: "DAdapt SGD",
+ Optimizer.PRODIGY: "Prodigy",
+ Optimizer.PRODIGY_PLUS_SCHEDULE_FREE: "Prodigy Plus Schedule Free",
+ Optimizer.PRODIGY_ADV: "Prodigy Advanced",
+ Optimizer.LION_PRODIGY_ADV: "Lion Prodigy Advanced",
+ Optimizer.ADAFACTOR: "Adafactor",
+ Optimizer.CAME: "CAME",
+ Optimizer.CAME_8BIT: "CAME 8 bit",
+ Optimizer.ADABELIEF: "AdaBelief",
+ Optimizer.TIGER: "Tiger",
+ Optimizer.AIDA: "Aida",
+ Optimizer.YOGI: "Yogi",
+ }[self]
+
@property
def is_adaptive(self):
return self in [
@@ -121,6 +166,3 @@ def maybe_adjust_lrs(self, lrs: dict[str, float], optimizer: torch.optim.Optimiz
for i, (key, lr) in enumerate(lrs.items())
}
return lrs
-
- def __str__(self):
- return self.value
diff --git a/modules/util/enum/PromptSource.py b/modules/util/enum/PromptSource.py
new file mode 100644
index 000000000..4170eeff7
--- /dev/null
+++ b/modules/util/enum/PromptSource.py
@@ -0,0 +1,14 @@
+from modules.util.enum.BaseEnum import BaseEnum
+
+
+class PromptSource(BaseEnum):
+ SAMPLE = 'sample'
+ CONCEPT = 'concept'
+ FILENAME = 'filename'
+
+ def pretty_print(self):
+ return {
+ PromptSource.SAMPLE: "From text file per sample",
+ PromptSource.CONCEPT: "From single text file",
+ PromptSource.FILENAME: "From image file name"
+ }[self]
diff --git a/modules/util/enum/SpecialDropoutTags.py b/modules/util/enum/SpecialDropoutTags.py
new file mode 100644
index 000000000..d49cfd635
--- /dev/null
+++ b/modules/util/enum/SpecialDropoutTags.py
@@ -0,0 +1,7 @@
+from modules.util.enum.BaseEnum import BaseEnum
+
+
+class SpecialDropoutTags(BaseEnum):
+ NONE = 'NONE'
+ BLACKLIST = 'BLACKLIST'
+ WHITELIST = 'WHITELIST'
diff --git a/modules/util/enum/TimeUnit.py b/modules/util/enum/TimeUnit.py
index c06922d72..cb6b0b2ce 100644
--- a/modules/util/enum/TimeUnit.py
+++ b/modules/util/enum/TimeUnit.py
@@ -1,7 +1,7 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class TimeUnit(Enum):
+class TimeUnit(BaseEnum):
EPOCH = 'EPOCH'
STEP = 'STEP'
SECOND = 'SECOND'
@@ -11,9 +11,6 @@ class TimeUnit(Enum):
NEVER = 'NEVER'
ALWAYS = 'ALWAYS'
- def __str__(self):
- return self.value
-
def is_time_unit(self) -> bool:
return self == TimeUnit.SECOND \
or self == TimeUnit.MINUTE \
diff --git a/modules/util/enum/TimestepDistribution.py b/modules/util/enum/TimestepDistribution.py
index 55d8efd02..dde233583 100644
--- a/modules/util/enum/TimestepDistribution.py
+++ b/modules/util/enum/TimestepDistribution.py
@@ -1,13 +1,10 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class TimestepDistribution(Enum):
+class TimestepDistribution(BaseEnum):
UNIFORM = 'UNIFORM'
SIGMOID = 'SIGMOID'
LOGIT_NORMAL = 'LOGIT_NORMAL'
HEAVY_TAIL = 'HEAVY_TAIL'
COS_MAP = 'COS_MAP'
INVERTED_PARABOLA = 'INVERTED_PARABOLA'
-
- def __str__(self):
- return self.value
diff --git a/modules/util/enum/ToolType.py b/modules/util/enum/ToolType.py
new file mode 100644
index 000000000..c6ca9959a
--- /dev/null
+++ b/modules/util/enum/ToolType.py
@@ -0,0 +1,9 @@
+from modules.util.enum.BaseEnum import BaseEnum
+
+
+class ToolType(BaseEnum):
+ SEPARATOR = "separator"
+ BUTTON = "button"
+ CHECKABLE_BUTTON = "checkable_button"
+ SPINBOX = "spinbox"
+ DOUBLE_SPINBOX = "double_spinbox"
diff --git a/modules/util/enum/TrainingMethod.py b/modules/util/enum/TrainingMethod.py
index 403a1ec10..88605ad0d 100644
--- a/modules/util/enum/TrainingMethod.py
+++ b/modules/util/enum/TrainingMethod.py
@@ -1,11 +1,26 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class TrainingMethod(Enum):
+class TrainingMethod(BaseEnum):
FINE_TUNE = 'FINE_TUNE'
LORA = 'LORA'
EMBEDDING = 'EMBEDDING'
FINE_TUNE_VAE = 'FINE_TUNE_VAE'
- def __str__(self):
- return self.value
+ def pretty_print(self):
+ return {
+ TrainingMethod.FINE_TUNE: "Fine Tune",
+ TrainingMethod.LORA: "Lora",
+ TrainingMethod.EMBEDDING: "Embedding",
+ TrainingMethod.FINE_TUNE_VAE: "Fine Tune VAE",
+ }[self]
+
+ @staticmethod
+ def is_enabled(value, context=None):
+ # TODO
+ if context == "convert_window":
+ return value in [TrainingMethod.FINE_TUNE, TrainingMethod.LORA, TrainingMethod.EMBEDDING]
+ else: # Main window
+ pass
+
+ return True
diff --git a/modules/util/enum/VideoFormat.py b/modules/util/enum/VideoFormat.py
index 67d8e8242..123fba48f 100644
--- a/modules/util/enum/VideoFormat.py
+++ b/modules/util/enum/VideoFormat.py
@@ -1,13 +1,17 @@
-from enum import Enum
+from modules.util.enum.BaseEnum import BaseEnum
-class VideoFormat(Enum):
+class VideoFormat(BaseEnum):
PNG_IMAGE_SEQUENCE = 'PNG_IMAGE_SEQUENCE'
JPG_IMAGE_SEQUENCE = 'JPG_IMAGE_SEQUENCE'
MP4 = 'MP4'
- def __str__(self):
- return self.value
+ def pretty_print(self):
+ return {
+ VideoFormat.PNG_IMAGE_SEQUENCE: "PNG Image Sequence",
+ VideoFormat.JPG_IMAGE_SEQUENCE: "JPG Image Sequence",
+ VideoFormat.MP4: "MP4",
+ }[self]
def extension(self) -> str:
match self:
diff --git a/requirements-global.txt b/requirements-global.txt
index c8e53aaed..7ee47a04d 100644
--- a/requirements-global.txt
+++ b/requirements-global.txt
@@ -50,7 +50,12 @@ adv_optm==1.1.3 # advanced optimizers
scalene==1.5.51
# ui
-customtkinter==5.2.2
+#customtkinter==5.2.2
+pyside6==6.10.0 # 6.4.2
+pyqt6==6.10.0 # 6.4.2
+PyQt6-Qt6==6.10.0 # 6.4.2
+matplotlib
+show-in-file-manager
# cloud
runpod==1.7.10
@@ -60,3 +65,7 @@ fabric==3.2.2
psutil==7.0.0
requests==2.32.3
deepdiff==8.6.1 # output easy to read diff for troublshooting
+
+# dataset tools
+pyoxipng==9.1.0 # Multithreaded PNG optimisation
+pepedpid==0.1.1 # required for DPID resizing https://dl.acm.org/doi/10.1145/2980179.2980239 https://pypi.org/project/pepedpid/
diff --git a/resources/icons/buttons/License.txt b/resources/icons/buttons/License.txt
new file mode 100644
index 000000000..461204b1d
--- /dev/null
+++ b/resources/icons/buttons/License.txt
@@ -0,0 +1,9 @@
+ISC License
+
+This license applies to every icon under ./buttons except `auto-mask.png` which we the OT team created.
+
+Copyright (c) for portions of Lucide are held by Cole Bemis 2013-2022 as part of Feather (MIT). All other copyright (c) for Lucide are held by Lucide Contributors 2022.
+
+Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies.
+
+THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
diff --git a/resources/icons/buttons/dark/arrow-left.svg b/resources/icons/buttons/dark/arrow-left.svg
new file mode 100644
index 000000000..c47de7143
--- /dev/null
+++ b/resources/icons/buttons/dark/arrow-left.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/dark/arrow-right.svg b/resources/icons/buttons/dark/arrow-right.svg
new file mode 100644
index 000000000..7f27d096a
--- /dev/null
+++ b/resources/icons/buttons/dark/arrow-right.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/dark/brush.svg b/resources/icons/buttons/dark/brush.svg
new file mode 100644
index 000000000..857daa82c
--- /dev/null
+++ b/resources/icons/buttons/dark/brush.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/dark/file-check-corner.svg b/resources/icons/buttons/dark/file-check-corner.svg
new file mode 100644
index 000000000..56845f8cb
--- /dev/null
+++ b/resources/icons/buttons/dark/file-check-corner.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/dark/file-minus-corner.svg b/resources/icons/buttons/dark/file-minus-corner.svg
new file mode 100644
index 000000000..8ba7ff469
--- /dev/null
+++ b/resources/icons/buttons/dark/file-minus-corner.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/dark/file-scan.svg b/resources/icons/buttons/dark/file-scan.svg
new file mode 100644
index 000000000..1ef3e38a0
--- /dev/null
+++ b/resources/icons/buttons/dark/file-scan.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/dark/file-x-corner.svg b/resources/icons/buttons/dark/file-x-corner.svg
new file mode 100644
index 000000000..0efc12dd8
--- /dev/null
+++ b/resources/icons/buttons/dark/file-x-corner.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/dark/folder-open.svg b/resources/icons/buttons/dark/folder-open.svg
new file mode 100644
index 000000000..9be222df2
--- /dev/null
+++ b/resources/icons/buttons/dark/folder-open.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/dark/house.svg b/resources/icons/buttons/dark/house.svg
new file mode 100644
index 000000000..80b986b62
--- /dev/null
+++ b/resources/icons/buttons/dark/house.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/dark/info.svg b/resources/icons/buttons/dark/info.svg
new file mode 100644
index 000000000..2262515ae
--- /dev/null
+++ b/resources/icons/buttons/dark/info.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/dark/mouse.svg b/resources/icons/buttons/dark/mouse.svg
new file mode 100644
index 000000000..c99319b7d
--- /dev/null
+++ b/resources/icons/buttons/dark/mouse.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/dark/move.svg b/resources/icons/buttons/dark/move.svg
new file mode 100644
index 000000000..8ddf4c455
--- /dev/null
+++ b/resources/icons/buttons/dark/move.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/dark/paint-bucket.svg b/resources/icons/buttons/dark/paint-bucket.svg
new file mode 100644
index 000000000..e0a4d9ae8
--- /dev/null
+++ b/resources/icons/buttons/dark/paint-bucket.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/dark/redo.svg b/resources/icons/buttons/dark/redo.svg
new file mode 100644
index 000000000..9549a0845
--- /dev/null
+++ b/resources/icons/buttons/dark/redo.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/dark/save.svg b/resources/icons/buttons/dark/save.svg
new file mode 100644
index 000000000..1a0767253
--- /dev/null
+++ b/resources/icons/buttons/dark/save.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/dark/search.svg b/resources/icons/buttons/dark/search.svg
new file mode 100644
index 000000000..b35bf55dc
--- /dev/null
+++ b/resources/icons/buttons/dark/search.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/dark/trash-2.svg b/resources/icons/buttons/dark/trash-2.svg
new file mode 100644
index 000000000..c93b5bc4e
--- /dev/null
+++ b/resources/icons/buttons/dark/trash-2.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/dark/undo.svg b/resources/icons/buttons/dark/undo.svg
new file mode 100644
index 000000000..335209871
--- /dev/null
+++ b/resources/icons/buttons/dark/undo.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/arrow-left.svg b/resources/icons/buttons/light/arrow-left.svg
new file mode 100644
index 000000000..5c70ef03f
--- /dev/null
+++ b/resources/icons/buttons/light/arrow-left.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/arrow-right.svg b/resources/icons/buttons/light/arrow-right.svg
new file mode 100644
index 000000000..1e82b5dfd
--- /dev/null
+++ b/resources/icons/buttons/light/arrow-right.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/brush.svg b/resources/icons/buttons/light/brush.svg
new file mode 100644
index 000000000..f1d104d0d
--- /dev/null
+++ b/resources/icons/buttons/light/brush.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/file-check-corner.svg b/resources/icons/buttons/light/file-check-corner.svg
new file mode 100644
index 000000000..7a1cc62a0
--- /dev/null
+++ b/resources/icons/buttons/light/file-check-corner.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/file-minus-corner.svg b/resources/icons/buttons/light/file-minus-corner.svg
new file mode 100644
index 000000000..ce563d574
--- /dev/null
+++ b/resources/icons/buttons/light/file-minus-corner.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/file-scan.svg b/resources/icons/buttons/light/file-scan.svg
new file mode 100644
index 000000000..8280593fb
--- /dev/null
+++ b/resources/icons/buttons/light/file-scan.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/file-x-corner.svg b/resources/icons/buttons/light/file-x-corner.svg
new file mode 100644
index 000000000..49c41c931
--- /dev/null
+++ b/resources/icons/buttons/light/file-x-corner.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/folder-open.svg b/resources/icons/buttons/light/folder-open.svg
new file mode 100644
index 000000000..ef1bfe647
--- /dev/null
+++ b/resources/icons/buttons/light/folder-open.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/house.svg b/resources/icons/buttons/light/house.svg
new file mode 100644
index 000000000..9d71b4973
--- /dev/null
+++ b/resources/icons/buttons/light/house.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/info.svg b/resources/icons/buttons/light/info.svg
new file mode 100644
index 000000000..96b6a525c
--- /dev/null
+++ b/resources/icons/buttons/light/info.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/mouse.svg b/resources/icons/buttons/light/mouse.svg
new file mode 100644
index 000000000..72ac3ae1f
--- /dev/null
+++ b/resources/icons/buttons/light/mouse.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/move.svg b/resources/icons/buttons/light/move.svg
new file mode 100644
index 000000000..ece904956
--- /dev/null
+++ b/resources/icons/buttons/light/move.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/paint-bucket.svg b/resources/icons/buttons/light/paint-bucket.svg
new file mode 100644
index 000000000..46058f728
--- /dev/null
+++ b/resources/icons/buttons/light/paint-bucket.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/redo.svg b/resources/icons/buttons/light/redo.svg
new file mode 100644
index 000000000..cf6d0e9e7
--- /dev/null
+++ b/resources/icons/buttons/light/redo.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/save.svg b/resources/icons/buttons/light/save.svg
new file mode 100644
index 000000000..ce1d45c86
--- /dev/null
+++ b/resources/icons/buttons/light/save.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/search.svg b/resources/icons/buttons/light/search.svg
new file mode 100644
index 000000000..148468042
--- /dev/null
+++ b/resources/icons/buttons/light/search.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/trash-2.svg b/resources/icons/buttons/light/trash-2.svg
new file mode 100644
index 000000000..1a24da1a2
--- /dev/null
+++ b/resources/icons/buttons/light/trash-2.svg
@@ -0,0 +1 @@
+
diff --git a/resources/icons/buttons/light/undo.svg b/resources/icons/buttons/light/undo.svg
new file mode 100644
index 000000000..71f7ec3aa
--- /dev/null
+++ b/resources/icons/buttons/light/undo.svg
@@ -0,0 +1 @@
+
diff --git a/scripts/caption_ui.py b/scripts/caption_ui.py
index aa380cd9a..7fe2c9e53 100644
--- a/scripts/caption_ui.py
+++ b/scripts/caption_ui.py
@@ -2,16 +2,30 @@
script_imports()
-from modules.ui.CaptionUI import CaptionUI
-from modules.util.args.CaptionUIArgs import CaptionUIArgs
+import os
+import sys
+
+from modules.ui.controllers.windows.DatasetController import DatasetController
+from modules.ui.utils.OneTrainerApplication import OnetrainerApplication
+
+from PySide6.QtUiTools import QUiLoader
def main():
- args = CaptionUIArgs.parse_args()
+ os.environ["QT_QPA_PLATFORM"] = "xcb" # Suppress Wayland warnings on NVidia drivers.
+ # TODO: scalene (modules.ui.models.StateModel) changes locale on import, change QT6 locale to suppress warning here?
+
+ app = OnetrainerApplication(sys.argv)
+ loader = QUiLoader()
+
+ onetrainer = DatasetController(loader)
+
+ # Invalidate ui elements after the controllers are set up, but before showing them.
+ app.stateChanged.emit()
+ onetrainer.ui.show()
- ui = CaptionUI(None, args.dir, args.include_subdirectories)
- ui.mainloop()
+ sys.exit(app.exec())
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/scripts/convert_model_ui.py b/scripts/convert_model_ui.py
index 3d024042d..d109c7fcb 100644
--- a/scripts/convert_model_ui.py
+++ b/scripts/convert_model_ui.py
@@ -2,13 +2,30 @@
script_imports()
-from modules.ui.ConvertModelUI import ConvertModelUI
+import os
+import sys
+
+from modules.ui.controllers.windows.ConvertController import ConvertController
+from modules.ui.utils.OneTrainerApplication import OnetrainerApplication
+
+from PySide6.QtUiTools import QUiLoader
def main():
- ui = ConvertModelUI(None)
- ui.mainloop()
+ os.environ["QT_QPA_PLATFORM"] = "xcb" # Suppress Wayland warnings on NVidia drivers.
+ # TODO: scalene (modules.ui.models.StateModel) changes locale on import, change QT6 locale to suppress warning here?
+
+ app = OnetrainerApplication(sys.argv)
+ loader = QUiLoader()
+
+ onetrainer = ConvertController(loader)
+
+ # Invalidate ui elements after the controllers are set up, but before showing them.
+ app.stateChanged.emit()
+ onetrainer.ui.show()
+
+ sys.exit(app.exec())
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/scripts/generate_captions.py b/scripts/generate_captions.py
index 7653430e1..27d4c816c 100644
--- a/scripts/generate_captions.py
+++ b/scripts/generate_captions.py
@@ -2,35 +2,29 @@
script_imports()
-from modules.module.Blip2Model import Blip2Model
-from modules.module.BlipModel import BlipModel
-from modules.module.WDModel import WDModel
-from modules.util.args.GenerateCaptionsArgs import GenerateCaptionsArgs
-from modules.util.enum.GenerateCaptionsModel import GenerateCaptionsModel
+import os
+import sys
-import torch
+from modules.ui.controllers.windows.CaptionController import CaptionController
+from modules.ui.utils.OneTrainerApplication import OnetrainerApplication
+
+from PySide6.QtUiTools import QUiLoader
def main():
- args = GenerateCaptionsArgs.parse_args()
-
- model = None
- if args.model == GenerateCaptionsModel.BLIP:
- model = BlipModel(torch.device(args.device), args.dtype.torch_dtype())
- elif args.model == GenerateCaptionsModel.BLIP2:
- model = Blip2Model(torch.device(args.device), args.dtype.torch_dtype())
- elif args.model == GenerateCaptionsModel.WD14_VIT_2:
- model = WDModel(torch.device(args.device), args.dtype.torch_dtype())
-
- model.caption_folder(
- sample_dir=args.sample_dir,
- initial_caption=args.initial_caption,
- caption_prefix=args.caption_prefix,
- caption_postfix=args.caption_postfix,
- mode=args.mode,
- error_callback=lambda filename: print("Error while processing image " + filename),
- include_subdirectories=args.include_subdirectories
- )
+ os.environ["QT_QPA_PLATFORM"] = "xcb" # Suppress Wayland warnings on NVidia drivers.
+ # TODO: scalene (modules.ui.models.StateModel) changes locale on import, change QT6 locale to suppress warning here?
+
+ app = OnetrainerApplication(sys.argv)
+ loader = QUiLoader()
+
+ onetrainer = CaptionController(loader)
+
+ # Invalidate ui elements after the controllers are set up, but before showing them.
+ app.stateChanged.emit()
+ onetrainer.ui.show()
+
+ sys.exit(app.exec())
if __name__ == "__main__":
diff --git a/scripts/generate_masks.py b/scripts/generate_masks.py
index 1c001c0d8..66a72ffa6 100644
--- a/scripts/generate_masks.py
+++ b/scripts/generate_masks.py
@@ -2,40 +2,29 @@
script_imports()
-from modules.module.ClipSegModel import ClipSegModel
-from modules.module.MaskByColor import MaskByColor
-from modules.module.RembgHumanModel import RembgHumanModel
-from modules.module.RembgModel import RembgModel
-from modules.util.args.GenerateMasksArgs import GenerateMasksArgs
-from modules.util.enum.GenerateMasksModel import GenerateMasksModel
+import os
+import sys
-import torch
+from modules.ui.controllers.windows.MaskController import MaskController
+from modules.ui.utils.OneTrainerApplication import OnetrainerApplication
+
+from PySide6.QtUiTools import QUiLoader
def main():
- args = GenerateMasksArgs.parse_args()
-
- model = None
- if args.model == GenerateMasksModel.CLIPSEG:
- model = ClipSegModel(torch.device(args.device), args.dtype.torch_dtype())
- elif args.model == GenerateMasksModel.REMBG:
- model = RembgModel(torch.device(args.device), args.dtype.torch_dtype())
- elif args.model == GenerateMasksModel.REMBG_HUMAN:
- model = RembgHumanModel(torch.device(args.device), args.dtype.torch_dtype())
- elif args.model == GenerateMasksModel.COLOR:
- model = MaskByColor(torch.device(args.device), args.dtype.torch_dtype())
-
- model.mask_folder(
- sample_dir=args.sample_dir,
- prompts=args.prompts,
- mode=args.mode,
- threshold=args.threshold,
- smooth_pixels=args.smooth_pixels,
- expand_pixels=args.expand_pixels,
- alpha=args.alpha,
- error_callback=lambda filename: print("Error while processing image " + filename),
- include_subdirectories=args.include_subdirectories
- )
+ os.environ["QT_QPA_PLATFORM"] = "xcb" # Suppress Wayland warnings on NVidia drivers.
+ # TODO: scalene (modules.ui.models.StateModel) changes locale on import, change QT6 locale to suppress warning here?
+
+ app = OnetrainerApplication(sys.argv)
+ loader = QUiLoader()
+
+ onetrainer = MaskController(loader)
+
+ # Invalidate ui elements after the controllers are set up, but before showing them.
+ app.stateChanged.emit()
+ onetrainer.ui.show()
+
+ sys.exit(app.exec())
if __name__ == "__main__":
diff --git a/scripts/train_ui.py b/scripts/train_ui.py
index 46ee8f1e6..30f1b4183 100644
--- a/scripts/train_ui.py
+++ b/scripts/train_ui.py
@@ -2,13 +2,30 @@
script_imports()
-from modules.ui.TrainUI import TrainUI
+
+import os
+import sys
+
+from modules.ui.controllers.OneTrainerController import OnetrainerController
+from modules.ui.utils.OneTrainerApplication import OnetrainerApplication
+
+from PySide6.QtUiTools import QUiLoader
def main():
- ui = TrainUI()
- ui.mainloop()
+ os.environ["QT_QPA_PLATFORM"] = "xcb" # Suppress Wayland warnings on NVidia drivers.
+ # TODO: scalene (modules.ui.models.StateModel) changes locale on import, change QT6 locale to suppress warning here?
+
+ app = OnetrainerApplication(sys.argv)
+ loader = QUiLoader()
+
+ onetrainer = OnetrainerController(loader)
+
+ # Invalidate ui elements after the controllers are set up, but before showing them.
+ app.stateChanged.emit()
+ onetrainer.ui.show()
+ sys.exit(app.exec())
if __name__ == '__main__':
main()
diff --git a/scripts/video_tool_ui.py b/scripts/video_tool_ui.py
index 99707506f..3d89bca82 100644
--- a/scripts/video_tool_ui.py
+++ b/scripts/video_tool_ui.py
@@ -2,13 +2,29 @@
script_imports()
-from modules.ui.VideoToolUI import VideoToolUI
+import os
+import sys
+
+from modules.ui.controllers.windows.VideoController import VideoController
+from modules.ui.utils.OneTrainerApplication import OnetrainerApplication
+
+from PySide6.QtUiTools import QUiLoader
def main():
- ui = VideoToolUI(None)
- ui.mainloop()
+ os.environ["QT_QPA_PLATFORM"] = "xcb" # Suppress Wayland warnings on NVidia drivers.
+ # TODO: scalene (modules.ui.models.StateModel) changes locale on import, change QT6 locale to suppress warning here?
+
+ app = OnetrainerApplication(sys.argv)
+ loader = QUiLoader()
+
+ onetrainer = VideoController(loader)
+
+ # Invalidate ui elements after the controllers are set up, but before showing them.
+ app.stateChanged.emit()
+ onetrainer.ui.show()
+ sys.exit(app.exec())
if __name__ == '__main__':
main()