diff --git a/forte_wrapper/allennlp/allennlp_processors.py b/forte_wrapper/allennlp/allennlp_processors.py index e59f347..55a2eb9 100644 --- a/forte_wrapper/allennlp/allennlp_processors.py +++ b/forte_wrapper/allennlp/allennlp_processors.py @@ -15,7 +15,7 @@ import itertools import logging import more_itertools -from typing import Dict, List +from typing import Any, Dict, List from allennlp.predictors import Predictor @@ -138,16 +138,19 @@ def _process(self, input_pack: DataPack): # handle existing entries self._process_existing_entries(input_pack) - batch_size = self.configs['infer_batch_size'] + batch_size: int = self.configs['infer_batch_size'] if batch_size <= 0: batches = iter([input_pack.get(Sentence)]) else: batches = more_itertools.ichunked(input_pack.get(Sentence), batch_size) for sentences in batches: - inputs = [{"sentence": s.text} for s in sentences] - results = {k: p.predict_batch_json(inputs) - for k, p in self.predictor.items()} + inputs: List[Dict[str, str]] = [{"sentence": s.text} + for s in sentences] + results: Dict[str, List[Dict[str, Any]]] = { + k: p.predict_batch_json(inputs) + for k, p in self.predictor.items() + } for i, sentence in enumerate(sentences): result: Dict[str, List[str]] = {} for key in self.predictor.keys():