from itertools import combinations
import time
import threading
from collections import defaultdict
import math
from multiprocessing import Pool

# 配置参数（优化阈值）
TARGET = 149665  # 目标值
BASE_VALUES = [38.5,44,61,70.5,75.5,93] # 基础系数列表
FLUCTUATION = 1.0  # 系数波动范围
MAX_SOLUTIONS = 3  # 每个组合的最大解数量
SOLVER_TIMEOUT = 180  # 求解超时时间(秒)
THREE_VAR_THRESHOLD = 259000  # 使用三个变量的阈值
PRODUCT_RANGE_THRESHOLD = 129000  # 乘积范围限制阈值
HIGH_TARGET_THRESHOLD = 259000  # 更高目标值阈值
SHOW_PROGRESS = True  # 是否显示进度
MAX_SOLUTIONS_PER_COMB = 100  # 每个组合的最大解数量，用于提前终止
USE_MULTIPROCESSING = True  # 是否使用多进程加速

def is_valid_product(p):
    """确保单个乘积不超过129000"""
    return p <= 129000  # 严格限制所有乘积不超过129000

def find_single_variable_solutions(values):
    """查找单个数的解（a*x = TARGET）"""
    solutions = []
    for a in values:
        quotient = TARGET / a
        if quotient != int(quotient):
            continue
        x = int(quotient)
        if 1 <= x <= 10000 and is_valid_product(a * x):
            solutions.append((a, x))
            if len(solutions) >= MAX_SOLUTIONS:
                break
    return solutions

