Skip to content

Commit

Permalink
Aligned ModelForTokenClassification and ModelForSequenceClassificatio…
Browse files Browse the repository at this point in the history
…n APIs (guillaume-be#323)
  • Loading branch information
guillaume-be committed Jan 15, 2023
1 parent 445b76f commit f12e8ef
Show file tree
Hide file tree
Showing 24 changed files with 199 additions and 113 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ All notable changes to this project will be documented in this file. The format
- Allow mixing local and remote resources in pipelines.
- Upgraded to `torch` 1.13 (via `tch` 0.9.0).
- (BREAKING) Made the `max_length` argument for generation methods and pipelines optional.
- (BREAKING) Changed return type of `ModelForSequenceClassification` and `ModelForTokenClassification` to `Result<Self, RustBertError>` allowing error handling if no labels are provided in the configuration.

## Fixed
- Fixed configuration check for RoBERTa models for sentence classification.
Expand Down
2 changes: 1 addition & 1 deletion examples/natural_language_inference_deberta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ fn main() -> anyhow::Result<()> {
false,
)?;
let config = DebertaConfig::from_file(config_path);
let model = DebertaForSequenceClassification::new(vs.root(), &config);
let model = DebertaForSequenceClassification::new(vs.root(), &config)?;
vs.load(weights_path)?;

// Define input
Expand Down
17 changes: 12 additions & 5 deletions src/albert/albert_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -505,9 +505,12 @@ impl AlbertForSequenceClassification {
/// let p = nn::VarStore::new(device);
/// let config = AlbertConfig::from_file(config_path);
/// let albert: AlbertForSequenceClassification =
/// AlbertForSequenceClassification::new(&p.root(), &config);
/// AlbertForSequenceClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &AlbertConfig) -> AlbertForSequenceClassification
pub fn new<'p, P>(
p: P,
config: &AlbertConfig,
) -> Result<AlbertForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
Expand All @@ -519,7 +522,11 @@ impl AlbertForSequenceClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
Expand All @@ -528,11 +535,11 @@ impl AlbertForSequenceClassification {
Default::default(),
);

AlbertForSequenceClassification {
Ok(AlbertForSequenceClassification {
albert,
dropout,
classifier,
}
})
}

/// Forward pass through the model
Expand Down
25 changes: 16 additions & 9 deletions src/bart/bart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -695,15 +695,19 @@ pub struct BartClassificationHead {
}

impl BartClassificationHead {
pub fn new<'p, P>(p: P, config: &BartConfig) -> BartClassificationHead
pub fn new<'p, P>(p: P, config: &BartConfig) -> Result<BartClassificationHead, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let dense = nn::linear(
p / "dense",
Expand All @@ -719,11 +723,11 @@ impl BartClassificationHead {
Default::default(),
);

BartClassificationHead {
Ok(BartClassificationHead {
dense,
dropout,
out_proj,
}
})
}

pub fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
Expand Down Expand Up @@ -768,22 +772,25 @@ impl BartForSequenceClassification {
/// let p = nn::VarStore::new(device);
/// let config = BartConfig::from_file(config_path);
/// let bart: BartForSequenceClassification =
/// BartForSequenceClassification::new(&p.root() / "bart", &config);
/// BartForSequenceClassification::new(&p.root() / "bart", &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &BartConfig) -> BartForSequenceClassification
pub fn new<'p, P>(
p: P,
config: &BartConfig,
) -> Result<BartForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();

let base_model = BartModel::new(p / "model", config);
let classification_head = BartClassificationHead::new(p / "classification_head", config);
let classification_head = BartClassificationHead::new(p / "classification_head", config)?;
let eos_token_id = config.eos_token_id.unwrap_or(3);
BartForSequenceClassification {
Ok(BartForSequenceClassification {
base_model,
classification_head,
eos_token_id,
}
})
}

/// Forward pass through the model
Expand Down
17 changes: 12 additions & 5 deletions src/bert/bert_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -682,9 +682,12 @@ impl BertForSequenceClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = BertConfig::from_file(config_path);
/// let bert = BertForSequenceClassification::new(&p.root() / "bert", &config);
/// let bert = BertForSequenceClassification::new(&p.root() / "bert", &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &BertConfig) -> BertForSequenceClassification
pub fn new<'p, P>(
p: P,
config: &BertConfig,
) -> Result<BertForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
Expand All @@ -695,7 +698,11 @@ impl BertForSequenceClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
Expand All @@ -704,11 +711,11 @@ impl BertForSequenceClassification {
Default::default(),
);

BertForSequenceClassification {
Ok(BertForSequenceClassification {
bert,
dropout,
classifier,
}
})
}

