import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial.transform import Rotation as R

# 创建画布和3D坐标轴
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')
ax.set_axis_off()

# 人体各部位参数配置（增强版）
body_config = {
    'head': {
        'type': 'sphere',
        'pos': (0, 0, 1.7),
        'size': 0.2,
        'color': '#FFDAB9'  # 肤色
    },
    'torso': {
        'type': 'cylinder',
        'start': (0, 0, 1.0),
        'end': (0, 0, 0.5),
        'radius': 0.3,
        'color': '#87CEEB'  # 天蓝色
    },
    'arm_left': {
        'type': 'cylinder',
        'start': (-0.3, 0, 1.3),
        'end': (-0.8, 0, 1.0),
        'radius': 0.1,
        'color': '#87CEEB'
    },
    'arm_right': {
        'type': 'cylinder',
        'start': (0.3, 0, 1.3),
        'end': (0.8, 0, 1.0),
        'radius': 0.1,
        'color': '#87CEEB'
    },
    'leg_left': {
        'type': 'cylinder',
        'start': (-0.15, 0, 0.5),
        'end': (-0.3, 0, 0.0),
        'radius': 0.15,
        'color': '#4682B4'  # 深蓝色
    },
    'leg_right': {
        'type': 'cylinder',
        'start': (0.15, 0, 0.5),
        'end': (0.3, 0, 0.0),
        'radius': 0.15,
        'color': '#4682B4'
    }
}

class BodyPart:
    """人体部件基类"""
    def __init__(self, config):
        self.config = config
    
    def draw(self, ax):
        raise NotImplementedError

class Sphere(BodyPart):
    """球体部件（用于头部）"""
    def draw(self, ax):
        u = np.linspace(0, 2 * np.pi, 30)
        v = np.linspace(0, np.pi, 30)
        x = self.config['pos'][0] + self.config['size'] * np.outer(np.cos(u), np.sin(v))
        y = self.config['pos'][1] + self.config['size'] * np.outer(np.sin(u), np.sin(v))
        z = self.config['pos'][2] + self.config['size'] * np.outer(np.ones(np.size(u)), np.cos(v))
        ax.plot_surface(x, y, z, color=self.config['color'], alpha=0.8)

class Cylinder(BodyPart):
    """圆柱体部件（用于四肢和躯干）"""
    def draw(self, ax):
        start = np.array(self.config['start'])
        end = np.array(self.config['end'])
        vec = end - start
        length = np.linalg.norm(vec)
        
        # 生成圆柱体基础网格
        theta = np.linspace(0, 2 * np.pi, 30)
        z = np.linspace(0, 1, 30)
        theta_grid, z_grid = np.meshgrid(theta, z)
        
        # 计算截面坐标
        x_grid = self.config['radius'] * np.cos(theta_grid)
        y_grid = self.config['radius'] * np.sin(theta_grid)
        z_grid = length * z_grid
        
        # 计算旋转矩阵
        target_dir = vec / length
        initial_dir = np.array([0, 0, 1])
        rotation_axis = np.cross(initial_dir, target_dir)
        rotation_angle = np.arccos(np.dot(initial_dir, target_dir))
        rotation = R.from_rotvec(rotation_axis * rotation_angle)
        
        # 应用旋转和平移
        points = np.vstack([x_grid.ravel(), y_grid.ravel(), z_grid.ravel()])
        rotated_points = rotation.apply(points.T).T
        translated_points = rotated_points + start.reshape(-1, 1)
        
        # 重塑网格形状
        x = translated_points[0].reshape(x_grid.shape)
        y = translated_points[1].reshape(y_grid.shape)
        z = translated_points[2].reshape(z_grid.shape)
        
        ax.plot_surface(x, y, z, color=self.config['color'], alpha=0.8)

# 创建人体部件
body_parts = []
for part in body_config.values():
    if part['type'] == 'sphere':
        body_parts.append(Sphere(part))
    elif part['type'] == 'cylinder':
        body_parts.append(Cylinder(part))

# 绘制初始图形
for part in body_parts:
    part.draw(ax)

# 设置视角范围和初始角度
ax.set_xlim(-1, 1)
ax.set_ylim
