Skip to content

Commit

Permalink
Fix the attachment of the lemma classifier to the pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Dec 24, 2024
1 parent ad17b27 commit 9578401
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
4 changes: 2 additions & 2 deletions stanza/models/lemma/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def save(self, filename, skip_modules=True):
torch.save(params, filename, _use_new_zipfile_serialization=False)
logger.info("Model saved to {}".format(filename))

def load(self, filename, args, foundation_cache):
def load(self, filename, args, foundation_cache, lemma_classifier_args=None):
try:
checkpoint = torch.load(filename, lambda storage, loc: storage, weights_only=True)
except BaseException:
Expand All @@ -313,4 +313,4 @@ def load(self, filename, args, foundation_cache):
self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])
self.contextual_lemmatizers = []
for contextual in checkpoint.get('contextual', []):
self.contextual_lemmatizers.append(LemmaClassifier.from_checkpoint(contextual))
self.contextual_lemmatizers.append(LemmaClassifier.from_checkpoint(contextual, args=lemma_classifier_args))
5 changes: 4 additions & 1 deletion stanza/pipeline/lemma_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,14 @@ def _set_up_model(self, config, pipeline, device):
# since a long running program will remember everything
# (unless we go back and make it smarter)
# we make this an option, not the default
# TODO: need to update the cache to skip the contextual lemmatizer
self.store_results = config.get('store_results', False)
self._use_identity = False
args = {'charlm_forward_file': config.get('forward_charlm_path', None),
'charlm_backward_file': config.get('backward_charlm_path', None)}
self._trainer = Trainer(args=args, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache)
lemma_classifier_args = dict(args)
lemma_classifier_args['wordvec_pretrain_file'] = config.get('pretrain_path', None)
self._trainer = Trainer(args=args, model_file=config['model_path'], device=device, foundation_cache=pipeline.foundation_cache, lemma_classifier_args=lemma_classifier_args)

def _set_up_requires(self):
self._pretagged = self._config.get('pretagged', None)
Expand Down

0 comments on commit 9578401

Please sign in to comment.