Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
- id: check-yaml

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.4
rev: v0.14.5
hooks:
# Run the Ruff linter, but not the formatter.
- id: ruff
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ OneTrainer is a one-stop solution for all your Diffusion training needs.
2. Navigate into the cloned directory `cd OneTrainer`
3. Set up a virtual environment `python -m venv venv`
4. Activate the new venv:
- Windows: `venv/scripts/activate`
- Windows: `venv\scripts\activate`
- Linux and Mac: Depends on your shell, activate the venv accordingly
5. Install the requirements `pip install -r requirements.txt`

Expand Down
8 changes: 8 additions & 0 deletions modules/modelSetup/mixin/ModelSetupDiffusionLossMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,16 @@ def _diffusion_losses(
if 'timestep' in data:
v_pred = data.get('prediction_type', '') == 'v_prediction'
match config.loss_weight_fn:
case LossWeight.CONSTANT:
pass
case LossWeight.MIN_SNR_GAMMA:
losses *= self.__min_snr_weight(data['timestep'], config.loss_weight_strength, v_pred, losses.device)
case LossWeight.DEBIASED_ESTIMATION:
losses *= self.__debiased_estimation_weight(data['timestep'], v_pred, losses.device)
case LossWeight.P2:
losses *= self.__p2_loss_weight(data['timestep'], config.loss_weight_strength, v_pred, losses.device)
case _:
raise NotImplementedError(f"Loss weight function {config.loss_weight_fn} not implemented for diffusion models")

return losses

Expand Down Expand Up @@ -329,7 +333,11 @@ def _flow_matching_losses(
# Apply timestep based loss weighting.
if 'timestep' in data:
match config.loss_weight_fn:
case LossWeight.CONSTANT:
pass
case LossWeight.SIGMA:
losses *= self.__sigma_loss_weight(data['timestep'], losses.device)
case _:
raise NotImplementedError(f"Loss weight function {config.loss_weight_fn} not implemented for flow matching models")

return losses
16 changes: 8 additions & 8 deletions modules/modelSetup/mixin/ModelSetupNoiseMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,31 +31,31 @@ def _compute_and_cache_offset_noise_psi_schedule(self, betas: Tensor) -> Tensor:
alphas_cumprod = torch.cumprod(alphas, dim=0)

# From paper footnote 4: "we introduce α_0 = 1 for convenience".
alphas_with_zero = torch.cat([torch.tensor([1.0], device=betas.device, dtype=betas.dtype), alphas])
alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=betas.device, dtype=betas.dtype), alphas_cumprod[:-1]])

# --- Start of Algorithm 1 ---
gammas = torch.zeros(T, device=betas.device, dtype=betas.dtype)

# Step 1: Set gamma_1 = 1
gammas[0] = 1.0

# This sum is `Σ_{i=1 to t-1} γ_i/√α_{i-1}` which we build iteratively.
cumulative_sum_term = gammas[0] / torch.sqrt(alphas_with_zero[0])
# This sum is `Σ_{i=1 to t-1} γ_i/√¯αᵢ₋₁` which we build iteratively.
cumulative_sum_term = gammas[0] / torch.sqrt(alphas_cumprod_prev[0])

# Step 2-4: Loop for t = 2 to T (in code: t = 1 to T-1)
for t in range(1, T):
alpha_t = alphas[t]
alpha_tm1 = alphas[t - 1]
alpha_cumprod_tm1 = alphas_cumprod_prev[t]

# Denominator from the paper's formula for C_t.
c_t_denominator = alpha_t * (1 - alpha_tm1)
c_t = (1 - alpha_t) * torch.sqrt(alpha_tm1) / c_t_denominator
c_t_denominator = alpha_t * (1 - alpha_cumprod_tm1)
c_t = (1 - alpha_t) * torch.sqrt(alpha_cumprod_tm1) / c_t_denominator

# Paper's recursive formula uses the full cumulative sum.
gammas[t] = c_t * cumulative_sum_term

# Update the sum for the next iteration.
cumulative_sum_term += gammas[t] / torch.sqrt(alphas_with_zero[t])
cumulative_sum_term += gammas[t] / torch.sqrt(alphas_cumprod_prev[t])

