Celery Task Testing with pytest: Mocking, Retries, and Chaining

Celery Task Testing with pytest: Mocking, Retries, and Chaining

Celery powers background task processing in Python applications. It's powerful but notoriously tricky to test because of its asynchronous, distributed nature. The good news: Celery ships with testing infrastructure that makes most scenarios straightforward. The tricky parts — retries, chains, countdown — require some extra care.

Core Testing Modes

CELERY_TASK_ALWAYS_EAGER

The most important configuration for testing: task_always_eager makes tasks execute synchronously in the calling process.

# conftest.py
import pytest
from celery import Celery

@pytest.fixture(autouse=True)
def celery_config():
    """Make tasks run synchronously in tests"""
    return {
        'task_always_eager': True,
        'task_eager_propagates': True,  # Raise exceptions instead of capturing
        'broker_url': 'memory://',
        'result_backend': 'cache+memory://',
    }

Or with pytest-celery:

# conftest.py
@pytest.fixture
def celery_app():
    from myapp.celery import app
    app.conf.update(
        task_always_eager=True,
        task_eager_propagates=True,
    )
    return app

Using the @pytest.mark.celery Decorator

# pytest.ini
[pytest]
addopts = --strict-markers

# conftest.py
@pytest.fixture(scope='session')
def celery_config():
    return {
        'broker_url': 'memory://',
        'result_backend': 'cache+memory://',
    }
import pytest

@pytest.mark.celery(task_always_eager=True)
def test_my_task():
    result = my_task.delay(1, 2)
    assert result.get() == 3

Defining Tasks to Test

# tasks.py
from celery import shared_task, current_task
from celery.exceptions import MaxRetriesExceededError
import logging

logger = logging.getLogger(__name__)


@shared_task(bind=True, max_retries=3, default_retry_delay=60)
def process_payment(self, order_id: int, amount: float):
    """Process a payment — retries on network errors"""
    from .models import Order, Payment
    
    order = Order.objects.get(id=order_id)
    
    if order.is_paid():
        logger.info(f"Order {order_id} already paid, skipping")
        return {'status': 'skipped', 'order_id': order_id}
    
    try:
        payment = PaymentGateway.charge(order.customer_email, amount)
        order.mark_paid(payment.transaction_id)
        return {'status': 'success', 'transaction_id': payment.transaction_id}
    
    except NetworkError as exc:
        raise self.retry(exc=exc, countdown=30 * (self.request.retries + 1))
    
    except InsufficientFundsError:
        order.mark_payment_failed('insufficient_funds')
        raise  # Don't retry — it won't succeed


@shared_task
def send_confirmation_email(order_id: int):
    from .models import Order
    order = Order.objects.get(id=order_id)
    EmailService.send_order_confirmation(order)
    return {'email_sent': True, 'order_id': order_id}

Unit Testing Tasks

# tests/test_tasks.py
import pytest
from unittest.mock import patch, MagicMock
from celery.exceptions import MaxRetriesExceededError, Retry
from myapp.tasks import process_payment, send_confirmation_email


class TestProcessPayment:
    
    @pytest.fixture
    def order(self, db):
        return Order.objects.create(
            customer_email='alice@example.com',
            total=99.99,
            status='pending'
        )
    
    def test_successful_payment(self, order):
        with patch('myapp.tasks.PaymentGateway.charge') as mock_charge:
            mock_charge.return_value = MagicMock(transaction_id='txn_abc123')
            
            result = process_payment(order.id, 99.99)
            
            assert result['status'] == 'success'
            assert result['transaction_id'] == 'txn_abc123'
            
            order.refresh_from_db()
            assert order.status == 'paid'
    
    def test_skips_already_paid_order(self, order):
        order.status = 'paid'
        order.save()
        
        with patch('myapp.tasks.PaymentGateway.charge') as mock_charge:
            result = process_payment(order.id, 99.99)
            
            assert result['status'] == 'skipped'
            mock_charge.assert_not_called()
    
    def test_retries_on_network_error(self, order, celery_app):
        celery_app.conf.task_always_eager = False  # Disable eager for retry test
        
        with patch('myapp.tasks.PaymentGateway.charge') as mock_charge:
            mock_charge.side_effect = NetworkError("Connection timeout")
            
            with pytest.raises(Retry):
                process_payment.apply(args=[order.id, 99.99])
    
    def test_does_not_retry_on_insufficient_funds(self, order):
        with patch('myapp.tasks.PaymentGateway.charge') as mock_charge:
            mock_charge.side_effect = InsufficientFundsError("Card declined")
            
            with pytest.raises(InsufficientFundsError):
                process_payment(order.id, 99.99)
            
            order.refresh_from_db()
            assert order.status == 'payment_failed'
            assert order.failure_reason == 'insufficient_funds'

