Python单元测试框架与Mock对象使用指南
单元测试是保证代码质量的重要手段。本文将介绍Python单元测试框架和Mock对象的使用。 单元测试是保证代码质量的重要手段。Python单元测试框架与Mock对象使用指南
单元测试核心概念
单元测试核心实现
"""
Python单元测试框架与Mock对象使用
包含unittest、pytest、mock等
"""
import unittest
from unittest.mock import Mock, patch, MagicMock, call
from typing import List, Dict, Optional
import requests
# ============ 被测试的代码 ============
class Calculator:
"""计算器类"""
def add(self, a: float, b: float) -> float:
return a + b
def subtract(self, a: float, b: float) -> float:
return a - b
def multiply(self, a: float, b: float) -> float:
return a * b
def divide(self, a: float, b: float) -> float:
if b == 0:
raise ValueError("除数不能为零")
return a / b
class UserService:
"""用户服务"""
def __init__(self, db_connection):
self.db = db_connection
def get_user(self, user_id: int) -> Optional[Dict]:
return self.db.query(f"SELECT * FROM users WHERE id = {user_id}")
def create_user(self, username: str, email: str) -> Dict:
if not username or not email:
raise ValueError("用户名和邮箱不能为空")
return self.db.insert("users", {"username": username, "email": email})
class PaymentGateway:
"""支付网关"""
def __init__(self, api_key: str):
self.api_key = api_key
self.base_url = "https://api.payment.com"
def charge(self, amount: float, card_number: str) -> Dict:
response = requests.post(
f"{self.base_url}/charge",
headers={"Authorization": f"Bearer {self.api_key}"},
json={"amount": amount, "card": card_number}
)
return response.json()
# ============ unittest测试 ============
class TestCalculator(unittest.TestCase):
"""计算器测试"""
def setUp(self):
"""测试前准备"""
self.calc = Calculator()
def test_add(self):
"""测试加法"""
self.assertEqual(self.calc.add(2, 3), 5)
self.assertEqual(self.calc.add(-1, 1), 0)
self.assertEqual(self.calc.add(0.1, 0.2), 0.3, places=7)
def test_subtract(self):
"""测试减法"""
self.assertEqual(self.calc.subtract(5, 3), 2)
self.assertEqual(self.calc.subtract(0, 5), -5)
def test_multiply(self):
"""测试乘法"""
self.assertEqual(self.calc.multiply(3, 4), 12)
self.assertEqual(self.calc.multiply(-2, 3), -6)
def test_divide(self):
"""测试除法"""
self.assertEqual(self.calc.divide(6, 2), 3)
self.assertEqual(self.calc.divide(5, 2), 2.5)
def test_divide_by_zero(self):
"""测试除零异常"""
with self.assertRaises(ValueError) as context:
self.calc.divide(5, 0)
self.assertEqual(str(context.exception), "除数不能为零")
class TestUserService(unittest.TestCase):
"""用户服务测试"""
def setUp(self):
"""测试前准备"""
# 创建Mock对象
self.mock_db = Mock()
self.user_service = UserService(self.mock_db)
def test_get_user(self):
"""测试获取用户"""
# 设置Mock返回值
expected_user = {"id": 1, "username": "alice"}
self.mock_db.query.return_value = expected_user
# 执行测试
user = self.user_service.get_user(1)
# 验证结果
self.assertEqual(user, expected_user)
self.mock_db.query.assert_called_once_with("SELECT * FROM users WHERE id = 1")
def test_create_user(self):
"""测试创建用户"""
expected_result = {"id": 1, "username": "bob", "email": "bob@example.com"}
self.mock_db.insert.return_value = expected_result
result = self.user_service.create_user("bob", "bob@example.com")
self.assertEqual(result, expected_result)
self.mock_db.insert.assert_called_once_with(
"users",
{"username": "bob", "email": "bob@example.com"}
)
def test_create_user_invalid_input(self):
"""测试创建用户无效输入"""
with self.assertRaises(ValueError):
self.user_service.create_user("", "email@example.com")
class TestPaymentGateway(unittest.TestCase):
"""支付网关测试"""
@patch('requests.post')
def test_charge(self, mock_post):
"""测试支付"""
# 设置Mock响应
mock_response = Mock()
mock_response.json.return_value = {"status": "success", "transaction_id": "12345"}
mock_post.return_value = mock_response
# 执行测试
gateway = PaymentGateway("test_api_key")
result = gateway.charge(100.0, "4111111111111111")
# 验证结果
self.assertEqual(result["status"], "success")
mock_post.assert_called_once()
# ============ Mock高级用法 ============
class TestMockAdvanced(unittest.TestCase):
"""Mock高级用法测试"""
def test_mock_return_value(self):
"""测试Mock返回值"""
mock = Mock()
mock.return_value = 42
result = mock()
self.assertEqual(result, 42)
def test_mock_side_effect(self):
"""测试Mock副作用"""
mock = Mock()
mock.side_effect = [1, 2, 3]
self.assertEqual(mock(), 1)
self.assertEqual(mock(), 2)
self.assertEqual(mock(), 3)
def test_mock_side_effect_exception(self):
"""测试Mock抛出异常"""
mock = Mock()
mock.side_effect = ValueError("错误")
with self.assertRaises(ValueError):
mock()
def test_mock_call_count(self):
"""测试Mock调用次数"""
mock = Mock()
mock()
mock()
mock()
self.assertEqual(mock.call_count, 3)
def test_mock_call_args(self):
"""测试Mock调用参数"""
mock = Mock()
mock(1, 2, key='value')
self.assertEqual(mock.call_args, call(1, 2, key='value'))
def test_mock_multiple_calls(self):
"""测试Mock多次调用"""
mock = Mock()
mock(1)
mock(2)
mock(3)
expected_calls = [call(1), call(2), call(3)]
mock.assert_has_calls(expected_calls)
# ============ 测试夹具 ============
class TestWithFixtures(unittest.TestCase):
"""使用测试夹具"""
@classmethod
def setUpClass(cls):
"""类级别设置"""
print("设置测试类")
cls.shared_resource = "shared"
@classmethod
def tearDownClass(cls):
"""类级别清理"""
print("清理测试类")
def setUp(self):
"""方法级别设置"""
self.test_data = [1, 2, 3]
def tearDown(self):
"""方法级别清理"""
self.test_data = None
def test_with_fixture(self):
"""使用夹具的测试"""
self.assertEqual(len(self.test_data), 3)
self.assertEqual(self.shared_resource, "shared")
def main():
"""主函数"""
print("="*60)
print("Python单元测试框架与Mock对象使用")
print("="*60)
# 运行测试
loader = unittest.TestLoader()
suite = unittest.TestSuite()
# 添加测试
suite.addTests(loader.loadTestsFromTestCase(TestCalculator))
suite.addTests(loader.loadTestsFromTestCase(TestUserService))
suite.addTests(loader.loadTestsFromTestCase(TestPaymentGateway))
suite.addTests(loader.loadTestsFromTestCase(TestMockAdvanced))
suite.addTests(loader.loadTestsFromTestCase(TestWithFixtures))
# 运行
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
print("\n" + "="*60)
print("单元测试总结")
print("="*60)
print("1. unittest: 标准库测试框架")
print("2. Mock: 模拟对象和依赖")
print("3. patch: 临时替换对象")
print("4. Fixture: 测试准备和清理")
print("5. 断言: 验证测试结果")
print("="*60)
if __name__ == "__main__":
main()单元测试架构图
关键要点