Skip to content

Commit

Permalink
fix: refactored kernel_init calls so that scale can be overriden from…
Browse files Browse the repository at this point in the history
… outside
  • Loading branch information
AshishKumar4 committed Sep 9, 2024
1 parent c076c31 commit 3c22222
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 42 deletions.
45 changes: 16 additions & 29 deletions evaluate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2672,7 +2672,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -2695,7 +2695,7 @@
" embedding_dim: int\n",
" dtype: Any = jnp.float32\n",
" precision: Any = jax.lax.Precision.HIGH\n",
" kernel_init: Callable = kernel_init(1.0)\n",
" kernel_init: Callable = partial(kernel_init, 1.0)\n",
"\n",
" @nn.compact\n",
" def __call__(self, x):\n",
Expand All @@ -2706,7 +2706,7 @@
" kernel_size=(self.patch_size, self.patch_size), \n",
" strides=(self.patch_size, self.patch_size),\n",
" dtype=self.dtype,\n",
" kernel_init=self.kernel_init,\n",
" kernel_init=self.kernel_init(),\n",
" precision=self.precision)(x)\n",
" x = jnp.reshape(x, (batch, -1, self.embedding_dim))\n",
" return x\n",
Expand Down Expand Up @@ -2739,7 +2739,7 @@
" norm_groups:int=8\n",
" dtype: Optional[Dtype] = None\n",
" precision: PrecisionLike = None\n",
" kernel_init: Callable = partial(kernel_init)\n",
" kernel_init: Callable = partial(kernel_init, scale=1.0)\n",
" add_residualblock_output: bool = False\n",
"\n",
" def setup(self):\n",
Expand All @@ -2758,10 +2758,10 @@
"\n",
" # Patch embedding\n",
" x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features, \n",
" dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init(1.0))(x)\n",
" dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x)\n",
" num_patches = x.shape[1]\n",
" \n",
" context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(1.0), \n",
" context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(), \n",
" dtype=self.dtype, precision=self.precision)(textcontext)\n",
" num_text_tokens = textcontext.shape[1]\n",
" \n",
Expand All @@ -2784,32 +2784,32 @@
" dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, \n",
" use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax, \n",
" only_pure_attention=False,\n",
" kernel_init=self.kernel_init(1.0))(x)\n",
" kernel_init=self.kernel_init())(x)\n",
" skips.append(x)\n",
" \n",
" # Middle block\n",
" x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, \n",
" dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, \n",
" use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax, \n",
" only_pure_attention=False,\n",
" kernel_init=self.kernel_init(1.0))(x)\n",
" kernel_init=self.kernel_init())(x)\n",
" \n",
" # # Out blocks\n",
" for i in range(self.num_layers // 2):\n",
" x = jnp.concatenate([x, skips.pop()], axis=-1)\n",
" x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(1.0), \n",
" x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(), \n",
" dtype=self.dtype, precision=self.precision)(x)\n",
" x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, \n",
" dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, \n",
" use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax, \n",
" only_pure_attention=False,\n",
" kernel_init=self.kernel_init(1.0))(x)\n",
" kernel_init=self.kernel_init())(x)\n",
" \n",
" # print(f'Shape of x after transformer blocks: {x.shape}')\n",
" x = self.norm()(x)\n",
" \n",
" patch_dim = self.patch_size ** 2 * self.output_channels\n",
" x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init(1.0))(x)\n",
" x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)\n",
" x = x[:, 1 + num_text_tokens:, :]\n",
" x = unpatchify(x, channels=self.output_channels)\n",
" \n",
Expand All @@ -2823,7 +2823,7 @@
" kernel_size=(3, 3),\n",
" strides=(1, 1),\n",
" # activation=jax.nn.mish\n",
" kernel_init=self.kernel_init(0.0),\n",
" kernel_init=self.kernel_init(scale=0.0),\n",
" dtype=self.dtype,\n",
" precision=self.precision\n",
" )(x)\n",
Expand All @@ -2837,7 +2837,7 @@
" kernel_size=(3, 3),\n",
" strides=(1, 1),\n",
" # activation=jax.nn.mish\n",
" kernel_init=self.kernel_init(0.0),\n",
" kernel_init=self.kernel_init(scale=0.0),\n",
" dtype=self.dtype,\n",
" precision=self.precision\n",
" )(x)\n",
Expand All @@ -2846,23 +2846,9 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 42,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "kernel_init() missing 1 required positional argument: 'scale'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[37], line 12\u001b[0m\n\u001b[1;32m 3\u001b[0m textcontext \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mones((\u001b[38;5;241m8\u001b[39m, \u001b[38;5;241m77\u001b[39m, \u001b[38;5;241m768\u001b[39m), dtype\u001b[38;5;241m=\u001b[39mjnp\u001b[38;5;241m.\u001b[39mbfloat16) \n\u001b[1;32m 4\u001b[0m vit \u001b[38;5;241m=\u001b[39m UViT(patch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m, \n\u001b[1;32m 5\u001b[0m emb_features\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m768\u001b[39m, \n\u001b[1;32m 6\u001b[0m num_layers\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m12\u001b[39m, \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 10\u001b[0m norm_groups\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m,\n\u001b[1;32m 11\u001b[0m dtype\u001b[38;5;241m=\u001b[39mjnp\u001b[38;5;241m.\u001b[39mbfloat16)\n\u001b[0;32m---> 12\u001b[0m params \u001b[38;5;241m=\u001b[39m \u001b[43mvit\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrandom\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPRNGKey\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtextcontext\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;129m@jax\u001b[39m\u001b[38;5;241m.\u001b[39mjit\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mapply\u001b[39m(params, x, temb, textcontext):\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m vit\u001b[38;5;241m.\u001b[39mapply(params, x, temb, textcontext)\n",
" \u001b[0;31m[... skipping hidden 9 frame]\u001b[0m\n",
"Cell \u001b[0;32mIn[36], line 122\u001b[0m, in \u001b[0;36mUViT.__call__\u001b[0;34m(self, x, temb, textcontext)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_layers \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;241m2\u001b[39m):\n\u001b[1;32m 121\u001b[0m x \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39mconcatenate([x, skips\u001b[38;5;241m.\u001b[39mpop()], axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m--> 122\u001b[0m x \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mDenseGeneral(features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39memb_features, kernel_init\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mkernel_init\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m, \n\u001b[1;32m 123\u001b[0m dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdtype, precision\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision)(x)\n\u001b[1;32m 124\u001b[0m x \u001b[38;5;241m=\u001b[39m TransformerBlock(heads\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_heads, dim_head\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39memb_features \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_heads, \n\u001b[1;32m 125\u001b[0m dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdtype, precision\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision, use_projection\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_projection, \n\u001b[1;32m 126\u001b[0m use_flash_attention\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_flash_attention, use_self_and_cross\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_self_and_cross, force_fp32_for_softmax\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforce_fp32_for_softmax, \n\u001b[1;32m 127\u001b[0m only_pure_attention\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 128\u001b[0m kernel_init\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mkernel_init(\u001b[38;5;241m1.0\u001b[39m))(x)\n\u001b[1;32m 130\u001b[0m \u001b[38;5;66;03m# print(f'Shape of x after transformer blocks: {x.shape}')\u001b[39;00m\n",
"\u001b[0;31mTypeError\u001b[0m: kernel_init() missing 1 required positional argument: 'scale'"
]
}
],
"outputs": [],
"source": [
"x = jnp.ones((8, 128, 128, 3), dtype=jnp.bfloat16)\n",
"temb = jnp.ones((8,), dtype=jnp.bfloat16)\n",
Expand All @@ -2874,6 +2860,7 @@
" dropout_rate=0.1, \n",
" add_residualblock_output=True,\n",
" norm_groups=0,\n",
" kernel_init=partial(kernel_init, scale=1.0),\n",
" dtype=jnp.bfloat16)\n",
"params = vit.init(jax.random.PRNGKey(0), x, temb, textcontext)\n",
"\n",
Expand Down
24 changes: 12 additions & 12 deletions flaxdiff/models/simple_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class PatchEmbedding(nn.Module):
embedding_dim: int
dtype: Any = jnp.float32
precision: Any = jax.lax.Precision.HIGH
kernel_init: Callable = kernel_init(1.0)
kernel_init: Callable = partial(kernel_init, 1.0)

@nn.compact
def __call__(self, x):
Expand All @@ -34,7 +34,7 @@ def __call__(self, x):
kernel_size=(self.patch_size, self.patch_size),
strides=(self.patch_size, self.patch_size),
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_init=self.kernel_init(),
precision=self.precision)(x)
x = jnp.reshape(x, (batch, -1, self.embedding_dim))
return x
Expand Down Expand Up @@ -67,7 +67,7 @@ class UViT(nn.Module):
norm_groups:int=8
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
kernel_init: Callable = partial(kernel_init)
kernel_init: Callable = partial(kernel_init, scale=1.0)
add_residualblock_output: bool = False