# Step 5: Calculate normalization factor psi_T
psi_T_denominator = torch.sqrt(1 - alphas_cumprod[-1])
Expand All @@ -66,7 +66,7 @@ def _compute_and_cache_offset_noise_psi_schedule(self, betas: Tensor) -> Tensor:
# --- End of Algorithm 1 ---

# Finally, calculate the psi schedule for all timesteps t using Equation (22)
terms = gammas_normalized / torch.sqrt(alphas_with_zero[:-1])
terms = gammas_normalized / torch.sqrt(alphas_cumprod_prev)
s_cumulative = torch.cumsum(terms, dim=0)
psi_schedule = s_cumulative / torch.sqrt(1 - alphas_cumprod)

Expand Down
173 changes: 172 additions & 1 deletion modules/module/LoRAModule.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import copy
import math
from abc import abstractmethod
from collections import defaultdict
from collections.abc import Mapping
from typing import Any

from modules.module.oft_utils import OFTRotationModule
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import PeftType
from modules.util.ModuleFilter import ModuleFilter
Expand Down Expand Up @@ -332,6 +334,139 @@ def extract_from_module(self, base_module: nn.Module):
pass


class OFTModule(PeftBase):
oft_R: OFTRotationModule | None
rank: int
oft_block_size: int
coft: bool
coft_eps: float
block_share: bool
dropout_probability: float
adjustment_info: tuple[int, int] | None # for reporting

def __init__(self, prefix: str, orig_module: nn.Module | None, oft_block_size: int, coft: bool, coft_eps: float, block_share: bool, **kwargs):
super().__init__(prefix, orig_module)
self.oft_block_size = oft_block_size
self.rank = 0
self.coft = coft
self.coft_eps = coft_eps
self.block_share = block_share
self.dropout_probability = kwargs.pop('dropout_probability', 0.0)
self.oft_R = None
self.adjustment_info = None


if orig_module is not None:
self.initialize_weights()

def adjust_oft_parameters(self, in_features, params):
"""
Adjust the OFT parameters to be divisible by the in_features dimension.
"""
if params < in_features:
higher_params = params
while higher_params <= in_features and in_features % higher_params != 0:
higher_params += 1
else:
return in_features

lower_params = params
while lower_params > 1 and in_features % lower_params != 0:
lower_params -= 1

if (params - lower_params) <= (higher_params - params):
return lower_params
else:
return higher_params

def initialize_weights(self):
self._initialized = True

if isinstance(self.orig_module, nn.Linear):
in_features = self.orig_module.in_features
elif isinstance(self.orig_module, nn.Conv2d):
if self.orig_module.dilation[0] > 1 or self.orig_module.dilation[1] > 1:
raise ValueError("Conv2d with dilation > 1 is not supported by OFT.")
in_features = self.orig_module.in_channels * self.orig_module.kernel_size[0] * self.orig_module.kernel_size[1]
else:
raise NotImplementedError("Unsupported layer type for OFT")

oft_block_size = self.oft_block_size
if oft_block_size <= 0:
raise ValueError("Rank must be a positive.")

# Adjust oft_block_size to be a divisor of in_features
if in_features % oft_block_size != 0 or oft_block_size > in_features:
old_oft_block_size = oft_block_size
oft_block_size = self.adjust_oft_parameters(in_features, oft_block_size)
self.adjustment_info = (old_oft_block_size, oft_block_size)

# Calculate the number of blocks 'r'
r = in_features // oft_block_size

# Store the final, potentially adjusted values
self.rank = r
self.oft_block_size = oft_block_size

n_elements = self.oft_block_size * (self.oft_block_size - 1) // 2

self.oft_R = OFTRotationModule(
r=self.rank if not self.block_share else 1,
n_elements=n_elements,
block_size=self.oft_block_size,
in_features=in_features,
coft=self.coft,
coft_eps=self.coft_eps,
block_share=self.block_share,
use_cayley_neumann=True,
num_cayley_neumann_terms=5,
dropout_probability=self.dropout_probability,
)

nn.init.zeros_(self.oft_R.weight)

def forward(self, x, *args, **kwargs):
self.check_initialized()

# For Linear layers, rotating the input is mathematically equivalent to rotating the weights.
if isinstance(self.orig_module, nn.Linear):
rotated_x = self.oft_R(x)
return self.orig_forward(rotated_x, *args, **kwargs)

