Skip to content

Commit

Permalink
Fix the attachment of the lemma classifier to the pipeline
Browse files Browse the repository at this point in the history
Add pretrains to the lemmatizers with contextual lemmas in the resources
  • Loading branch information
AngledLuffa committed Dec 24, 2024
1 parent 6f56d83 commit 7dafdfb
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 6 deletions.
8 changes: 4 additions & 4 deletions stanza/models/lemma/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def unpack_batch(batch, device):

class Trainer(object):
""" A trainer for training models. """
def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, device=None, foundation_cache=None):
def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, device=None, foundation_cache=None, lemma_classifier_args=None):
if model_file is not None:
# load everything from file
self.load(model_file, args, foundation_cache)
self.load(model_file, args, foundation_cache, lemma_classifier_args)
else:
# build model from scratch
self.args = args
Expand Down Expand Up @@ -292,7 +292,7 @@ def save(self, filename, skip_modules=True):
torch.save(params, filename, _use_new_zipfile_serialization=False)
logger.info("Model saved to {}".format(filename))

def load(self, filename, args, foundation_cache):
def load(self, filename, args, foundation_cache, lemma_classifier_args=None):
try:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except BaseException:
Expand All @@ -313,4 +313,4 @@ def load(self, filename, args, foundation_cache):
self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])
self.contextual_lemmatizers = []
for contextual in checkpoint.get('contextual', []):
self.contextual_lemmatizers.append(LemmaClassifier.from_checkpoint(contextual))
self.contextual_lemmatizers.append(LemmaClassifier.from_checkpoint(contextual, args=lemma_classifier_args))
5 changes: 4 additions & 1 deletion stanza/pipeline/lemma_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,14 @@ def _set_up_model(self, config, pipeline, device):
# since a long running program will remember everything
# (unless we go back and make it smarter)
# we make this an option, not the default
# TODO: need to update the cache to skip the contextual lemmatizer
self.store_results = config.get('store_results', False)
self._use_identity = False
args = {'charlm_forward_file': config.get('forward_charlm_path', None),
'charlm_backward_file': config.get('backward_charlm_path', None)}
self._trainer = Trainer(args=args, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache)
lemma_classifier_args = dict(args)
lemma_classifier_args['wordvec_pretrain_file'] = config.get('pretrain_path', None)
self._trainer = Trainer(args=args, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache, lemma_classifier_args=lemma_classifier_args)

def _set_up_requires(self):
self._pretagged = self._config.get('pretagged', None)
Expand Down
18 changes: 17 additions & 1 deletion stanza/resources/prepare_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from stanza.models.common.constant import lcode2lang, two_to_three_letters, three_to_two_letters
from stanza.resources.default_packages import PACKAGES, TRANSFORMERS, TRANSFORMER_NICKNAMES
from stanza.resources.default_packages import *
from stanza.utils.datasets.prepare_lemma_classifier import DATASET_MAPPING as LEMMA_CLASSIFIER_DATASETS
from stanza.utils.get_tqdm import get_tqdm

tqdm = get_tqdm()
Expand Down Expand Up @@ -179,14 +180,29 @@ def get_pos_dependencies(lang, package):

return dependencies

def get_lemma_pretrain_package(lang, package):
package, uses_pretrain, uses_charlm = split_package(package)
if not uses_pretrain:
return None
if not uses_charlm:
# currently the contextual lemma classifier is only active
# for the charlm lemmatizers
return None
if "%s_%s" % (lang, package) not in LEMMA_CLASSIFIER_DATASETS:
return None
return get_pretrain_package(lang, package, {}, default_pretrains)

def get_lemma_charlm_package(lang, package):
return get_charlm_package(lang, package, lemma_charlms, default_charlms)

def get_lemma_dependencies(lang, package):
dependencies = []

charlm_package = get_lemma_charlm_package(lang, package)
pretrain_package = get_lemma_pretrain_package(lang, package)
if pretrain_package is not None:
dependencies.append({'model': 'pretrain', 'package': pretrain_package})

charlm_package = get_lemma_charlm_package(lang, package)
if charlm_package is not None:
dependencies.append({'model': 'forward_charlm', 'package': charlm_package})
dependencies.append({'model': 'backward_charlm', 'package': charlm_package})
Expand Down

0 comments on commit 7dafdfb

Please sign in to comment.