Celery Task Testing with pytest: Mocking, Retries, and Chaining
Celery powers background task processing in Python applications. It's powerful but notoriously tricky to test because of its asynchronous, distributed nature. The good news: Celery ships with testing infrastructure that makes most scenarios straightforward. The tricky parts — retries, chains, countdown — require some extra care.
Core Testing Modes
CELERY_TASK_ALWAYS_EAGER
The most important configuration for testing: task_always_eager makes tasks execute synchronously in the calling process.
# conftest.py
import pytest
from celery import Celery
@pytest.fixture(autouse=True)
def celery_config():
"""Make tasks run synchronously in tests"""
return {
'task_always_eager': True,
'task_eager_propagates': True, # Raise exceptions instead of capturing
'broker_url': 'memory://',
'result_backend': 'cache+memory://',
}Or with pytest-celery:
# conftest.py
@pytest.fixture
def celery_app():
from myapp.celery import app
app.conf.update(
task_always_eager=True,
task_eager_propagates=True,
)
return appUsing the @pytest.mark.celery Decorator
# pytest.ini
[pytest]
addopts = --strict-markers
# conftest.py
@pytest.fixture(scope='session')
def celery_config():
return {
'broker_url': 'memory://',
'result_backend': 'cache+memory://',
}import pytest
@pytest.mark.celery(task_always_eager=True)
def test_my_task():
result = my_task.delay(1, 2)
assert result.get() == 3Defining Tasks to Test
# tasks.py
from celery import shared_task, current_task
from celery.exceptions import MaxRetriesExceededError
import logging
logger = logging.getLogger(__name__)
@shared_task(bind=True, max_retries=3, default_retry_delay=60)
def process_payment(self, order_id: int, amount: float):
"""Process a payment — retries on network errors"""
from .models import Order, Payment
order = Order.objects.get(id=order_id)
if order.is_paid():
logger.info(f"Order {order_id} already paid, skipping")
return {'status': 'skipped', 'order_id': order_id}
try:
payment = PaymentGateway.charge(order.customer_email, amount)
order.mark_paid(payment.transaction_id)
return {'status': 'success', 'transaction_id': payment.transaction_id}
except NetworkError as exc:
raise self.retry(exc=exc, countdown=30 * (self.request.retries + 1))
except InsufficientFundsError:
order.mark_payment_failed('insufficient_funds')
raise # Don't retry — it won't succeed
@shared_task
def send_confirmation_email(order_id: int):
from .models import Order
order = Order.objects.get(id=order_id)
EmailService.send_order_confirmation(order)
return {'email_sent': True, 'order_id': order_id}Unit Testing Tasks
# tests/test_tasks.py
import pytest
from unittest.mock import patch, MagicMock
from celery.exceptions import MaxRetriesExceededError, Retry
from myapp.tasks import process_payment, send_confirmation_email
class TestProcessPayment:
@pytest.fixture
def order(self, db):
return Order.objects.create(
customer_email='alice@example.com',
total=99.99,
status='pending'
)
def test_successful_payment(self, order):
with patch('myapp.tasks.PaymentGateway.charge') as mock_charge:
mock_charge.return_value = MagicMock(transaction_id='txn_abc123')
result = process_payment(order.id, 99.99)
assert result['status'] == 'success'
assert result['transaction_id'] == 'txn_abc123'
order.refresh_from_db()
assert order.status == 'paid'
def test_skips_already_paid_order(self, order):
order.status = 'paid'
order.save()
with patch('myapp.tasks.PaymentGateway.charge') as mock_charge:
result = process_payment(order.id, 99.99)
assert result['status'] == 'skipped'
mock_charge.assert_not_called()
def test_retries_on_network_error(self, order, celery_app):
celery_app.conf.task_always_eager = False # Disable eager for retry test
with patch('myapp.tasks.PaymentGateway.charge') as mock_charge:
mock_charge.side_effect = NetworkError("Connection timeout")
with pytest.raises(Retry):
process_payment.apply(args=[order.id, 99.99])
def test_does_not_retry_on_insufficient_funds(self, order):
with patch('myapp.tasks.PaymentGateway.charge') as mock_charge:
mock_charge.side_effect = InsufficientFundsError("Card declined")
with pytest.raises(InsufficientFundsError):
process_payment(order.id, 99.99)
order.refresh_from_db()
assert order.status == 'payment_failed'
assert order.failure_reason == 'insufficient_funds'Testing Retries
Retry testing requires bypassing task_always_eager because eager mode doesn't simulate the retry mechanism:
class TestRetryBehavior:
def test_retry_countdown_increases_with_attempts(self):
"""Verify retry delay increases: 30s, 60s, 90s"""
task = process_payment
# Access the retry_in configuration
# For manual retry with countdown:
with patch.object(process_payment, 'retry') as mock_retry:
mock_retry.side_effect = Retry()
with patch('myapp.tasks.PaymentGateway.charge') as mock_charge:
mock_charge.side_effect = NetworkError("timeout")
# Simulate first attempt (retries=0)
with patch.object(process_payment.request, 'retries', 0):
with pytest.raises(Retry):
process_payment(1, 99.99)
# First retry: countdown should be 30
call_kwargs = mock_retry.call_args.kwargs
assert call_kwargs['countdown'] == 30
def test_max_retries_exhausted(self):
"""When max retries are exceeded, exception propagates"""
with patch('myapp.tasks.PaymentGateway.charge') as mock_charge:
mock_charge.side_effect = NetworkError("Server down")
# Apply task with max_retries=0 to force immediate exhaustion
with pytest.raises(NetworkError):
process_payment.apply(
args=[1, 99.99],
kwargs={},
retries=3 # Already at max
)Testing retry with apply options
def test_task_raises_after_max_retries(order):
"""Simulate what happens when retries are exhausted"""
def fail_always(*args, **kwargs):
raise NetworkError("Always fails")
with patch('myapp.tasks.PaymentGateway.charge', side_effect=fail_always):
# Override max_retries to 0 for testing
with pytest.raises(NetworkError):
process_payment.apply(
args=[order.id, 99.99],
throw=True,
retries=process_payment.max_retries # Already exhausted
)Testing Task Chains
Chains are one of the most common Celery patterns — one task's output feeds the next:
# tasks.py
from celery import chain
@shared_task
def generate_report(report_id: int) -> dict:
report = Report.objects.get(id=report_id)
data = DataCollector.collect(report)
report.update(raw_data=data, status='collected')
return {'report_id': report_id, 'row_count': len(data)}
@shared_task
def format_report(result: dict) -> dict:
report_id = result['report_id']
report = Report.objects.get(id=report_id)
formatted = Formatter.format(report.raw_data)
report.update(formatted_data=formatted, status='formatted')
return {'report_id': report_id}
@shared_task
def export_report(result: dict) -> str:
report_id = result['report_id']
report = Report.objects.get(id=report_id)
url = S3Exporter.export(report.formatted_data)
report.update(export_url=url, status='exported')
return url
def run_report_pipeline(report_id: int):
return chain(
generate_report.s(report_id),
format_report.s(),
export_report.s()
)()class TestReportChain:
@pytest.fixture(autouse=True)
def eager_mode(self, settings):
from django.test.utils import override_settings
with override_settings(
CELERY_TASK_ALWAYS_EAGER=True,
CELERY_TASK_EAGER_PROPAGATES=True
):
yield
def test_full_pipeline(self, report):
with patch('myapp.tasks.S3Exporter.export') as mock_export:
mock_export.return_value = 'https://s3.example.com/report.pdf'
result = run_report_pipeline(report.id)
report.refresh_from_db()
assert report.status == 'exported'
assert report.export_url == 'https://s3.example.com/report.pdf'
def test_chain_stops_on_first_failure(self, report):
with patch('myapp.tasks.DataCollector.collect') as mock_collect:
mock_collect.side_effect = ValueError("Invalid data source")
with pytest.raises(ValueError):
run_report_pipeline(report.id)
# format_report and export_report should NOT have been called
report.refresh_from_db()
assert report.status == 'pending' # Never progressedTesting Groups
Groups run tasks in parallel:
from celery import group
@shared_task
def process_batch_item(item_id: int) -> dict:
item = BatchItem.objects.get(id=item_id)
result = Processor.process(item)
return {'item_id': item_id, 'success': True, 'result': result}
def process_batch(batch_id: int):
batch = Batch.objects.get(id=batch_id)
return group(
process_batch_item.s(item.id)
for item in batch.items.all()
)()class TestBatchProcessing:
def test_processes_all_items_in_group(self, batch_with_items):
results = process_batch(batch_with_items.id)
# In eager mode, group returns a list of results
completed = results.get()
assert len(completed) == batch_with_items.items.count()
assert all(r['success'] for r in completed)
def test_partial_failure_in_group(self, batch_with_items):
"""Groups collect all results even if some fail"""
def fail_for_item_1(item_id):
if item_id == batch_with_items.items.first().id:
raise ValueError("Item 1 processing failed")
return {'item_id': item_id, 'success': True}
with patch('myapp.tasks.Processor.process', side_effect=fail_for_item_1):
# In eager mode with propagate=False, exceptions are stored in results
results = process_batch(batch_with_items.id)
completed = results.get(propagate=False)
successes = [r for r in completed if not isinstance(r, Exception)]
failures = [r for r in completed if isinstance(r, Exception)]
assert len(failures) == 1
assert len(successes) == batch_with_items.items.count() - 1Testing Task State and Result Backend
class TestTaskStates:
def test_task_result_is_stored(self, order):
result = process_payment.apply(args=[order.id, 99.99])
# Retrieve from result backend
stored = process_payment.AsyncResult(result.id)
assert stored.status == 'SUCCESS'
assert stored.result['order_id'] == order.id
def test_failed_task_stores_exception(self, order):
with patch('myapp.tasks.PaymentGateway.charge') as mock_charge:
mock_charge.side_effect = ValueError("Card invalid")
result = process_payment.apply(args=[order.id, 99.99], throw=False)
assert result.status == 'FAILURE'
assert isinstance(result.result, ValueError)Testing Periodic Tasks (beat)
# tasks.py
from celery.schedules import crontab
@shared_task
def cleanup_expired_sessions():
count = Session.objects.filter(
expires_at__lt=timezone.now()
).delete()[0]
return {'deleted_count': count}# Test the task logic directly — beat scheduling is config, not code to test
class TestCleanupTask:
def test_deletes_expired_sessions(self, db):
# Create mix of expired and valid sessions
Session.objects.create(expires_at=timezone.now() - timedelta(hours=1))
Session.objects.create(expires_at=timezone.now() - timedelta(days=7))
active = Session.objects.create(expires_at=timezone.now() + timedelta(hours=1))
result = cleanup_expired_sessions()
assert result['deleted_count'] == 2
assert Session.objects.filter(id=active.id).exists()
def test_is_registered_in_beat_schedule(self):
from django.conf import settings
schedule = settings.CELERY_BEAT_SCHEDULE
assert 'cleanup-expired-sessions' in schedule
assert schedule['cleanup-expired-sessions']['task'] == 'myapp.tasks.cleanup_expired_sessions'Useful Fixtures for Celery Testing
# conftest.py
import pytest
from unittest.mock import patch
@pytest.fixture
def eager_celery(settings):
"""Force synchronous task execution"""
settings.CELERY_TASK_ALWAYS_EAGER = True
settings.CELERY_TASK_EAGER_PROPAGATES = True
@pytest.fixture
def mock_celery_task():
"""Mock a task to prevent actual execution, capture calls"""
def factory(task_path):
with patch(task_path) as mock:
mock.delay = MagicMock()
mock.apply_async = MagicMock()
yield mock
return factory
@pytest.fixture
def capture_tasks(mocker):
"""Capture all task calls without executing them"""
tasks_sent = []
original_send = celery_app.send_task
def capture(*args, **kwargs):
tasks_sent.append({'args': args, 'kwargs': kwargs})
return MagicMock(id='mock-task-id')
mocker.patch.object(celery_app, 'send_task', side_effect=capture)
return tasks_sentCelery task testing becomes manageable once you accept that task_always_eager is your best friend for most cases. Reserve the complexity — mocking retry mechanisms, testing chains with partial failures — for the jobs where the edge cases actually matter.