Skip to content

Commit

Permalink
fix: tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AlekseyKorshuk committed Mar 24, 2022
1 parent b374dc7 commit 39102a4
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import string
import unittest
from abc import abstractmethod
from functools import lru_cache
from unittest import skipIf
from functools import lru_cache

from transformers import (
FEATURE_EXTRACTOR_MAPPING,
Expand All @@ -23,7 +23,6 @@
from transformers.pipelines.base import _pad
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -153,8 +152,8 @@ def test(self):
if isinstance(model.config, (RobertaConfig, IBertConfig)):
tokenizer.model_max_length = model.config.max_position_embeddings - 2
elif (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings > 0
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings > 0
):
tokenizer.model_max_length = model.config.max_position_embeddings
# Rust Panic exception are NOT Exception subclass
Expand Down Expand Up @@ -591,4 +590,4 @@ def add(number, extra=0):
dataset = PipelinePackIterator(dummy_dataset, add, {"extra": 2}, loader_batch_size=3)

outputs = [item for item in dataset]
self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]])
self.assertEqual(outputs, [[{"id": 2}, {"id": 3}, {"id": 4}, {"id": 5}]])

0 comments on commit 39102a4

Please sign in to comment.