Skip to content

Commit

Permalink
feat: latest changes
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 7, 2024
1 parent 6e58e0e commit 8e7aaf5
Show file tree
Hide file tree
Showing 4 changed files with 367 additions and 126 deletions.
334 changes: 286 additions & 48 deletions Diffusion flax linen.ipynb

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions flaxdiff/models/autoencoder/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from ..common import kernel_init, ConvLayer, Upsample, Downsample, PixelShuffle


class AutoEncoder(nn.Module):
class AutoEncoder():
def encode(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
raise NotImplementedError

def decode(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
raise NotImplementedError

@nn.compact
def __call__(self, *args, **kwargs) -> Any:
raise NotImplementedError
def __call__(self, x: jnp.ndarray):
latents = self.encode(x)
reconstructions = self.decode(latents)
return reconstructions
76 changes: 76 additions & 0 deletions flaxdiff/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,79 @@ def l2norm(t, axis=1, eps=1e-12):
denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
out = t/denom
return (out)


class ResidualBlock(nn.Module):
conv_type:str
features:int
kernel_size:tuple=(3, 3)
strides:tuple=(1, 1)
padding:str="SAME"
activation:Callable=jax.nn.swish
direction:str=None
res:int=2
norm_groups:int=8
kernel_init:Callable=kernel_init(1.0)
dtype: Optional[Dtype] = None
precision: PrecisionLike = None

@nn.compact
def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
residual = x
# out = nn.GroupNorm(self.norm_groups)(x)
out = nn.RMSNorm()(x)
out = self.activation(out)

out = ConvLayer(
self.conv_type,
features=self.features,
kernel_size=self.kernel_size,
strides=self.strides,
kernel_init=self.kernel_init,
name="conv1",
dtype=self.dtype,
precision=self.precision
)(out)

temb = nn.DenseGeneral(
features=self.features,
name="temb_projection",
dtype=self.dtype,
precision=self.precision)(temb)
temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
# scale, shift = jnp.split(temb, 2, axis=-1)
# out = out * (1 + scale) + shift
out = out + temb

# out = nn.GroupNorm(self.norm_groups)(out)
out = nn.RMSNorm()(out)
out = self.activation(out)

out = ConvLayer(
self.conv_type,
features=self.features,
kernel_size=self.kernel_size,
strides=self.strides,
kernel_init=self.kernel_init,
name="conv2",
dtype=self.dtype,
precision=self.precision
)(out)

if residual.shape != out.shape:
residual = ConvLayer(
self.conv_type,
features=self.features,
kernel_size=(1, 1),
strides=1,
kernel_init=self.kernel_init,
name="residual_conv",
dtype=self.dtype,
precision=self.precision
)(residual)
out = out + residual

out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out

return out

74 changes: 0 additions & 74 deletions flaxdiff/models/simple_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,80 +7,6 @@
from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection
from .attention import TransformerBlock

class ResidualBlock(nn.Module):
conv_type:str
features:int
kernel_size:tuple=(3, 3)
strides:tuple=(1, 1)
padding:str="SAME"
activation:Callable=jax.nn.swish
direction:str=None
res:int=2
norm_groups:int=8
kernel_init:Callable=kernel_init(1.0)
dtype: Optional[Dtype] = None
precision: PrecisionLike = None

@nn.compact
def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
residual = x
# out = nn.GroupNorm(self.norm_groups)(x)
out = nn.RMSNorm()(x)
out = self.activation(out)

out = ConvLayer(
self.conv_type,
features=self.features,
kernel_size=self.kernel_size,
strides=self.strides,
kernel_init=self.kernel_init,
name="conv1",
dtype=self.dtype,
precision=self.precision
)(out)

temb = nn.DenseGeneral(
features=self.features,
name="temb_projection",
dtype=self.dtype,
precision=self.precision)(temb)
temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
# scale, shift = jnp.split(temb, 2, axis=-1)
# out = out * (1 + scale) + shift
out = out + temb

# out = nn.GroupNorm(self.norm_groups)(out)
out = nn.RMSNorm()(out)
out = self.activation(out)

out = ConvLayer(
self.conv_type,
features=self.features,
kernel_size=self.kernel_size,
strides=self.strides,
kernel_init=self.kernel_init,
name="conv2",
dtype=self.dtype,
precision=self.precision
)(out)

if residual.shape != out.shape:
residual = ConvLayer(
self.conv_type,
features=self.features,
kernel_size=(1, 1),
strides=1,
kernel_init=self.kernel_init,
name="residual_conv",
dtype=self.dtype,
precision=self.precision
)(residual)
out = out + residual

out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out

return out

class Unet(nn.Module):
output_channels:int=3
emb_features:int=64*4,
Expand Down

0 comments on commit 8e7aaf5

Please sign in to comment.