def find_two_variable_solutions(values):
    """优化的双变量求解算法，确保每个乘积不超过129000"""
    solutions = defaultdict(list)
    for i, a in enumerate(values):
        for b in values[i:]:
            seen_xy = set()
            # 调整x的最大值，确保a*x不超过129000
            max_x = min(math.floor((TARGET - b) / a), math.floor(129000 / a))
            min_x = max(1, math.ceil((TARGET - b * 10000) / a))
            
            if max_x < min_x:
                continue
                
            x_count = max_x - min_x + 1
            x_step = max(1, x_count // 1000)
            
            for x in range(min_x, max_x + 1, x_step):
                ax = a * x
                if ax > 129000:  # 额外检查确保不超过129000
                    continue
                    
                remainder = TARGET - ax
                
                if remainder < b:
                    break
                    
                if remainder > b * 10000:
                    continue
                    
                if remainder % b == 0:
                    y = remainder // b
                    by = b * y
                    if 1 <= y <= 10000 and by <= 129000:  # 确保b*y也不超过129000
                        xy_pair = (x, y) if a <= b else (y, x)
                        if xy_pair not in seen_xy:
                            seen_xy.add(xy_pair)
                            solutions[(a, b)].append((a, x, b, y))
                            if len(solutions[(a, b)]) >= MAX_SOLUTIONS_PER_COMB:
                                break
    return solutions

def process_three_var_combination(args):
    """处理三变量组合的辅助函数，用于并行计算"""
    a, b, c, value_ranges, target = args
    solutions = []
    seen_xyz = set()
    
    min_x, max_x = value_ranges[a]
    x_count = max_x - min_x + 1
    x_step = max(1, x_count // 1000)
    
    for x in range(min_x, max_x + 1, x_step):
        ax = a * x
        if not is_valid_product(ax):
            continue
            
        remainder1 = target - ax
        if remainder1 < 0:
            break
            
        max_y = math.floor((remainder1 - c) / b)
        min_y = max(1, math.ceil((remainder1 - c * 10000) / b))
        
        if max_y < min_y:
            continue
            
        y_count = max_y - min_y + 1
        y_step = max(1, y_count // 100)
        
        for y in range(min_y, max_y + 1, y_step):
            by = b * y
            if not is_valid_product(by):
                continue
                
            remainder2 = remainder1 - by
            if remainder2 < 0:
                break
                
            if remainder2 > c * 10000:
                continue
                
            if remainder2 % c == 0:
                z = remainder2 // c
                if 1 <= z <= 10000 and is_valid_product(c * z):
                    xyz_tuple = tuple(sorted([x, y, z]))
                    if xyz_tuple not in seen_xyz:
                        seen_xyz.add(xyz_tuple)
                        solutions.append((a, x, b, y, c, z))
                        if len(solutions) >= MAX_SOLUTIONS_PER_COMB:
                            return solutions
    
    return solutions

def find_three_variable_solutions(values):
    """优化的三变量求解算法，确保每个乘积不超过129000"""
    solutions = defaultdict(list)
    sorted_values = sorted(values)
    
    # 预计算每个系数的有效范围，确保每个乘积不超过129000
    value_ranges = {}
    for a in sorted_values:
        min_x = max(1, math.ceil(1 / a))  # 最小为1
        max_x = min(10000, math.floor(129000 / a))  # 确保a*x <= 129000
        value_ranges[a] = (min_x, max_x)
    
    combinations_list = []
    for i, a in enumerate(sorted_values):
        for j in range(i + 1, len(sorted_values)):
            b = sorted_values[j]
            for k in range(j + 1, len(sorted_values)):
                c = sorted_values[k]
                combinations_list.append((a, b, c, value_ranges, TARGET))
    
    if USE_MULTIPROCESSING:
        with Pool() as pool:
            results = pool.map(process_three_var_combination, combinations_list)
        
        for i, (a, b, c, _, _) in enumerate(combinations_list):
            if results[i]:
                solutions[(a, b, c)] = results[i]
    else:
        total_combinations = len(combinations_list)
        for i, (a, b, c, _, _) in enumerate(combinations_list):
            res = process_three_var_combination((a, b, c, value_ranges, TARGET))
            if res:
                solutions[(a, b, c)] = res
            
            if SHOW_PROGRESS and i % 10 == 0:
                print(f"\r三变量组合进度: {i}/{total_combinations} 组", end='')
    
    if SHOW_PROGRESS and not USE_MULTIPROCESSING:
        print(f"\r三变量组合进度: {total_combinations}/{total_combinations} 组 - 完成")
    
    return solutions

def find_balanced_solutions(solutions, var_count, num=2):
    """从所有解中筛选出最平衡的解"""
    if var_count == 1 or not solutions:
        return solutions
    
    balanced = []
    for sol in solutions:
        vars = sol[1::2]  # 获取解中的变量值
        diff = max(vars) - min(vars)  # 计算变量之间的最大差值
        balanced.append((diff, sol))
    
    # 按差值排序，返回差值最小的解
    return [s for _, s in sorted(balanced, key=lambda x: x[0])[:num]]

def find_original_solutions(solutions, balanced_solutions, num=3):
    """从剩余解中获取原始顺序的解"""
    if not solutions:
        return []
    
    remaining = [s for s in solutions if s not in balanced_solutions]
    return remaining[:num]

def display_solutions(solutions_dict, var_count):
    """优化的解显示函数"""
    if not solutions_dict:
        return
    
    print(f"\n找到 {len(solutions_dict)} 组{var_count}变量解：")
    
    for i, (coeffs, pair_solutions) in enumerate(sorted(solutions_dict.items()), 1):
        balanced = find_balanced_solutions(pair_solutions, var_count)
        original = find_original_solutions(pair_solutions, balanced)
        all_display = balanced + original
        
        if var_count == 1:
            a = coeffs
            print(f"\n{i}. 组合: a={a} ({len(pair_solutions)} 个有效解)")
        elif var_count == 2:
            a, b = coeffs
            print(f"\n{i}. 组合: a={a}, b={b} ({len(pair_solutions)} 个有效解)")
        else:
            a, b, c = coeffs
            print(f"\n{i}. 组合: a={a}, b={b}, c={c} ({len(pair_solutions)} 个有效解)")
        
        for j, sol in enumerate(all_display, 1):
            tag = "[平衡解]" if j <= len(balanced) else "[原始解]"
            
            if var_count == 1:
                a, x = sol
                print(f"  {j}. x={x}, a*x={a*x:.1f}, 总和={a*x:.1f} {tag}")
            elif var_count == 2:
                a, x, b, y = sol
                print(f"  {j}. x={x}, y={y}, a*x={a*x:.1f}, b*y={b*y:.1f}, 总和={a*x + b*y:.1f} {tag}")
            else:
                a, x, b, y, c, z = sol
                print(f"  {j}. x={x}, y={y}, z={z}, "
                      f"a*x={a*x:.1f}, b*y={b*y:.1f}, c*z={c*z:.1f}, "
                      f"总和={a*x + b*y + c*z:.1f} {tag}")

def run_with_timeout(func, args=(), kwargs=None, timeout=SOLVER_TIMEOUT):
    """运行函数并设置超时限制"""
    if kwargs is None:
        kwargs = {}
    
    result = []
    error = []
    
    def wrapper():
        try:
            result.append(func(*args, **kwargs))
        except Exception as e:
            error.append(e)
    
    thread = threading.Thread(target=wrapper)
    thread.daemon = True
    thread.start()
    thread.join(timeout)
    
    if thread.is_alive():
        print(f"警告: {func.__name__} 超时（{timeout}秒），跳过此方法")
        return None
    
    if error:
        raise error[0]
    
    return result[0]

def main():
    print(f"目标值: {TARGET}")
    
    # 生成波动后的系数
    FLUCTUATED_VALUES = [round(v - FLUCTUATION, 1) for v in BASE_VALUES]
    
    # 尝试基础系数
    print(f"\n==== 尝试基础系数 ====")
    
    # 目标值255966 > 259000不成立，会按顺序尝试单、双、三变量解
    base_solutions = {
        'single': run_with_timeout(find_single_variable_solutions, args=(BASE_VALUES,)),
        'two': run_with_timeout(find_two_variable_solutions, args=(BASE_VALUES,)),
        'three': []
    }
    
    has_solution = False
    
    # 显示单变量解
    if base_solutions['single']:
        has_solution = True
        display_solutions({a: [sol] for a, sol in zip(BASE_VALUES, base_solutions['single']) if sol}, 1)
    
    # 显示双变量解
    if base_solutions['two'] and len(base_solutions['two']) > 0:
        has_solution = True
        display_solutions(base_solutions['two'], 2)
    
    # 单变量和双变量都无解时，尝试三变量解
    if not has_solution:
        print(f"\n==== 单变量和双变量无解，尝试三变量解 ====")
        base_solutions['three'] = run_with_timeout(find_three_variable_solutions, args=(BASE_VALUES,))
        
        if base_solutions['three'] and len(base_solutions['three']) > 0:
            has_solution = True
            display_solutions(base_solutions['three'], 3)
    
    if has_solution:
        print(f"\n使用基础系数列表，共找到有效解")
        return
    
    # 如果基础系数没有找到解，尝试波动系数
    print(f"\n==== 尝试波动系数 ====")
    
    fluctuated_solutions = {
        'single': run_with_timeout(find_single_variable_solutions, args=(FLUCTUATED_VALUES,)),
        'two': run_with_timeout(find_two_variable_solutions, args=(FLUCTUATED_VALUES,)),
        'three': []
    }
    
    has_solution = False
    
    # 显示单变量解
    if fluctuated_solutions['single']:
        has_solution = True
        display_solutions({a: [sol] for a, sol in zip(FLUCTUATED_VALUES, fluctuated_solutions['single']) if sol}, 1)
    
    # 显示双变量解
    if fluctuated_solutions['two'] and len(fluctuated_solutions['two']) > 0:
        has_solution = True
        display_solutions(fluctuated_solutions['two'], 2)
    
    # 单变量和双变量都无解时，尝试三变量解
    if not has_solution:
        print(f"\n==== 单变量和双变量无解，尝试三变量解 ====")
        fluctuated_solutions['three'] = run_with_timeout(find_three_variable_solutions, args=(FLUCTUATED_VALUES,))
        
        if fluctuated_solutions['three'] and len(fluctuated_solutions['three']) > 0:
            has_solution = True
            display_solutions(fluctuated_solutions['three'], 3)
    
    if has_solution:
        print(f"\n使用波动系数列表，共找到有效解")
        return
    
    # 如果所有系数集都没有找到解
    print("\n没有找到符合条件的解，即使使用波动后的系数列表。")

if __name__ == "__main__":
    start_time = time.time()
    main()
    print(f"\n总耗时: {time.time() - start_time:.2f}秒")    