Skip to content

Commit

Permalink
feat: new release
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 9, 2024
1 parent 6d2dc14 commit 5984fab
Show file tree
Hide file tree
Showing 10 changed files with 1,089 additions and 364 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ good models
.env
tensorboard
wandb
gs:
gcs_mount
datacache
*.deb
Expand Down
1,043 changes: 831 additions & 212 deletions evaluate.ipynb

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions flaxdiff/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def __call__(self, x, context=None):
value = self.value(context)

hidden_states = nn.dot_product_attention(
query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
query, key, value, dtype=self.dtype, broadcast_dropout=False,
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=True,
deterministic=True
)
proj = self.proj_attn(hidden_states)
proj = proj.reshape(orig_x_shape)
Expand Down Expand Up @@ -187,7 +189,7 @@ def setup(self):

def __call__(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=3)
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=-1)
return hidden_linear * nn.gelu(hidden_gelu)

class FlaxFeedForward(nn.Module):
Expand Down Expand Up @@ -291,14 +293,14 @@ class TransformerBlock(nn.Module):
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
use_projection: bool = False
use_flash_attention:bool = True
use_self_and_cross:bool = False
use_flash_attention:bool = False
use_self_and_cross:bool = True
only_pure_attention:bool = False

@nn.compact
def __call__(self, x, context=None):
inner_dim = self.heads * self.dim_head
B, H, W, C = x.shape
C = x.shape[-1]
normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
if self.use_projection == True:
if self.use_linear_attention:
Expand Down
14 changes: 12 additions & 2 deletions flaxdiff/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from flax.typing import Dtype, PrecisionLike
from typing import Dict, Callable, Sequence, Any, Union
import einops
from functools import partial

# Kernel initializer to use
def kernel_init(scale, dtype=jnp.float32):
Expand Down Expand Up @@ -266,11 +267,20 @@ class ResidualBlock(nn.Module):
kernel_init:Callable=kernel_init(1.0)
dtype: Optional[Dtype] = None
precision: PrecisionLike = None

def setup(self):
if self.norm_groups > 0:
norm = partial(nn.GroupNorm, self.norm_groups)
else:
norm = partial(nn.RMSNorm, 1e-5)

self.norm1 = norm()
self.norm2 = norm()

@nn.compact
def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
residual = x
out = nn.GroupNorm(self.norm_groups)(x)
out = self.norm1(x)
# out = nn.RMSNorm()(x)
out = self.activation(out)

Expand All @@ -295,7 +305,7 @@ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_fe
# out = out * (1 + scale) + shift
out = out + temb

out = nn.GroupNorm(self.norm_groups)(out)
out = self.norm2(out)
# out = nn.RMSNorm()(out)
out = self.activation(out)

Expand Down
24 changes: 17 additions & 7 deletions flaxdiff/models/simple_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import einops
from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection, ResidualBlock
from .attention import TransformerBlock
from functools import partial

class Unet(nn.Module):
output_channels:int=3
Expand All @@ -19,6 +20,15 @@ class Unet(nn.Module):
dtype: Optional[Dtype] = None
precision: PrecisionLike = None

def setup(self):
if self.norm_groups > 0:
norm = partial(nn.GroupNorm, self.norm_groups)
else:
norm = partial(nn.RMSNorm, 1e-5)

# self.last_up_norm = norm()
self.conv_out_norm = norm()