Testing Retries

Retry testing requires bypassing task_always_eager because eager mode doesn't simulate the retry mechanism:

class TestRetryBehavior:
    
    def test_retry_countdown_increases_with_attempts(self):
        """Verify retry delay increases: 30s, 60s, 90s"""
        task = process_payment
        
        # Access the retry_in configuration
        # For manual retry with countdown:
        with patch.object(process_payment, 'retry') as mock_retry:
            mock_retry.side_effect = Retry()
            
            with patch('myapp.tasks.PaymentGateway.charge') as mock_charge:
                mock_charge.side_effect = NetworkError("timeout")
                
                # Simulate first attempt (retries=0)
                with patch.object(process_payment.request, 'retries', 0):
                    with pytest.raises(Retry):
                        process_payment(1, 99.99)
                
                # First retry: countdown should be 30
                call_kwargs = mock_retry.call_args.kwargs
                assert call_kwargs['countdown'] == 30
    
    def test_max_retries_exhausted(self):
        """When max retries are exceeded, exception propagates"""
        with patch('myapp.tasks.PaymentGateway.charge') as mock_charge:
            mock_charge.side_effect = NetworkError("Server down")
            
            # Apply task with max_retries=0 to force immediate exhaustion
            with pytest.raises(NetworkError):
                process_payment.apply(
                    args=[1, 99.99],
                    kwargs={},
                    retries=3  # Already at max
                )

Testing retry with apply options

def test_task_raises_after_max_retries(order):
    """Simulate what happens when retries are exhausted"""
    
    def fail_always(*args, **kwargs):
        raise NetworkError("Always fails")
    
    with patch('myapp.tasks.PaymentGateway.charge', side_effect=fail_always):
        # Override max_retries to 0 for testing
        with pytest.raises(NetworkError):
            process_payment.apply(
                args=[order.id, 99.99],
                throw=True,
                retries=process_payment.max_retries  # Already exhausted
            )

Testing Task Chains

Chains are one of the most common Celery patterns — one task's output feeds the next:

# tasks.py
from celery import chain

@shared_task
def generate_report(report_id: int) -> dict:
    report = Report.objects.get(id=report_id)
    data = DataCollector.collect(report)
    report.update(raw_data=data, status='collected')
    return {'report_id': report_id, 'row_count': len(data)}


@shared_task
def format_report(result: dict) -> dict:
    report_id = result['report_id']
    report = Report.objects.get(id=report_id)
    formatted = Formatter.format(report.raw_data)
    report.update(formatted_data=formatted, status='formatted')
    return {'report_id': report_id}


@shared_task
def export_report(result: dict) -> str:
    report_id = result['report_id']
    report = Report.objects.get(id=report_id)
    url = S3Exporter.export(report.formatted_data)
    report.update(export_url=url, status='exported')
    return url


def run_report_pipeline(report_id: int):
    return chain(
        generate_report.s(report_id),
        format_report.s(),
        export_report.s()
    )()
class TestReportChain:
    
    @pytest.fixture(autouse=True)
    def eager_mode(self, settings):
        from django.test.utils import override_settings
        with override_settings(
            CELERY_TASK_ALWAYS_EAGER=True,
            CELERY_TASK_EAGER_PROPAGATES=True
        ):
            yield
    
    def test_full_pipeline(self, report):
        with patch('myapp.tasks.S3Exporter.export') as mock_export:
            mock_export.return_value = 'https://s3.example.com/report.pdf'
            
            result = run_report_pipeline(report.id)
            
            report.refresh_from_db()
            assert report.status == 'exported'
            assert report.export_url == 'https://s3.example.com/report.pdf'
    
    def test_chain_stops_on_first_failure(self, report):
        with patch('myapp.tasks.DataCollector.collect') as mock_collect:
            mock_collect.side_effect = ValueError("Invalid data source")
            
            with pytest.raises(ValueError):
                run_report_pipeline(report.id)
            
            # format_report and export_report should NOT have been called
            report.refresh_from_db()
            assert report.status == 'pending'  # Never progressed

Testing Groups

Groups run tasks in parallel:

from celery import group

