diff --git a/src/transformers/pipelines/mask_generation.py b/src/transformers/pipelines/mask_generation.py index 920ce9d0a39c..ffbe5868f7a7 100644 --- a/src/transformers/pipelines/mask_generation.py +++ b/src/transformers/pipelines/mask_generation.py @@ -231,7 +231,7 @@ def preprocess( for i in range(0, n_points, points_per_batch): batched_points = grid_points[:, i : i + points_per_batch, :, :] labels = input_labels[:, i : i + points_per_batch] - is_last = i == n_points - points_per_batch + is_last = i + points_per_batch >= n_points yield { "input_points": batched_points, "input_labels": labels, diff --git a/tests/pipelines/test_pipelines_mask_generation.py b/tests/pipelines/test_pipelines_mask_generation.py index 07ce9c5d91ca..4d25bb9c0cc1 100644 --- a/tests/pipelines/test_pipelines_mask_generation.py +++ b/tests/pipelines/test_pipelines_mask_generation.py @@ -93,6 +93,16 @@ def get_test_pipeline( def run_pipeline_test(self, mask_generator, examples): pass + def test_preprocess_is_last(self): + mask_generator = pipeline("mask-generation", model="hf-internal-testing/tiny-random-SamModel") + mask_generator.image_processor.pad_size = {"height": 24, "width": 24} + image = "./tests/fixtures/tests_samples/COCO/000000039769.png" + for points_per_batch in (100, 64): + with self.subTest(points_per_batch=points_per_batch): + batches = list(mask_generator.preprocess(image, points_per_batch=points_per_batch)) + self.assertTrue(batches[-1]["is_last"]) + self.assertFalse(any(b["is_last"] for b in batches[:-1])) + @slow @require_torch def test_small_model_pt(self):