Python单元测试框架与Mock对象使用指南

单元测试是保证代码质量的重要手段。本文将介绍Python单元测试框架和Mock对象的使用。

单元测试核心概念

  • unittest:标准库测试框架
  • pytest:第三方测试框架
  • Mock:模拟对象
  • Fixture:测试夹具

单元测试核心实现

"""
Python单元测试框架与Mock对象使用
包含unittest、pytest、mock等
"""

import unittest
from unittest.mock import Mock, patch, MagicMock, call
from typing import List, Dict, Optional
import requests


# ============ 被测试的代码 ============
class Calculator:
    """计算器类"""
    
    def add(self, a: float, b: float) -> float:
        return a + b
    
    def subtract(self, a: float, b: float) -> float:
        return a - b
    
    def multiply(self, a: float, b: float) -> float:
        return a * b
    
    def divide(self, a: float, b: float) -> float:
        if b == 0:
            raise ValueError("除数不能为零")
        return a / b


class UserService:
    """用户服务"""
    
    def __init__(self, db_connection):
        self.db = db_connection
    
    def get_user(self, user_id: int) -> Optional[Dict]:
        return self.db.query(f"SELECT * FROM users WHERE id = {user_id}")
    
    def create_user(self, username: str, email: str) -> Dict:
        if not username or not email:
            raise ValueError("用户名和邮箱不能为空")
        return self.db.insert("users", {"username": username, "email": email})


class PaymentGateway:
    """支付网关"""
    
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.base_url = "https://api.payment.com"
    
    def charge(self, amount: float, card_number: str) -> Dict:
        response = requests.post(
            f"{self.base_url}/charge",
            headers={"Authorization": f"Bearer {self.api_key}"},
            json={"amount": amount, "card": card_number}
        )
        return response.json()


# ============ unittest测试 ============
class TestCalculator(unittest.TestCase):
    """计算器测试"""
    
    def setUp(self):
        """测试前准备"""
        self.calc = Calculator()
    
    def test_add(self):
        """测试加法"""
        self.assertEqual(self.calc.add(2, 3), 5)
        self.assertEqual(self.calc.add(-1, 1), 0)
        self.assertEqual(self.calc.add(0.1, 0.2), 0.3, places=7)
    
    def test_subtract(self):
        """测试减法"""
        self.assertEqual(self.calc.subtract(5, 3), 2)
        self.assertEqual(self.calc.subtract(0, 5), -5)
    
    def test_multiply(self):
        """测试乘法"""
        self.assertEqual(self.calc.multiply(3, 4), 12)
        self.assertEqual(self.calc.multiply(-2, 3), -6)
    
    def test_divide(self):
        """测试除法"""
        self.assertEqual(self.calc.divide(6, 2), 3)
        self.assertEqual(self.calc.divide(5, 2), 2.5)
    
    def test_divide_by_zero(self):
        """测试除零异常"""
        with self.assertRaises(ValueError) as context:
            self.calc.divide(5, 0)
        self.assertEqual(str(context.exception), "除数不能为零")


class TestUserService(unittest.TestCase):
    """用户服务测试"""
    
    def setUp(self):
        """测试前准备"""
        # 创建Mock对象
        self.mock_db = Mock()
        self.user_service = UserService(self.mock_db)
    
    def test_get_user(self):
        """测试获取用户"""
        # 设置Mock返回值
        expected_user = {"id": 1, "username": "alice"}
        self.mock_db.query.return_value = expected_user
        
        # 执行测试
        user = self.user_service.get_user(1)
        
        # 验证结果
        self.assertEqual(user, expected_user)
        self.mock_db.query.assert_called_once_with("SELECT * FROM users WHERE id = 1")
    
    def test_create_user(self):
        """测试创建用户"""
        expected_result = {"id": 1, "username": "bob", "email": "bob@example.com"}
        self.mock_db.insert.return_value = expected_result
        
        result = self.user_service.create_user("bob", "bob@example.com")
        
        self.assertEqual(result, expected_result)
        self.mock_db.insert.assert_called_once_with(
            "users",
            {"username": "bob", "email": "bob@example.com"}
        )
    
    def test_create_user_invalid_input(self):
        """测试创建用户无效输入"""
        with self.assertRaises(ValueError):
            self.user_service.create_user("", "email@example.com")