@shared_task
def process_batch_item(item_id: int) -> dict:
    item = BatchItem.objects.get(id=item_id)
    result = Processor.process(item)
    return {'item_id': item_id, 'success': True, 'result': result}


def process_batch(batch_id: int):
    batch = Batch.objects.get(id=batch_id)
    return group(
        process_batch_item.s(item.id)
        for item in batch.items.all()
    )()
class TestBatchProcessing:
    
    def test_processes_all_items_in_group(self, batch_with_items):
        results = process_batch(batch_with_items.id)
        
        # In eager mode, group returns a list of results
        completed = results.get()
        
        assert len(completed) == batch_with_items.items.count()
        assert all(r['success'] for r in completed)
    
    def test_partial_failure_in_group(self, batch_with_items):
        """Groups collect all results even if some fail"""
        def fail_for_item_1(item_id):
            if item_id == batch_with_items.items.first().id:
                raise ValueError("Item 1 processing failed")
            return {'item_id': item_id, 'success': True}
        
        with patch('myapp.tasks.Processor.process', side_effect=fail_for_item_1):
            # In eager mode with propagate=False, exceptions are stored in results
            results = process_batch(batch_with_items.id)
            completed = results.get(propagate=False)
            
            successes = [r for r in completed if not isinstance(r, Exception)]
            failures = [r for r in completed if isinstance(r, Exception)]
            
            assert len(failures) == 1
            assert len(successes) == batch_with_items.items.count() - 1

Testing Task State and Result Backend

class TestTaskStates:
    
    def test_task_result_is_stored(self, order):
        result = process_payment.apply(args=[order.id, 99.99])
        
        # Retrieve from result backend
        stored = process_payment.AsyncResult(result.id)
        
        assert stored.status == 'SUCCESS'
        assert stored.result['order_id'] == order.id
    
    def test_failed_task_stores_exception(self, order):
        with patch('myapp.tasks.PaymentGateway.charge') as mock_charge:
            mock_charge.side_effect = ValueError("Card invalid")
            
            result = process_payment.apply(args=[order.id, 99.99], throw=False)
            
            assert result.status == 'FAILURE'
            assert isinstance(result.result, ValueError)

Testing Periodic Tasks (beat)

# tasks.py
from celery.schedules import crontab

@shared_task
def cleanup_expired_sessions():
    count = Session.objects.filter(
        expires_at__lt=timezone.now()
    ).delete()[0]
    return {'deleted_count': count}
# Test the task logic directly — beat scheduling is config, not code to test
class TestCleanupTask:
    
    def test_deletes_expired_sessions(self, db):
        # Create mix of expired and valid sessions
        Session.objects.create(expires_at=timezone.now() - timedelta(hours=1))
        Session.objects.create(expires_at=timezone.now() - timedelta(days=7))
        active = Session.objects.create(expires_at=timezone.now() + timedelta(hours=1))
        
        result = cleanup_expired_sessions()
        
        assert result['deleted_count'] == 2
        assert Session.objects.filter(id=active.id).exists()
    
    def test_is_registered_in_beat_schedule(self):
        from django.conf import settings
        
        schedule = settings.CELERY_BEAT_SCHEDULE
        assert 'cleanup-expired-sessions' in schedule
        assert schedule['cleanup-expired-sessions']['task'] == 'myapp.tasks.cleanup_expired_sessions'

Useful Fixtures for Celery Testing

# conftest.py
import pytest
from unittest.mock import patch


@pytest.fixture
def eager_celery(settings):
    """Force synchronous task execution"""
    settings.CELERY_TASK_ALWAYS_EAGER = True
    settings.CELERY_TASK_EAGER_PROPAGATES = True


@pytest.fixture
def mock_celery_task():
    """Mock a task to prevent actual execution, capture calls"""
    def factory(task_path):
        with patch(task_path) as mock:
            mock.delay = MagicMock()
            mock.apply_async = MagicMock()
            yield mock
    return factory


@pytest.fixture
def capture_tasks(mocker):
    """Capture all task calls without executing them"""
    tasks_sent = []
    
    original_send = celery_app.send_task
    
    def capture(*args, **kwargs):
        tasks_sent.append({'args': args, 'kwargs': kwargs})
        return MagicMock(id='mock-task-id')
    
    mocker.patch.object(celery_app, 'send_task', side_effect=capture)
    return tasks_sent

Celery task testing becomes manageable once you accept that task_always_eager is your best friend for most cases. Reserve the complexity — mocking retry mechanisms, testing chains with partial failures — for the jobs where the edge cases actually matter.

Read more