Skip to content

Commit

Permalink
Remove the unused pragma and properly apply the bias. (#1147)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Oct 22, 2023
1 parent 3115fe4 commit 5b32c2a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 22 deletions.
7 changes: 0 additions & 7 deletions candle-transformers/src/models/blip.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#![allow(unused)]
use super::blip_text;
use super::with_tracing::{conv2d, linear, Conv2d, Linear};
use candle::{Module, Result, Tensor, D};
Expand Down Expand Up @@ -65,7 +64,6 @@ struct VisionEmbeddings {
class_embedding: Tensor,
patch_embedding: Conv2d,
position_embedding: Tensor,
num_positions: usize,
}

impl VisionEmbeddings {
Expand All @@ -91,7 +89,6 @@ impl VisionEmbeddings {
class_embedding,
patch_embedding,
position_embedding,
num_positions,
})
}
}
Expand All @@ -117,8 +114,6 @@ struct Attention {
qkv: Linear,
projection: Linear,
scale: f64,
embed_dim: usize,
head_dim: usize,
num_heads: usize,
}

Expand All @@ -134,8 +129,6 @@ impl Attention {
qkv,
projection,
scale,
embed_dim,
head_dim,
num_heads,
})
}
Expand Down
22 changes: 7 additions & 15 deletions candle-transformers/src/models/blip_text.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#![allow(unused)]
use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
use super::with_tracing::{linear, Embedding, Linear};
use candle::{Module, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, VarBuilder};

Expand Down Expand Up @@ -63,7 +62,6 @@ struct TextSelfAttention {
query: Linear,
key: Linear,
value: Linear,
all_head_size: usize,
attention_head_size: usize,
num_attention_heads: usize,
attention_scale: f64,
Expand All @@ -87,7 +85,6 @@ impl TextSelfAttention {
query,
key,
value,
all_head_size,
attention_head_size,
num_attention_heads,
attention_scale,
Expand Down Expand Up @@ -301,12 +298,12 @@ impl TextEncoder {
}

#[derive(Debug, Clone)]
struct TextPooler {
pub struct TextPooler {
dense: Linear,
}

impl TextPooler {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
Ok(Self { dense })
}
Expand Down Expand Up @@ -352,19 +349,15 @@ impl Module for TextPredictionHeadTransform {
struct TextLMPredictionHead {
transform: TextPredictionHeadTransform,
decoder: Linear,
bias: Tensor,
}

impl TextLMPredictionHead {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?;
let decoder = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("decoder"))?;
let weight = vb.get((cfg.vocab_size, cfg.hidden_size), "decoder.weight")?;
let bias = vb.get(cfg.vocab_size, "bias")?;
Ok(Self {
transform,
decoder,
bias,
})
let decoder = Linear::from_weights(weight, Some(bias));
Ok(Self { transform, decoder })
}
}

Expand Down Expand Up @@ -396,7 +389,7 @@ impl Module for TextOnlyMLMHead {
struct TextModel {
embeddings: TextEmbeddings,
encoder: TextEncoder,
pooler: Option<TextPooler>,
// We do not need the pooler for caption generation
}

impl TextModel {
Expand All @@ -406,7 +399,6 @@ impl TextModel {
Ok(Self {
embeddings,
encoder,
pooler: None,
})
}

Expand Down
8 changes: 8 additions & 0 deletions candle-transformers/src/models/with_tracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ pub struct Linear {
span: tracing::Span,
}

impl Linear {
pub fn from_weights(weights: Tensor, bias: Option<Tensor>) -> Self {
let inner = candle_nn::Linear::new(weights, bias);
let span = tracing::span!(tracing::Level::TRACE, "linear");
Self { inner, span }
}
}

pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
let inner = candle_nn::linear(d1, d2, vb)?;
let span = tracing::span!(tracing::Level::TRACE, "linear");
Expand Down

0 comments on commit 5b32c2a

Please sign in to comment.