class TestPaymentGateway(unittest.TestCase):
    """支付网关测试"""
    
    @patch('requests.post')
    def test_charge(self, mock_post):
        """测试支付"""
        # 设置Mock响应
        mock_response = Mock()
        mock_response.json.return_value = {"status": "success", "transaction_id": "12345"}
        mock_post.return_value = mock_response
        
        # 执行测试
        gateway = PaymentGateway("test_api_key")
        result = gateway.charge(100.0, "4111111111111111")
        
        # 验证结果
        self.assertEqual(result["status"], "success")
        mock_post.assert_called_once()


# ============ Mock高级用法 ============
class TestMockAdvanced(unittest.TestCase):
    """Mock高级用法测试"""
    
    def test_mock_return_value(self):
        """测试Mock返回值"""
        mock = Mock()
        mock.return_value = 42
        
        result = mock()
        self.assertEqual(result, 42)
    
    def test_mock_side_effect(self):
        """测试Mock副作用"""
        mock = Mock()
        mock.side_effect = [1, 2, 3]
        
        self.assertEqual(mock(), 1)
        self.assertEqual(mock(), 2)
        self.assertEqual(mock(), 3)
    
    def test_mock_side_effect_exception(self):
        """测试Mock抛出异常"""
        mock = Mock()
        mock.side_effect = ValueError("错误")
        
        with self.assertRaises(ValueError):
            mock()
    
    def test_mock_call_count(self):
        """测试Mock调用次数"""
        mock = Mock()
        
        mock()
        mock()
        mock()
        
        self.assertEqual(mock.call_count, 3)
    
    def test_mock_call_args(self):
        """测试Mock调用参数"""
        mock = Mock()
        
        mock(1, 2, key='value')
        
        self.assertEqual(mock.call_args, call(1, 2, key='value'))
    
    def test_mock_multiple_calls(self):
        """测试Mock多次调用"""
        mock = Mock()
        
        mock(1)
        mock(2)
        mock(3)
        
        expected_calls = [call(1), call(2), call(3)]
        mock.assert_has_calls(expected_calls)


# ============ 测试夹具 ============
class TestWithFixtures(unittest.TestCase):
    """使用测试夹具"""
    
    @classmethod
    def setUpClass(cls):
        """类级别设置"""
        print("设置测试类")
        cls.shared_resource = "shared"
    
    @classmethod
    def tearDownClass(cls):
        """类级别清理"""
        print("清理测试类")
    
    def setUp(self):
        """方法级别设置"""
        self.test_data = [1, 2, 3]
    
    def tearDown(self):
        """方法级别清理"""
        self.test_data = None
    
    def test_with_fixture(self):
        """使用夹具的测试"""
        self.assertEqual(len(self.test_data), 3)
        self.assertEqual(self.shared_resource, "shared")


def main():
    """主函数"""
    print("="*60)
    print("Python单元测试框架与Mock对象使用")
    print("="*60)
    
    # 运行测试
    loader = unittest.TestLoader()
    suite = unittest.TestSuite()
    
    # 添加测试
    suite.addTests(loader.loadTestsFromTestCase(TestCalculator))
    suite.addTests(loader.loadTestsFromTestCase(TestUserService))
    suite.addTests(loader.loadTestsFromTestCase(TestPaymentGateway))
    suite.addTests(loader.loadTestsFromTestCase(TestMockAdvanced))
    suite.addTests(loader.loadTestsFromTestCase(TestWithFixtures))
    
    # 运行
    runner = unittest.TextTestRunner(verbosity=2)
    result = runner.run(suite)
    
    print("\n" + "="*60)
    print("单元测试总结")
    print("="*60)
    print("1. unittest: 标准库测试框架")
    print("2. Mock: 模拟对象和依赖")
    print("3. patch: 临时替换对象")
    print("4. Fixture: 测试准备和清理")
    print("5. 断言: 验证测试结果")
    print("="*60)


if __name__ == "__main__":
    main()

单元测试架构图

flowchart TB
    subgraph TestCode["测试代码"]
        T1[测试用例]
        T2[测试夹具]
        T3[断言]
    end
    
    subgraph MockObjects["Mock对象"]
        M1[Mock]
        M2[MagicMock]
        M3[Patch]
    end
    
    subgraph Production["生产代码"]
        P1[被测函数]
        P2[被测类]
    end
    
    TestCode --> MockObjects
    TestCode --> Production

关键要点

  1. unittest:标准库测试框架
  2. Mock:模拟对象和依赖
  3. patch:临时替换对象
  4. Fixture:测试准备和清理
  5. 断言:验证测试结果

单元测试是保证代码质量的重要手段。

标签: none

添加新评论