/// Forward pass through the model
Expand Down
17 changes: 12 additions & 5 deletions src/deberta/deberta_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -732,9 +732,12 @@ impl DebertaForSequenceClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DebertaConfig::from_file(config_path);
/// let model = DebertaForSequenceClassification::new(&p.root(), &config);
/// let model = DebertaForSequenceClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &DebertaConfig) -> DebertaForSequenceClassification
pub fn new<'p, P>(
p: P,
config: &DebertaConfig,
) -> Result<DebertaForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
Expand All @@ -751,7 +754,11 @@ impl DebertaForSequenceClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;

let classifier = nn::linear(
Expand All @@ -761,12 +768,12 @@ impl DebertaForSequenceClassification {
Default::default(),
);

DebertaForSequenceClassification {
Ok(DebertaForSequenceClassification {
deberta,
pooler,
classifier,
dropout,
}
})
}

/// Forward pass through the model
Expand Down
17 changes: 12 additions & 5 deletions src/deberta_v2/deberta_v2_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,12 @@ impl DebertaV2ForSequenceClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DebertaV2Config::from_file(config_path);
/// let model = DebertaV2ForSequenceClassification::new(&p.root(), &config);
/// let model = DebertaV2ForSequenceClassification::new(&p.root(), &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &DebertaV2Config) -> DebertaV2ForSequenceClassification
pub fn new<'p, P>(
p: P,
config: &DebertaV2Config,
) -> Result<DebertaV2ForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
Expand All @@ -613,7 +616,11 @@ impl DebertaV2ForSequenceClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;

let classifier = nn::linear(
Expand All @@ -623,12 +630,12 @@ impl DebertaV2ForSequenceClassification {
Default::default(),
);

DebertaV2ForSequenceClassification {
Ok(DebertaV2ForSequenceClassification {
deberta,
pooler,
classifier,
dropout,
}
})
}

/// Forward pass through the model
Expand Down
20 changes: 14 additions & 6 deletions src/distilbert/distilbert_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,12 @@ impl DistilBertModelClassifier {
/// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert: DistilBertModelClassifier =
/// DistilBertModelClassifier::new(&p.root() / "distilbert", &config);
/// DistilBertModelClassifier::new(&p.root() / "distilbert", &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &DistilBertConfig) -> DistilBertModelClassifier
pub fn new<'p, P>(
p: P,
config: &DistilBertConfig,
) -> Result<DistilBertModelClassifier, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
Expand All @@ -300,7 +303,11 @@ impl DistilBertModelClassifier {
let num_labels = config
.id2label
.as_ref()
.expect("id2label must be provided for classifiers")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;

let pre_classifier = nn::linear(
Expand All @@ -312,12 +319,12 @@ impl DistilBertModelClassifier {
let classifier = nn::linear(p / "classifier", config.dim, num_labels, Default::default());
let dropout = Dropout::new(config.seq_classif_dropout);

DistilBertModelClassifier {
Ok(DistilBertModelClassifier {
distil_bert_model,
pre_classifier,
classifier,
dropout,
}
})
}

/// Forward pass through the model
Expand Down Expand Up @@ -680,7 +687,8 @@ impl DistilBertForTokenClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = DistilBertConfig::from_file(config_path);
/// let distil_bert = DistilBertForTokenClassification::new(&p.root() / "distilbert", &config).unwrap();
/// let distil_bert =
/// DistilBertForTokenClassification::new(&p.root() / "distilbert", &config).unwrap();
/// ```
pub fn new<'p, P>(
p: P,
Expand Down
17 changes: 12 additions & 5 deletions src/fnet/fnet_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -499,9 +499,12 @@ impl FNetForSequenceClassification {
/// let device = Device::Cpu;
/// let p = nn::VarStore::new(device);
/// let config = FNetConfig::from_file(config_path);
/// let fnet = FNetForSequenceClassification::new(&p.root() / "fnet", &config);
/// let fnet = FNetForSequenceClassification::new(&p.root() / "fnet", &config).unwrap();
/// ```
pub fn new<'p, P>(p: P, config: &FNetConfig) -> FNetForSequenceClassification
pub fn new<'p, P>(
p: P,
config: &FNetConfig,
) -> Result<FNetForSequenceClassification, RustBertError>
where
P: Borrow<nn::Path<'p>>,
{
Expand All @@ -512,7 +515,11 @@ impl FNetForSequenceClassification {
let num_labels = config
.id2label
.as_ref()
.expect("num_labels not provided in configuration")
.ok_or_else(|| {
RustBertError::InvalidConfigurationError(
"num_labels not provided in configuration".to_string(),
)
})?
.len() as i64;
let classifier = nn::linear(
p / "classifier",
Expand All @@ -521,11 +528,11 @@ impl FNetForSequenceClassification {
Default::default(),
);

FNetForSequenceClassification {
Ok(FNetForSequenceClassification {
fnet,
dropout,
classifier,
}
})
}

/// Forward pass through the model
Expand Down
Loading

0 comments on commit f12e8ef

Please sign in to comment.