# For Conv2d, we must rotate the weights, not the input, to preserve spatial information.
orth_rotate = self.oft_R._cayley_batch(
self.oft_R.weight, self.oft_R.block_size, self.oft_R.use_cayley_neumann, self.oft_R.num_cayley_neumann_terms
)
orth_rotate = self.oft_R.dropout(orth_rotate)

if self.block_share:
orth_rotate = orth_rotate.repeat(self.rank, 1, 1)

weight = self.orig_module.weight
weight_reshaped = weight.reshape(weight.shape[0], self.rank, self.oft_block_size)
rotated_weight_reshaped = torch.einsum("ork,rkc->orc", weight_reshaped, orth_rotate)

rotated_weight = rotated_weight_reshaped.reshape(weight.shape)

return self.op(x, rotated_weight, self.orig_module.bias, **self.layer_kwargs)

def apply_to_module(self):
# TODO
pass

def extract_from_module(self, base_module: nn.Module):
# TODO
pass

def check_initialized(self):
super().check_initialized()
assert self.oft_R is not None

@property
def dropout(self):
return self.oft_R.dropout


class DoRAModule(LoRAModule):
"""Weight-decomposed low rank adaptation.

Expand Down Expand Up @@ -432,6 +567,7 @@ def forward(self, x, *args, **kwargs):
DummyLoRAModule = LoRAModule.make_dummy()
DummyDoRAModule = DoRAModule.make_dummy()
DummyLoHaModule = LoHaModule.make_dummy()
DummyOFTModule = OFTModule.make_dummy()


class LoRAModuleWrapper:
Expand Down Expand Up @@ -481,6 +617,18 @@ def __init__(
self.dummy_klass = DummyLoHaModule
self.additional_args = [self.rank, self.alpha]
self.additional_kwargs = {}
elif self.peft_type == PeftType.OFT_2:
self.klass = OFTModule
self.dummy_klass = DummyOFTModule
self.additional_args = [
config.oft_block_size,
config.oft_coft,
config.coft_eps,
config.oft_block_share,
]
self.additional_kwargs = {
'dropout_probability': config.dropout_probability,
}

self.lora_modules = self.__create_modules(orig_module, config)

Expand All @@ -492,18 +640,37 @@ def __create_modules(self, orig_module: nn.Module | None, config: TrainConfig) -
selected = []
deselected = []
unsuitable = []
oft_adjustments = []

for name, child_module in orig_module.named_modules():
name = name.replace(".checkpoint.", ".")
if not isinstance(child_module, Linear | Conv2d):
unsuitable.append(name)
continue
if len(self.module_filters) == 0 or any(f.matches(name) for f in self.module_filters):
lora_modules[name] = self.klass(self.prefix + "." + name, child_module, *self.additional_args, **self.additional_kwargs)
lora_module = self.klass(self.prefix + "." + name, child_module, *self.additional_args, **self.additional_kwargs)
lora_modules[name] = lora_module
if self.peft_type == PeftType.OFT_2 and lora_module.adjustment_info:
old, new = lora_module.adjustment_info
oft_adjustments.append({'old': old, 'new': new})
selected.append(name)
else:
deselected.append(name)

if oft_adjustments:
summary = defaultdict(int)
for adj in oft_adjustments:
summary[(adj['old'], adj['new'])] += 1

sorted_summary = sorted(summary.items(), key=lambda item: (item[0][0], item[0][1]))

summary_lines = [
f" - {count} layer{'s' if count > 1 else ''} from {old} to {new}"
for (old, new), count in sorted_summary
]
print(f"OFT Block Size automatically adjusted for {len(oft_adjustments)} layers. Changes:")
print("\n".join(summary_lines))

if len(self.module_filters) > 0:
if config.debug_mode:
print(f"Selected layers: {selected}")
Expand Down Expand Up @@ -539,6 +706,10 @@ def _check_rank_matches(self, state_dict: dict[str, Tensor]):
if not state_dict:
return

# For OFT, the comparison is not straightforward, so we skip it.
if self.peft_type == PeftType.OFT_2:
return

if rank_key := next((k for k in state_dict if k.endswith((".lora_down.weight", ".hada_w1_a"))), None):
if (checkpoint_rank := state_dict[rank_key].shape[0]) != self.rank:
raise ValueError(f"Rank mismatch: checkpoint={checkpoint_rank}, config={self.rank}, please correct in the UI.")
Expand Down
Loading