Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change/bert encoder public #2658

Merged
merged 7 commits into from
Dec 4, 2024
Merged
51 changes: 30 additions & 21 deletions candle-transformers/src/models/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub enum HiddenAct {
Relu,
}

#[derive(Clone)]
struct HiddenActLayer {
act: HiddenAct,
span: tracing::Span,
Expand All @@ -46,32 +47,32 @@ impl HiddenActLayer {

#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
enum PositionEmbeddingType {
pub enum PositionEmbeddingType {
#[default]
Absolute,
}

// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
vocab_size: usize,
hidden_size: usize,
num_hidden_layers: usize,
num_attention_heads: usize,
intermediate_size: usize,
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub intermediate_size: usize,
pub hidden_act: HiddenAct,
hidden_dropout_prob: f64,
max_position_embeddings: usize,
type_vocab_size: usize,
initializer_range: f64,
layer_norm_eps: f64,
pad_token_id: usize,
pub hidden_dropout_prob: f64,
pub max_position_embeddings: usize,
pub type_vocab_size: usize,
pub initializer_range: f64,
pub layer_norm_eps: f64,
pub pad_token_id: usize,
#[serde(default)]
position_embedding_type: PositionEmbeddingType,
pub position_embedding_type: PositionEmbeddingType,
#[serde(default)]
use_cache: bool,
classifier_dropout: Option<f64>,
model_type: Option<String>,
pub use_cache: bool,
pub classifier_dropout: Option<f64>,
pub model_type: Option<String>,
}

impl Default for Config {
Expand Down Expand Up @@ -121,6 +122,7 @@ impl Config {
}
}

#[derive(Clone)]
struct Dropout {
#[allow(dead_code)]
pr: f64,
Expand Down Expand Up @@ -199,6 +201,7 @@ impl BertEmbeddings {
}
}

#[derive(Clone)]
struct BertSelfAttention {
query: Linear,
key: Linear,
Expand Down Expand Up @@ -266,6 +269,7 @@ impl BertSelfAttention {
}
}

#[derive(Clone)]
struct BertSelfOutput {
dense: Linear,
layer_norm: LayerNorm,
Expand Down Expand Up @@ -299,6 +303,7 @@ impl BertSelfOutput {
}

// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
#[derive(Clone)]
struct BertAttention {
self_attention: BertSelfAttention,
self_output: BertSelfOutput,
Expand All @@ -325,6 +330,7 @@ impl BertAttention {
}

// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
#[derive(Clone)]
struct BertIntermediate {
dense: Linear,
intermediate_act: HiddenActLayer,
Expand Down Expand Up @@ -352,6 +358,7 @@ impl Module for BertIntermediate {
}

// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
#[derive(Clone)]
struct BertOutput {
dense: Linear,
layer_norm: LayerNorm,
Expand Down Expand Up @@ -385,7 +392,8 @@ impl BertOutput {
}

// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
struct BertLayer {
#[derive(Clone)]
pub struct BertLayer {
attention: BertAttention,
intermediate: BertIntermediate,
output: BertOutput,
Expand Down Expand Up @@ -420,21 +428,22 @@ impl BertLayer {
}

// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
struct BertEncoder {
layers: Vec<BertLayer>,
#[derive(Clone)]
pub struct BertEncoder {
pub layers: Vec<BertLayer>,
span: tracing::Span,
}

impl BertEncoder {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let layers = (0..config.num_hidden_layers)
.map(|index| BertLayer::load(vb.pp(format!("layer.{index}")), config))
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(BertEncoder { layers, span })
}

fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
pub fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut hidden_states = hidden_states.clone();
// Use a loop rather than a fold as it's easier to modify when adding debug/...
Expand Down
Loading