def setup(self):
Expand All @@ -86,10 +86,10 @@ def __call__(self, x, temb, textcontext=None):

# Patch embedding
x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init(1.0))(x)
dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x)
num_patches = x.shape[1]

context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(1.0),
context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
dtype=self.dtype, precision=self.precision)(textcontext)
num_text_tokens = textcontext.shape[1]

Expand All @@ -112,32 +112,32 @@ def __call__(self, x, temb, textcontext=None):
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
only_pure_attention=False,
kernel_init=self.kernel_init(1.0))(x)
kernel_init=self.kernel_init())(x)
skips.append(x)

# Middle block
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
only_pure_attention=False,
kernel_init=self.kernel_init(1.0))(x)
kernel_init=self.kernel_init())(x)

# # Out blocks
for i in range(self.num_layers // 2):
x = jnp.concatenate([x, skips.pop()], axis=-1)
x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(1.0),
x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
dtype=self.dtype, precision=self.precision)(x)
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
only_pure_attention=False,
kernel_init=self.kernel_init(1.0))(x)
kernel_init=self.kernel_init())(x)

# print(f'Shape of x after transformer blocks: {x.shape}')
x = self.norm()(x)

patch_dim = self.patch_size ** 2 * self.output_channels
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init(1.0))(x)
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)
x = x[:, 1 + num_text_tokens:, :]
x = unpatchify(x, channels=self.output_channels)

Expand All @@ -151,7 +151,7 @@ def __call__(self, x, temb, textcontext=None):
kernel_size=(3, 3),
strides=(1, 1),
# activation=jax.nn.mish
kernel_init=self.kernel_init(0.0),
kernel_init=self.kernel_init(scale=0.0),
dtype=self.dtype,
precision=self.precision
)(x)
Expand All @@ -165,7 +165,7 @@ def __call__(self, x, temb, textcontext=None):
kernel_size=(3, 3),
strides=(1, 1),
# activation=jax.nn.mish
kernel_init=self.kernel_init(0.0),
kernel_init=self.kernel_init(scale=0.0),
dtype=self.dtype,
precision=self.precision
)(x)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name='flaxdiff',
packages=find_packages(),
version='0.1.30',
version='0.1.31',
description='A versatile and easy to understand Diffusion library',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 3c22222

Please sign in to comment.