本發(fā)明涉屬于聯(lián)邦學習,具體是一種融合雙重對比損失的數(shù)據(jù)類別適應(yīng)性平衡聯(lián)邦學習方法。
背景技術(shù):
1、近年來,隨著人工智能技術(shù)的快速發(fā)展,數(shù)據(jù)隱私和安全性問題日益受到關(guān)注,特別是在醫(yī)療領(lǐng)域,隨意收集和共享數(shù)據(jù)已不可行。為解決這一挑戰(zhàn),聯(lián)邦學習作為一種邊緣智能范式,逐漸受到廣泛關(guān)注。聯(lián)邦學習能夠在不共享原始數(shù)據(jù)的情況下,讓各方共同訓練機器學習模型,打破了數(shù)據(jù)孤島的限制,同時確保了隱私保護和數(shù)據(jù)安全。然而,實際應(yīng)用中,不同客戶端的數(shù)據(jù)集通常是非獨立同分布的,這帶來了新的挑戰(zhàn)。
2、為了應(yīng)對數(shù)據(jù)異構(gòu)性的挑戰(zhàn),已經(jīng)有了一系列方法來改善聯(lián)邦學習的性能。一些方法通過優(yōu)化本地訓練方式或正則化客戶端訓練,來接近全局最優(yōu)目標,但這些方法可能帶來隱私泄露或通信和計算開銷的增加。其他方法則從服務(wù)器聚合的角度出發(fā),如feddisco考慮數(shù)據(jù)集差異和大小,fedlc則通過標簽聚類選擇有益的模型進行聚合,但這些方法可能在嚴重數(shù)據(jù)異構(gòu)性下效果有限,或忽略了潛在有用的信息,影響模型的泛化能力。
3、針對現(xiàn)有研究方案的局限性,以及無法從根本解決客戶端之間數(shù)據(jù)非獨立同分布問題,一些研究開始重點關(guān)注融入數(shù)據(jù)預(yù)處理技術(shù),以進一步提升聯(lián)邦學習的性能。例如,fraug提出在低維特征空間運用數(shù)據(jù)增強來提高模型性能。這類針對數(shù)據(jù)在數(shù)據(jù)預(yù)處理階段的優(yōu)化操作,不僅效率高,而且沒有額外的通信開銷問題,但是在數(shù)據(jù)增強時,該類方案并沒有考慮到數(shù)據(jù)增強可能帶來的過擬合問題,因此模型可能會過度適應(yīng)訓練集的噪聲,從而失去泛化能力。
技術(shù)實現(xiàn)思路
1、針對聯(lián)邦學習中客戶端生成數(shù)據(jù)的多樣性和收集方式的不同導致的客戶端數(shù)據(jù)類別不平衡以及由此引發(fā)的模型性能下降的問題,本發(fā)明提供了一種融合雙重對比損失的數(shù)據(jù)類別適應(yīng)性平衡聯(lián)邦學習方法。
2、本發(fā)明提供的一種融合雙重對比損失的類別適應(yīng)性平衡聯(lián)邦學習方法,適用于聯(lián)邦學習框架,包括中央服務(wù)器和中央服務(wù)器和n個客戶端;每個客戶端擁有其獨立的本地原始數(shù)據(jù)集;中央服務(wù)器在每輪聯(lián)邦學習通信開始時,將全局模型發(fā)送給所有客戶端,在客戶端訓練結(jié)束后,接收來自客戶端的本地模型,對所有接收到的客戶端本地模型進行聚合形成新的全局模型;所述聯(lián)邦學習方法包括如下步驟:
3、所述聯(lián)邦學習方法包括如下步驟:
4、步驟1、在第t輪時,客戶端i收到中央服務(wù)器的全局模型后,對本地原始數(shù)據(jù)集進行過采樣操作;
5、步驟2、客戶端i在過采樣后的數(shù)據(jù)集上,結(jié)合分類損失函數(shù)和雙重對比損失函數(shù)得到客戶端本地訓練損失函數(shù),對客戶端的本地模型進行訓練得到更新后的本地模型;
6、步驟3、客戶端將更新后的本地模型發(fā)給中央服務(wù)器,由中央服務(wù)器按客戶端數(shù)據(jù)比例進行聚合形成新的第t+1輪全局模型。
7、進一步的,
8、所述步驟1的具體步驟如下:
9、步驟1.1、客戶端在開始每輪訓練之前執(zhí)行過采樣操作,首先統(tǒng)計本地原始數(shù)據(jù)集中每個類別樣本的數(shù)量,統(tǒng)計公式如下:
10、
11、其中,1(·)為指示函數(shù),表示客戶端i中含有的類別為c的樣本個數(shù),di表示第i個客戶端的本地原始數(shù)據(jù)集,xi表示di中的樣本,yi表示對應(yīng)于xi的標簽。步驟1.2、找到本地原始數(shù)據(jù)集中樣本數(shù)量最多的類別,并根據(jù)該類別的樣本數(shù)來得到一個目標采樣值:
12、
13、其中,r表示目標采樣值,z表示本地原始數(shù)據(jù)集中總的樣本類別數(shù)量,是超參數(shù),用于確定一個合理的標準來衡量本地數(shù)據(jù)集,確保數(shù)據(jù)集的類別分布相對一致且合理。
14、步驟1.3、對于每個類別c,根據(jù)該類別樣本數(shù)與目標值r的大小關(guān)系來決定是否進行過采樣,具體操作如下:
15、若則進行過采樣操作,將過采樣的樣本數(shù)設(shè)為過采樣的子數(shù)據(jù)集表示為:
16、
17、其中,為客戶端i上類別為c的樣本集,為從樣本集中隨機均勻采樣得到的過采樣子集,包含個樣本;
18、步驟1.4、在過采樣之后,將從本地原始數(shù)據(jù)集中隨機采樣得到的過采樣的子數(shù)據(jù)集并到本地原始數(shù)據(jù)集中。
19、進一步的,所述步驟2中,雙重對比損失函數(shù)由類內(nèi)緊湊性損失與有監(jiān)督對比損失這兩個損失函數(shù)組成;
20、所述客戶端本地訓練損失函數(shù)的公式如下:
21、l=lce+λlint+λlcon,
22、其中,λ為控制類內(nèi)緊湊性損失和有監(jiān)督對比損失影響力的超參數(shù);lce表示交叉熵損失,lint類內(nèi)緊湊性損失,lcon表示有監(jiān)督對比損失,l表示客戶端本地訓練損失。
23、進一步的,所述類內(nèi)緊湊性損失共由類間距離項和類中心距離項組成;
24、所述類間距離項的公式為:
25、
26、其中,z(x)=f(θi,x),表示從模型中提取到的樣本特征,樣本x的標簽y=c,θi為客戶端i的本地模型參數(shù);lint,d表示類間距離;z(xi)表示表示從模型中提取到的樣本xi的特征;
27、所述類中心距離項的公式為:
28、
29、
30、其中,lint,c表示類中心距離;表示類中心特征;
31、所述類內(nèi)緊湊性損失函數(shù)公式如下:
32、lint=αlint,d+(1-α)lint,c,
33、其中,lint類內(nèi)緊湊性損失,α表示超參數(shù)。
34、進一步的,所述有監(jiān)督對比損失是對一個batch內(nèi)的樣本特征進行對比,對比操作為減小正樣本對特征之間的距離,增大負樣本對特征之間的距離;正樣本對的距離定義如下:
35、
36、其中,δ為決策邊界,dp表示正樣本對的距離;
37、對于一個batch內(nèi)與樣本x有不同標簽的所有樣本,它們與樣本x形成的負樣本對距離則定義如下:
38、
39、其中,dn表示負樣本對距離;
40、將正樣本對距離與負樣本對距離相加形成有監(jiān)督對比損失,公式如下:
41、lcon=dp+dn,
42、其中,lcon表示有監(jiān)督對比損失。
43、進一步的,所述步驟2中,根據(jù)客戶端本地訓練損失函數(shù),客戶端通過進行如下的梯度下降來更新本地模型:
44、
45、其中,η是模型訓練的學習率,θi表示客戶端i當前的本地模型,表示客戶端本地梯度。
46、與現(xiàn)有技術(shù)相比,本發(fā)明的改進之處在于:
47、1)本發(fā)明在數(shù)據(jù)預(yù)處理階段進行過采樣操作,通過增加少數(shù)類別樣本在訓練中的出現(xiàn)次數(shù),平衡了訓練樣本的類別分布,確保了每個類別在訓練中都能得到充分的關(guān)注和公平對待,從而提高了模型的泛化能力。
48、2)本發(fā)明設(shè)計了類內(nèi)緊湊性損失和有監(jiān)督對比損失兩種對比性質(zhì)的損失函數(shù),保證了同類別間樣本特征的緊湊性,同時增大了不同類別樣本之間的差異。這兩種對比性質(zhì)的損失函數(shù)能夠解決過采樣重復樣本可能帶來的過擬合問題,并促進模型在潛在的特征空間中學習到更有效的特征表示。
49、3)本發(fā)明提出的方法能夠有效地解決客戶端數(shù)據(jù)非獨立同分布時性能下降的問題。與其它先進算法相比,本發(fā)明方法在性能方面展現(xiàn)出更為顯著的優(yōu)勢,且無需額外的通信開銷,同時也不會增加數(shù)據(jù)隱私泄露的風險。