diff --git a/stanza/models/mwt_expander.py b/stanza/models/mwt_expander.py index f6274211e..8f68cfca6 100644 --- a/stanza/models/mwt_expander.py +++ b/stanza/models/mwt_expander.py @@ -139,6 +139,9 @@ def train(args): logger.warning("Training data available, but dev data has no MWTs. Only training a dict based MWT") args['dict_only'] = True + if args['force_exact_pieces'] and not mwts_composed_of_words(train_doc): + raise ValueError("Cannot train model with --force_exact_pieces, as the MWT in this dataset are not entirely composed of their subwords") + if args['force_exact_pieces'] is None and mwts_composed_of_words(train_doc): # the force_exact_pieces mechanism trains a separate version of the MWT expander in the Trainer # (the training loop here does not need to change) @@ -147,6 +150,8 @@ def train(args): # this behavior can be turned off at training time with --no_force_exact_pieces logger.info("Train MWTs entirely composed of their subwords. Training the MWT to match that paradigm as closely as possible") args['force_exact_pieces'] = True + + if args['force_exact_pieces']: logger.info("Reconverting to BinaryDataLoader") train_batch = BinaryDataLoader(train_doc, args['batch_size'], args, evaluation=False) vocab = train_batch.vocab