Skip to content

Commit

Permalink
Add the encoder down-blocks.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Aug 3, 2024
1 parent b0b5855 commit 0269fe6
Showing 1 changed file with 76 additions and 9 deletions.
85 changes: 76 additions & 9 deletions candle-transformers/src/models/flux/autoencoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,36 @@ impl candle::Module for ResnetBlock {
}
}

#[derive(Debug, Clone)]
struct Downsample {
conv: Conv2d,
}

impl Downsample {
fn new(in_c: usize, vb: VarBuilder) -> Result<Self> {
let conv_cfg = candle_nn::Conv2dConfig {
stride: 2,
..Default::default()
};
let conv = conv2d(in_c, in_c, 3, conv_cfg, vb.pp("conv"))?;
Ok(Self { conv })
}
}

impl candle::Module for Downsample {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.pad_with_zeros(D::Minus1, 0, 1)?;
let xs = xs.pad_with_zeros(D::Minus2, 0, 1)?;
xs.apply(&self.conv)
}
}

#[derive(Debug, Clone)]
struct DownBlock {
block: Vec<ResnetBlock>,
downsample: Option<Downsample>,
}

#[derive(Debug, Clone)]
pub struct Encoder {
conv_in: Conv2d,
Expand All @@ -129,6 +159,7 @@ pub struct Encoder {
mid_block_2: ResnetBlock,
norm_out: GroupNorm,
conv_out: Conv2d,
down: Vec<DownBlock>,
}

impl Encoder {
Expand All @@ -137,10 +168,38 @@ impl Encoder {
padding: 1,
..Default::default()
};
let block_in = cfg.ch;
let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid_block_1"))?;
let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid_attn_1"))?;
let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid_block_2"))?;
let mut block_in = cfg.ch;

let mut down = Vec::with_capacity(cfg.ch_mult.len());
let vb_d = vb.pp("down");
for (i_level, ch_mult) in cfg.ch_mult.iter().enumerate() {
let mut block = Vec::with_capacity(cfg.num_res_blocks);
let vb_d = vb_d.pp(i_level);
let vb_b = vb_d.pp("block");
let in_ch_mult = if i_level == 0 {
1
} else {
cfg.ch_mult[i_level - 1]
};
block_in = cfg.ch * in_ch_mult;
let block_out = cfg.ch * ch_mult;
for i_block in 0..cfg.num_res_blocks {
let b = ResnetBlock::new(block_in, block_out, vb_b.pp(i_block))?;
block.push(b);
block_in = block_out;
}
let downsample = if i_level != cfg.ch_mult.len() - 1 {
Some(Downsample::new(block_in, vb_d.pp("downsample"))?)
} else {
None
};
let block = DownBlock { block, downsample };
down.push(block)
}

let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?;
let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?;
let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?;
let conv_in = conv2d(cfg.in_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
let conv_out = conv2d(cfg.in_channels, block_in, 3, conv_cfg, vb.pp("conv_out"))?;
let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?;
Expand All @@ -151,14 +210,22 @@ impl Encoder {
mid_block_2,
norm_out,
conv_out,
down,
})
}
}

impl candle_nn::Module for Encoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let h = xs.apply(&self.conv_in)?;
// TODO: blocks for downsampling
let mut h = xs.apply(&self.conv_in)?;
for block in self.down.iter() {
for b in block.block.iter() {
h = h.apply(b)?
}
if let Some(ds) = block.downsample.as_ref() {
h = h.apply(ds)?
}
}
h.apply(&self.mid_block_1)?
.apply(&self.mid_attn_1)?
.apply(&self.mid_block_2)?
Expand All @@ -185,9 +252,9 @@ impl Decoder {
..Default::default()
};
let block_in = cfg.ch * cfg.ch_mult.last().unwrap_or(&1);
let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid_block_1"))?;
let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid_attn_1"))?;
let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid_block_2"))?;
let mid_block_1 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_1"))?;
let mid_attn_1 = AttnBlock::new(block_in, vb.pp("mid.attn_1"))?;
let mid_block_2 = ResnetBlock::new(block_in, block_in, vb.pp("mid.block_2"))?;
let conv_in = conv2d(cfg.z_channels, block_in, 3, conv_cfg, vb.pp("conv_in"))?;
let norm_out = group_norm(32, block_in, 1e-6, vb.pp("norm_out"))?;
let conv_out = conv2d(block_in, cfg.out_ch, 3, conv_cfg, vb.pp("conv_out"))?;
Expand Down

0 comments on commit 0269fe6

Please sign in to comment.