使用混合TabNet架构结合堆叠集成学习进行心脏病预测
《Frontiers in Physiology》:Heart disease prediction using hybrid TabNet architecture with stacked ensemble learning
【字体:
大
中
小
】
时间:2025年11月06日
来源:Frontiers in Physiology 3.4
编辑推荐:
心血管疾病预测的堆叠集成模型研究,提出结合TabNet(专为表格数据设计的深度学习模型)和XGBoost(高效树模型)的元学习框架,通过Logistic回归或支持向量机整合两者输出。实验表明,该模型在Kaggle和UCI数据集上显著优于单一模型,提升准确率11%-4.3%,减少假阴性(如Kaggle数据集假阴性减少73.8%),同时保持高可解释性。
心血管疾病(CVDs)仍然是全球范围内的首要死亡原因,而早期检测对于及时干预和改善患者预后至关重要。然而,当前的预测工具常常受到患者数据中噪声、异构性以及准确度有限的困扰。为了解决这一挑战,我们提出了一种堆叠集成框架,该框架结合了TabNet,一种专为结构化表格数据设计的深度学习模型,以及XGBoost,一种在结构化数据上表现出强大鲁棒性的树基方法。通过将这些模型的输出整合为Logistic Regression(LR)或Support Vector Machine(SVM)作为元学习器,我们构建了一个兼顾准确性和可解释性的系统。在Kaggle和UCI CVD数据集上的测试表明,我们的集成模型在多个性能指标上均优于基线模型,包括准确率、F1-score、精确率、召回率、ROC-AUC、PR-AUC和Matthews相关系数(MCC)。这些结果表明,将深度学习与树基模型相结合,为改进心血管疾病风险预测提供了一种切实可行的方法,有助于临床医生做出更可靠的早期CVD检测决策。
### 1. 引言
心血管疾病是全球范围内对人类健康构成重大威胁的疾病之一。根据世界卫生组织的数据,每年大约有1790万人死于心血管疾病,这一数字凸显了其在公共卫生领域的重要性。近年来,随着机器学习技术的快速发展,许多研究开始探索利用这些算法预测和检测多种疾病,包括偏头痛、抑郁症、阿尔茨海默病以及心血管疾病。深度学习模型在图像和音频分类任务中已经展现出强大的性能,而递归神经网络(RNNs)和基于Transformer的模型在自然语言处理领域取得了显著成果。然而,表格数据在许多现实世界的应用中仍然未得到深度学习方法的充分支持。表格数据通常包含数值和分类特征,是人工智能驱动决策的重要基础。相比之下,基于树的机器学习模型如决策树(DTs)在表格数据的应用中依然占据主导地位。这些模型的优势在于其能够高效地捕捉与表格数据中常见的超平面结构一致的决策边界,并且在计算效率和训练速度上表现优异。然而,像多层感知机(MLPs)这样的深度学习方法往往参数过多,缺乏归纳偏置,从而难以有效捕捉表格数据中的复杂模式。因此,探索适用于表格数据的深度学习方法仍然是一个重要的研究方向。
### 2. 背景与文献综述
本研究的背景建立在心血管疾病预测和集成学习的现有文献之上。文献综述是学术研究中的基础性工作,其目的是系统地评估已有研究,识别当前研究在学术讨论中的位置,并突出其创新性和相关性。通过综述已有文献,我们不仅能够为当前研究提供理论支持,还能指出心血管疾病预测领域中尚未解决的关键挑战。
心血管疾病的主要症状包括胸痛,这种不适感可能表现为压力、紧绷、搏动、沉重或挤压感。除了胸痛,心脏相关疼痛还可能扩散至肩部、手臂、颈部、喉咙、下颌或上背部。值得注意的是,女性在50岁以上比男性更容易患心血管疾病。另一方面,男性往往在较年轻时就受到此类疾病的困扰。其他常见症状包括过度出汗、呼吸困难、胸闷、头晕、疲劳、心跳加快、恶心、肩臂疼痛、胸压感以及在某些情况下出现强烈焦虑或不规则心律。
心血管疾病涵盖了影响心脏多个组成部分的一系列疾病,包括心肌、瓣膜、节律和血管结构。主要疾病包括冠状动脉疾病(CAD)、心力衰竭、心律失常和瓣膜功能障碍。临床表现可能包括胸痛、呼吸困难、疲劳和不规则心律。全球范围内,心血管疾病仍然是首要的死亡原因,若未得到有效管理,将显著影响患者的生活质量。
冠状动脉疾病,也称为缺血性心脏病或心肌梗死,是心血管疾病中最常见和最严重的类型之一。其诊断和治疗管理在低收入和中等收入国家中面临诸多挑战。这些挑战源于缺乏先进的诊断工具以及训练有素的医疗人员的短缺,这导致了准确预测和及时干预的困难。
冠状动脉疾病(CHD)通过阻碍动脉血流,从而降低身体组织的氧气和营养供应。这种障碍通常由动脉粥样硬化引起,即动脉壁上脂质丰富的斑块和钙的积累。世界卫生组织指出,心血管疾病,尤其是心脏病和中风,是全球死亡的主要原因。CHD的风险因素包括年龄、性别、遗传易感性、肥胖、糖尿病、心理压力和不良饮食习惯。
心脏在维持全身循环中发挥着至关重要的作用;血液灌注不足可能导致大脑等重要器官的功能受损,而完全的心脏衰竭最终会导致死亡。广义的心脏病包括影响心脏肌肉和血管网络的任何病理变化。冠状动脉疾病作为心血管疾病的主要亚型,占全球死亡的约20%,主要由于心肌梗死和脑血管事件。
机器学习(ML)作为一种变革性的方法,已被广泛应用于从大规模和复杂数据集中提取可操作的见解。它利用预测算法来预测健康结果,并使用描述性模型来揭示数据中的潜在模式。各种机器学习技术,如MLP、决策树(DT)、K-最近邻(KNN)、支持向量机(SVM)和朴素贝叶斯(NB),在解释大规模医疗数据方面展现了显著的成效。
在一项研究中,Shah等人(2020)引入了一种新颖的基于机器学习的框架,结合了量子神经网络以实现心血管疾病的早期检测。该模型在689名有症状的患者数据上进行训练,并在Framingham数据集上进行验证,显著优于传统的Framingham风险评分(FRS),准确率达到98.57%,而FRS的准确率仅为19.22%。这一显著的提升凸显了该模型在辅助临床实现精确诊断和有效治疗策略方面的潜力。
类似地,Bhatt等人(2023)在Cleveland Heart Disease数据集上应用了多种监督学习算法,该数据集包含303条记录和17个属性。在测试的模型中,KNN达到了最高的预测准确率,为90.8%,这进一步强调了算法选择在优化诊断结果中的重要性。
在另一项研究中,Shah等人(2020)利用机器学习方法如逻辑回归、单变量特征选择和主成分分析(PCA)来评估心血管疾病风险,特别是代谢相关脂肪肝病(MAFLD)患者。高胆固醇水平、动脉斑块积累和糖尿病持续时间被确定为关键预测因素。该模型有效地对高风险(85.11%)和低风险(79.17%)个体进行了分类,取得了0.87的曲线下面积(AUC)分数,展示了基于常规临床数据的强大分类能力。
在一项比较研究中,Hasan和Bao(2021)评估了特征选择技术——过滤、包装和嵌入方法——在增强CVD预测中的效率。他们应用了基于布尔的框架来识别最佳特征子集,并对包括随机森林、SVM、KNN、朴素贝叶斯和XGBoost在内的多个分类器进行了基准测试。XGBoost与包装方法的结合实现了最佳性能,准确率达到73.74%,优于SVM(73.18%)和人工神经网络(ANN)(73.20%)。
然而,现有研究的一个主要挑战是依赖相对较小的数据集,这往往导致过拟合和泛化能力受限。因此,本研究采用了一组包含70,000条患者记录的数据集,包含11个特征,以提高模型的稳健性和减少过拟合。一项关于大规模数据集的心脏病预测模型的全面比较分析详见表1,进一步证明了数据驱动的机器学习方法在临床实践中的有效性。
### 3. 数据集与提出的方法
本节介绍了研究中使用的数据集,即UCI CVD数据集和Kaggle CVD数据集,并概述了提出的方法。后半部分详细描述了心血管疾病(CVD)检测的预处理和模型架构。
#### 3.1 数据集
Kaggle CVD数据集包含70,000条患者记录,每条记录包含16个特征属性和一个二元目标变量,表示是否患有心血管疾病。特征属性包括患者ID、年龄、性别、身高、体重、吸烟状态、舒张压和收缩压、胆固醇水平、血糖水平以及体力活动。该数据集的特征属性和两个样本实例详见表2。
表2列出了心血管疾病数据集的特征属性和两个样本实例。
UCI CVD数据集则来源于UCI机器学习仓库,包含920条实例。其中,561条属于正类(患有CVD),其余为负类(未患有CVD)。该数据集来自四个不同的来源,分别是匈牙利、克利夫兰、长滩VA和瑞士。UCI数据集的特征属性和样本实例详见表3。
表3展示了UCI数据集的特征属性和样本实例。
#### 3.2 提出的方法
在完成必要的预处理步骤后,提出的方法的主架构如图1所示。该模型基于堆叠集成方法,已被证明在分类和预测任务中优于单个学习模型。这主要归因于集成多个基学习器,其中每个模型的弱点都可以被其他模型的优势所补偿,从而提高整体预测准确性和鲁棒性。
堆叠是一种方法,其中n个机器学习算法可以独立训练并逐层堆叠。训练后,每个模型的输出将用于下一层,其中另一个机器学习模型作为元学习器,基于前一层的输出预测最终输出。在本研究中,提出的方法采用了TabNet和XGBoost两个互补的模型作为基学习器。TabNet是一种专为表格数据设计的深度学习模型,而XGBoost则在先前的研究中表现出卓越的性能。基学习器的性能通过标准指标如精确率、召回率、F1-score、准确率和混淆矩阵进行测量。TabNet和XGBoost的输出通过堆叠方法与元学习器集成,以产生最终的CVD预测。我们比较了SVM和LR这两种传统模型作为元学习器,通过各种超参数调优来评估其性能。
#### 3.3 预测集成与输出
预测集成与输出是提出模型的核心组成部分。TabNet和XGBoost的输出通过堆叠集成方法进行集成,以提高预测性能。具体而言,我们评估了逻辑回归(LR)和支持向量机(SVM)作为元学习器的性能。LR作为一种有效的二分类基线模型,因其可解释性和对数据中线性关系建模的能力而被广泛使用。SVM则通过其基于边缘的优化方法捕捉互补的决策边界。通过比较这两种元学习器,我们评估了集成模型的鲁棒性和泛化能力。结果表明,两种模型都表现出色,其中LR在概率指标上的校准略好,而SVM在准确率和F1-score方面表现相似,这凸显了提出堆叠集成框架在利用不同聚合策略方面的灵活性。
尽管提出堆叠集成方法相比单模型方法引入了额外的计算开销,但其复杂性在现代硬件上仍然是可处理的。训练是在配备24GB显存的NVIDIA RTX 4090 GPU和配备64GB内存的Intel i9处理器上进行的。在Kaggle数据集上,TabNet平均需要约2.3小时才能收敛,而在UCI数据集上则需要约1.1小时,使用早停和批处理学习。元学习器(逻辑回归)的训练时间可以忽略不计(少于2分钟)。整体流程的计算复杂度为O(n × (f × d)),其中n是样本数量,f是特征数量,d是集成深度(基学习器数量)。尽管训练时间有所增加,但堆叠框架在预测性能上的显著提升,使得准确率与计算成本之间的权衡在临床应用中是合理的。
### 4. 实验与结果
#### 4.1 超参数调优
表4展示了我们在实验中应用的超参数设置及其搜索范围。为了选择这些参数,我们进行了系统的网格搜索,以在10折分层交叉验证中选择最佳的平均接收者操作特征-面积下曲线(ROC-AUC)。
对于TabNet,除了表4中列出的超参数搜索范围外,我们还应用了多种正则化策略,以确保可重复性和防止过拟合。具体而言,我们对权重应用了L2正则化,系数为1e-5,并在全连接层上应用了0.2的Dropout率。每块后均使用了批量归一化(Batch Normalization)以稳定训练。基于验证损失的早停策略(耐心为20个周期)也被采用,以避免过训练。
为了防止数据泄露并确保无偏评估,我们采用了分层10折交叉验证。对于每折,基学习器(TabNet和XGBoost)在9折上进行训练,预测结果在保留的1折上记录为折外预测。这些预测用于训练逻辑回归元学习器。测试折在训练过程中未被元学习器所接触。我们使用了一个固定的随机种子42以确保结果的可重复性。所有报告的指标(准确率、F1-score、ROC-AUC、Matthews相关系数(MCC))均基于10折交叉验证的平均值及其标准差。
Kaggle数据集相对平衡(50/50),而UCI数据集表现出适度的类别不平衡(约61/39)。为了确保稳健的评估,我们报告了多个性能指标:准确率、F1-score、精确率、召回率、ROC-AUC、精确率-召回率面积下曲线(PR-AUC)和MCC。这些指标如MCC和PR-AUC对类别不平衡影响较小,因此能更可靠地评估模型性能,特别是对于少数类别。
#### 4.2 Kaggle CVD数据集上的性能
如表5所示,提出的模型(TabNet与XGBoost结合SVM作为元学习器)在Kaggle CVD数据集上达到了最高的准确率80.70%和F1-score 77.52%。这一性能显著优于所有单独的基学习器。在单独模型中,TabNet表现最佳,准确率为77.40%,F1-score为76.82%,其次是XGBoost。
传统模型如逻辑回归(LR)和支持向量机(SVM)表现相对较差,准确率分别为71.00%和70.00%,F1-score分别为69.90%和68.39%。这些结果突显了它们在捕捉表格临床数据中复杂模式方面的局限性。
集成策略的有效性在提出的模型中得到了进一步验证,该模型使用LR作为元学习器,也表现出优越的性能(准确率:80.20%,F1-score:78.42%),优于单独学习器。总体而言,这些发现表明,整合不同模型类型可以弥补个体模型的不足,从而增强预测能力。
#### 4.3 UCI CVD数据集上的性能
表6展示了提出的模型在UCI CVD数据集上的结果。总体而言,所有模型在该数据集上的表现优于Kaggle数据集,这表明UCI数据集可能更加结构化或噪声较少。
提出的集成模型(TabNet与XGBoost结合LR作为元学习器)达到了95.20%的准确率和91.92%的F1-score。使用SVM作为元学习器的集成模型也表现出色,准确率为94.30%,F1-score为91.14%。
在单独模型中,TabNet再次表现出最佳性能,准确率为90.90%,F1-score为86.39%,其次是XGBoost和LSTM。集成模型在F1-score上的改进确认了堆叠不仅提高了准确率,还提供了更平衡的精确率与召回率之间的权衡。
#### 4.4 比较分析
在Kaggle数据集上,提出的模型比最佳单独模型(TabNet)的准确率提高了超过11%。在UCI数据集上,准确率的提升约为4.3%,F1-score也有显著提高。集成方法在噪声较大或结构较弱的数据集上显示出更显著的影响,这在Kaggle数据集上尤为明显。这些结果表明,提出的堆叠集成方法在多样性和复杂性数据集上既有效又稳健,利用预训练的深度学习模型和树基机器学习算法的互补优势,从而提高了预测性能。
#### 4.5 元学习器的贡献
我们的实验结果在表7中提供了实证证据,表明提出的堆叠集成方法在两个数据集上均优于基线方法。这一改进可以归因于两个关键因素:首先,逻辑回归学习器以数据驱动的方式为基学习器选择最佳权重,而非简单平均或启发式权重;其次,逻辑回归能够建模TabNet和XGBoost输出之间的相互作用,捕捉无法通过线性平均完全利用的互补决策边界。这些能力在类别不平衡的情况下提高了泛化能力,使得模型在临床应用中更具实用价值。
#### 4.6 堆叠集成的统计显著性
表8展示了提出的堆叠集成方法与基学习器的统计显著性分析。为了验证提出的堆叠集成方法相对于单独基学习器的优越性,我们进行了统计显著性测试,包括使用Bootstrap置信区间和McNemar检验。
Bootstrap置信区间:我们计算了1000个Bootstrap样本的ROC-AUC的95%置信区间。在UCI数据集上,堆叠集成方法的ROC-AUC为0.96 [0.95–0.97],而TabNet的ROC-AUC为0.92 [0.91–0.93]。在Kaggle数据集上,堆叠集成方法的ROC-AUC为0.91 [0.90–0.92],而TabNet的ROC-AUC为0.84 [0.83–0.85]。
McNemar检验:我们将堆叠集成模型的分类输出(正确与错误预测)与最佳基学习器进行了比较。得到的p值为0.018(UCI)和0.022(Kaggle),确认了模型在准确率和F1-score上的显著改进。
这些结果表明,提出的堆叠集成方法在两个数据集上均显著优于单独的基学习器,从而验证了其在临床应用中的价值。
#### 4.7 混淆矩阵分析
为了进一步验证模型的性能,我们对Kaggle CVD数据集进行了混淆矩阵分析。混淆矩阵用于评估模型的行为,分析真实和错误预测的分布。Kaggle CVD数据集是平衡的,包含70,000条患者记录,其中一半属于正类(患有CVD),另一半属于负类(未患有CVD)。
混淆矩阵显示,模型正确识别了27,282名患有CVD的患者(真阳性)和27,705名未患有CVD的患者(真阴性)。然而,7,718名实际患有CVD的患者被错误分类为未患有CVD(假阴性),而7,295名健康的患者被错误预测为患有CVD(假阳性)。
尽管模型在准确率、精确率、召回率和F1-score上表现出色,但假阴性的数量仍然令人担忧,因为未能检测到实际患有CVD的患者可能会影响干预时机,增加患者风险。尽管如此,该模型的性能与准确率的权衡被认为是合理的。
混淆矩阵分析突显了模型在识别真实和错误预测方面的优势,同时也为未来的模型调优和临床决策支持应用提供了有价值的反馈。
#### 4.8 混淆矩阵分析(UCI CVD数据集)
为了进一步评估模型的性能,我们构建了UCI CVD数据集的混淆矩阵。该数据集包含920条患者记录,由四个子集(Cleveland、Hungary、Switzerland和Long Beach VA)汇总而成,被二元化为两个类别:561名被诊断为患有心血管疾病(正类)的患者和359名未患有该疾病的患者(负类)。
模型正确预测了518名正类患者和311名负类患者,其中43名正类患者被错误分类为负类(假阴性),而49名健康患者被错误预测为正类(假阳性)。
这一混淆矩阵确认了模型的强分类能力,其中精确率(91.45%)和召回率(92.40%)均较高。假阴性率较低在临床应用中尤为重要,因为这表明模型在识别真正患有心血管疾病患者方面更少出现遗漏。整体F1-score为91.92%,准确率为95.20%,进一步证明了模型的有效性。
混淆矩阵揭示了模型的强判别能力,尽管Kaggle数据集的假阴性数量较高(7,718例),可能由于其数据的异构性和类别不平衡。在临床实践中,假阴性是至关重要的,因为它们代表了未被检测到的高风险患者。为了缓解这一问题,可以调整决策阈值,以优先考虑较高的召回率(灵敏度),从而确保潜在的CVD病例不会被遗漏。在筛查场景中,这种权衡是可以接受的,因为假阳性比漏诊的后果要小,使模型更适合于临床早期风险检测。
#### 4.9 局限性
尽管我们的研究展示了强大的预测性能,但仍需承认一些局限性。首先,评估仅在相对较小的基准数据集(UCI和Kaggle)上进行,这可能限制了结果的稳健性。因此,我们的框架的性能可能具有数据集特定性,需要在更大的、多机构的数据集上进一步验证。其次,所使用的数据集不能代表现实世界中的电子健康记录(EHRs),这些记录通常包含更嘈杂、不完整和异构的数据。因此,模型在常规临床环境中的实际适用性仍不确定。最后,由于两个数据集缺乏对多样族群、年龄组和合并人群的充分代表性,我们的发现对更广泛的患者群体的泛化性尚未经过测试。这些局限性突显了未来工作需要在现实世界、人口学多样化的EHR数据上进行外部验证,以更好地建立临床效用和公平性。
### 5. 结论与未来研究方向
本研究提出了一种结合TabNet和XGBoost,并以逻辑回归(LR)或支持向量机(SVM)作为元学习器的混合堆叠集成框架,用于心血管疾病(CVD)预测。与依赖于图像或序列数据的现有混合模型不同,我们的方法直接利用了专为表格数据设计的TabNet,并将其与XGBoost结合,以在深度表示学习和结构化特征鲁棒性之间取得平衡。LR元学习器进一步确保了预测的稳定集成,从而弥补了个体模型的弱点。
在Kaggle和UCI CVD数据集上的全面实验表明,提出的混合模型在多个指标上均优于传统机器学习和深度学习基线模型,包括准确率、F1-score、ROC-AUC、PR-AUC和MCC。更重要的是,该模型减少了假阴性,这是临床中至关重要的改进,因为漏诊可能会延迟干预,增加患者风险。这直接解决了以往TabNet或集成模型在大规模、异构CVD数据集上的系统验证不足的问题。
本研究的发现具有重要的意义。首先,它们提供了实证证据,表明在集成框架中,受到Transformer启发的模型可以实现对表格医疗数据的最先进性能。其次,结果突显了可解释性管道的价值,因为TabNet的稀疏注意力机制通过特征级透明度支持临床医生的信任。第三,该方法通过更早、更可靠地识别高风险患者,为临床决策支持提供了实用价值。
未来的研究可以进一步增强可解释性,例如通过可视化TabNet的注意力掩码或SHAP值,以提高临床可用性。引入多模态数据,如影像、电子健康记录和基因图谱,可能扩展预测能力,并提供更全面的患者健康视图。此外,联邦学习可能实现跨机构的隐私保护部署,从而提高模型在多样化人群中的泛化能力。
总之,本研究通过展示TabNet–XGBoost堆叠集成方法在心血管疾病预测中的稳健性、可解释性和临床相关性,填补了方法学和临床领域的空白。这不仅推动了医疗人工智能的发展,也为将集成学习可靠地整合到现实世界的心血管风险评估中铺平了道路。
生物通微信公众号
生物通新浪微博
今日动态 |
人才市场 |
新技术专栏 |
中国科学人 |
云展台 |
BioHot |
云讲堂直播 |
会展中心 |
特价专栏 |
技术快讯 |
免费试用
版权所有 生物通
Copyright© eBiotrade.com, All Rights Reserved
联系信箱:
粤ICP备09063491号