@nn.compact
def __call__(self, x, temb, textcontext):
# print("embedding features", self.emb_features)
Expand Down Expand Up @@ -69,7 +79,7 @@ def __call__(self, x, temb, textcontext):
use_projection=attention_config.get("use_projection", False),
use_self_and_cross=attention_config.get("use_self_and_cross", True),
precision=attention_config.get("precision", self.precision),
only_pure_attention=True,
only_pure_attention=attention_config.get("only_pure_attention", True),
name=f"down_{i}_attention_{j}")(x, textcontext)
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
downs.append(x)
Expand Down Expand Up @@ -107,8 +117,8 @@ def __call__(self, x, temb, textcontext):
use_linear_attention=False,
use_projection=middle_attention.get("use_projection", False),
use_self_and_cross=False,
precision=attention_config.get("precision", self.precision),
only_pure_attention=True,
precision=middle_attention.get("precision", self.precision),
only_pure_attention=middle_attention.get("only_pure_attention", True),
name=f"middle_attention_{j}")(x, textcontext)
x = ResidualBlock(
middle_conv_type,
Expand Down Expand Up @@ -150,7 +160,7 @@ def __call__(self, x, temb, textcontext):
use_projection=attention_config.get("use_projection", False),
use_self_and_cross=attention_config.get("use_self_and_cross", True),
precision=attention_config.get("precision", self.precision),
only_pure_attention=True,
only_pure_attention=attention_config.get("only_pure_attention", True),
name=f"up_{i}_attention_{j}")(x, textcontext)
# print("Upscaling ", i, x.shape)
if i != len(feature_depths) - 1:
Expand All @@ -163,13 +173,13 @@ def __call__(self, x, temb, textcontext):
precision=self.precision
)(x)

# x = nn.GroupNorm(8)(x)
# x = self.last_up_norm(x)
x = ConvLayer(
conv_type,
features=self.feature_depths[0],
kernel_size=(3, 3),
strides=(1, 1),
kernel_init=kernel_init(0.0),
kernel_init=kernel_init(1.0),
dtype=self.dtype,
precision=self.precision
)(x)
Expand All @@ -189,7 +199,7 @@ def __call__(self, x, temb, textcontext):
precision=self.precision
)(x, temb)

x = nn.GroupNorm(self.norm_groups)(x)
x = self.conv_out_norm(x)
x = self.activation(x)

noise_out = ConvLayer(
Expand Down
29 changes: 13 additions & 16 deletions flaxdiff/models/simple_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
from flax import linen as nn
from typing import Callable, Any
from .simply_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
from .simple_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
from .attention import TransformerBlock

class PatchEmbedding(nn.Module):
Expand Down Expand Up @@ -40,34 +40,35 @@ def __call__(self, x):
class TransformerEncoder(nn.Module):
num_layers: int
num_heads: int
mlp_dim: int
dropout_rate: float = 0.1
dtype: Any = jnp.float32
precision: Any = jax.lax.Precision.HIGH
use_projection: bool = False

@nn.compact
def __call__(self, x, training=True):
def __call__(self, x, context=None):
for _ in range(self.num_layers):
x = TransformerBlock(
heads=self.num_heads,
dim_head=x.shape[-1] // self.num_heads,
mlp_dim=self.mlp_dim,
dropout_rate=self.dropout_rate,
dtype=self.dtype,
precision=self.precision
)(x)
precision=self.precision,
use_self_and_cross=True,
use_projection=self.use_projection,
)(x, context)
return x

class VisionTransformer(nn.Module):
patch_size: int = 16
embedding_dim: int = 768
num_layers: int = 12
num_heads: int = 12
mlp_dim: int = 3072
emb_features: int = 256
dropout_rate: float = 0.1
dtype: Any = jnp.float32
precision: Any = jax.lax.Precision.HIGH
use_projection: bool = False

@nn.compact
def __call__(self, x, temb, textcontext=None):
Expand All @@ -81,27 +82,23 @@ def __call__(self, x, temb, textcontext=None):

# Add positional encoding
x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.embedding_dim)(x)

num_patches = x.shape[1]

# Add time embedding
temb = jnp.expand_dims(temb, axis=1)
x = jnp.concatenate([x, temb], axis=1)

# Add text context
if textcontext is not None:
x = jnp.concatenate([x, textcontext], axis=1)

# Transformer encoder
x = TransformerEncoder(
num_layers=self.num_layers,
num_heads=self.num_heads,
mlp_dim=self.mlp_dim,
dropout_rate=self.dropout_rate,
dtype=self.dtype,
precision=self.precision
)(x)
precision=self.precision,
use_projection=self.use_projection
)(x, textcontext)

# Extract the image tokens (exclude time and text embeddings)
num_patches = (x.shape[1] - 1 - (0 if textcontext is None else textcontext.shape[1]))
x = x[:, :num_patches, :]

# Reshape to image dimensions
Expand Down
51 changes: 41 additions & 10 deletions flaxdiff/trainer/diffusion_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def apply_ema(self, decay: float = 0.999):
)
return self.replace(ema_params=new_ema_params)

from flaxdiff.models.autoencoder.autoencoder import AutoEncoder

class DiffusionTrainer(SimpleTrainer):
noise_schedule: NoiseScheduler
model_output_transform: DiffusionPredictionTransform
Expand All @@ -40,7 +42,7 @@ def __init__(self,
optimizer: optax.GradientTransformation,
noise_schedule: NoiseScheduler,
rngs: jax.random.PRNGKey,
unconditional_prob: float = 0.2,
unconditional_prob: float = 0.12,
name: str = "Diffusion",
model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
autoencoder: AutoEncoder = None,
Expand All @@ -67,7 +69,8 @@ def generate_states(
existing_state: dict = None,
existing_best_state: dict = None,
model: nn.Module = None,
param_transforms: Callable = None
param_transforms: Callable = None,
use_dynamic_scale: bool = False
) -> Tuple[TrainState, TrainState]:
print("Generating states for DiffusionTrainer")
rngs, subkey = jax.random.split(rngs)
Expand All @@ -88,7 +91,8 @@ def generate_states(
ema_params=new_state['ema_params'],
tx=optimizer,
rngs=rngs,
metrics=Metrics.empty()
metrics=Metrics.empty(),
dynamic_scale = flax.training.dynamic_scale.DynamicScale() if use_dynamic_scale else None
)

if existing_best_state is not None:
Expand Down Expand Up @@ -125,14 +129,14 @@ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, loc
local_rng_state = RandomMarkovState(subkey)

images = batch['image']
images = jnp.array(images, dtype=jnp.bfloat16)
images = jnp.array(images, dtype=jnp.float32)
# normalize image
images = (images - 127.5) / 127.5

if autoencoder is not None:
# Convert the images to latent space
# local_rng_state, rngs = local_rng_state.get_random_key()
images = autoencoder.encode(images)#, rngs)
local_rng_state, rngs = local_rng_state.get_random_key()
images = autoencoder.encode(images, rngs)

output = text_embedder(
input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
Expand Down Expand Up @@ -163,12 +167,39 @@ def model_loss(params):
loss = nloss
return loss

loss, grads = jax.value_and_grad(model_loss)(train_state.params)

if train_state.dynamic_scale is not None:
# dynamic scale takes care of averaging gradients across replicas
grad_fn = train_state.dynamic_scale.value_and_grad(
model_loss, axis_name="data"
)
dynamic_scale, is_fin, loss, grads = grad_fn(train_state.params)
train_state = train_state.replace(dynamic_scale=dynamic_scale)
else:
grad_fn = jax.value_and_grad(model_loss)
loss, grads = grad_fn(train_state.params)
if distributed_training:
grads = jax.lax.pmean(grads, "data")

new_state = train_state.apply_gradients(grads=grads)

if train_state.dynamic_scale:
# if is_fin == False the gradients contain Inf/NaNs and optimizer state and
# params should be restored (= skip this step).
select_fn = functools.partial(jnp.where, is_fin)
new_state = train_state.replace(
opt_state=jax.tree_util.tree_map(
select_fn, new_state.opt_state, train_state.opt_state
),
params=jax.tree_util.tree_map(
select_fn, new_state.params, train_state.params
),
)

train_state = new_state.apply_ema(self.ema_decay)

if distributed_training:
grads = jax.lax.pmean(grads, "data")
loss = jax.lax.pmean(loss, "data")
train_state = train_state.apply_gradients(grads=grads)
train_state = train_state.apply_ema(self.ema_decay)
return train_state, loss, rng_state

if distributed_training:
Expand Down
Loading

0 comments on commit 5984fab

Please sign in to comment.