本公開涉及人工智能,具體地,涉及一種基于互信息正則化的聯(lián)邦學習后驗推理方法及系統(tǒng)。
背景技術:
1、聯(lián)邦學習是一種分布式機器學習方法,在多個本地節(jié)點上訓練模型,無需將原始數(shù)據集傳輸?shù)街行姆掌?,通常用于隱私任務,例如醫(yī)療保健、金融領域。標準的聯(lián)邦學習的目標為訓練一個泛化性能較好的全局模型,完成訓練后的最優(yōu)全局模型的參數(shù)為固定值,即模型參數(shù)的點估計。然而,從概率論的角度來看,采用模型參數(shù)的點估計作為模型權重,容易在客戶端的稀缺訓練數(shù)據上出現(xiàn)過擬合。更嚴重的是,這種方法未能評估模型的不確定性,進而可能導致模型做出過于自信的決策。在某些安全關鍵的聯(lián)邦學習應用中,如自動駕駛、醫(yī)療健康和金融領域,可靠的不確定性評估尤為重要。
2、由此,提出了對模型權重進行后驗推斷,這在理論上能夠有效防止對稀缺訓練數(shù)據的過擬合,同時提供了一種方式來評估權重估計中的不確定性,并將其傳播到模型的預測中。然而,在聯(lián)邦學習場景下,數(shù)據異構性導致的本地后驗偏差問題仍未得到有效解決。
技術實現(xiàn)思路
1、針對現(xiàn)有技術中的缺陷,本公開的目的是提供一種基于互信息正則化的聯(lián)邦學習后驗推理方法及系統(tǒng)。
2、為實現(xiàn)上述目的,根據本公開的第一方面,提供一種基于互信息正則化的聯(lián)邦學習后驗推理方法,應用于客戶端,所述客戶端包括本地模型,包括:
3、接收服務器端發(fā)送的經過初始化的全局模型或者更新的全局模型;
4、根據所述服務器端發(fā)送的所述經過初始化的全局模型或者所述更新的全局模型,對本地模型進行初始化,確定經過初始化的所述本地模型以及所述本地模型的先驗概率分布;
5、采用本地數(shù)據對所述經過初始化的所述本地模型進行模型訓練,確定所述本地模型的損失函數(shù),所述本地模型的損失函數(shù)包括所述本地數(shù)據的訓練標簽與真實標簽之間的交叉熵損失和所述本地模型的先驗概率分布的正則項;
6、根據所述本地模型的損失函數(shù),基于反向傳播對所述本地模型進行第一更新處理,確定經過第一更新處理的本地模型;
7、根據從預設的正態(tài)分布中采樣的高斯噪聲對所述經過第一更新處理的本地模型進行第二更新處理,確定服從局部后驗分布的本地模型;
8、根據經過末次所述第一更新處理、末次所述第二更新處理后的服從局部后驗分布的本地模型和初始的本地模型,確定本地模型改變量;
9、將所述本地模型改變量發(fā)送至所述服務器端。
10、可選地,所述根據所述本地模型的損失函數(shù),基于反向傳播對所述本地模型進行第一更新處理,確定經過第一更新處理的本地模型,包括:
11、根據所述本地模型的損失函數(shù),基于反向傳播確定所述本地模型參數(shù)的更新梯度;
12、將所述本地模型與所述本地模型參數(shù)的更新梯度和學習率的乘積作差,確定所述經過第一更新處理的本地模型。
13、可選地,所述根據從預設的正態(tài)分布中采樣的高斯噪聲對所述經過第一更新處理的本地模型進行第二更新處理,確定服從局部后驗分布的本地模型,包括:
14、在預設的正態(tài)分布進行噪聲采樣,確定采樣的高斯噪聲;
15、將所述采樣的高斯噪聲與所述經過第一更新處理的本地模型作和,確定所述服從局部后驗分布的本地模型。
16、可選地,所述方法還包括:
17、當所述第一更新處理、所述第二更新處理的更新次數(shù)均達到預設的更新閾值時,所述本地模型停止更新,確定所述經過末次所述第一更新處理、末次所述第二更新處理后的服從局部后驗分布的本地模型。
18、可選地,所述根據經過末次所述第一更新處理、末次所述第二更新處理后的服從局部后驗分布的本地模型和初始的本地模型,確定本地模型改變量,包括:
19、將所述經過末次所述第一更新處理、末次所述第二更新處理后的服從局部后驗分布的本地模型和所述初始的本地模型作差,確定所述本地模型改變量。
20、根據本公開的第二方面,提供一種基于互信息正則化的聯(lián)邦學習后驗推理方法,應用于服務器端,包括:
21、初始化全局模型,確定經過初始化的全局模型;
22、將所述經過初始化的全局模型發(fā)送至客戶端;
23、接收所述客戶端發(fā)送的本地模型改變量;
24、根據所述客戶端發(fā)送的本地模型改變量,確定更新的全局模型;
25、將所述更新的全局模型發(fā)送至所述客戶端;
26、當采樣所述更新的全局模型收斂至服從最優(yōu)全局后驗分布樣本時,確定目標全局模型。
27、可選地,所述根據所述客戶端發(fā)送的本地模型改變量,確定更新的全局模型,包括:
28、對所述客戶端發(fā)送的本地模型改變量進行全局聚合處理,確定聚合改變量;
29、對所述聚合改變量進行滑動平均處理,確定經過滑動平均處理的聚合改變量;
30、將所述經過滑動平均處理的聚合改變量和本輪原始的全局模型進行線性相加,確定更新的全局模型。
31、根據本公開的第三方面,提供一種基于互信息正則化的聯(lián)邦學習后驗推理系統(tǒng),應用于客戶端,所述客戶端包括本地模型,包括:
32、客戶端接收模塊,接收服務器端發(fā)送的經過初始化的全局模型或者更新的全局模型;
33、客戶端初始化模塊,用于根據所述服務器端發(fā)送的所述經過初始化的全局模型或者所述更新的全局模型,對本地模型進行初始化,確定經過初始化的所述本地模型以及所述本地模型的先驗概率分布;
34、客戶端模型訓練模塊,用于采用本地數(shù)據對所述經過初始化的所述本地模型進行模型訓練,確定所述本地模型的損失函數(shù),所述本地模型的損失函數(shù)包括所述本地數(shù)據的訓練標簽與真實標簽之間的交叉熵損失和所述本地模型的先驗概率分布的正則項;
35、客戶端第一更新模塊,用于根據所述本地模型的損失函數(shù),基于反向傳播對所述本地模型進行第一更新處理,確定經過第一更新處理的本地模型;
36、客戶端第二更新模塊,用于根據從預設的正態(tài)分布中采樣的高斯噪聲對所述經過第一更新處理的本地模型進行第二更新處理,確定服從局部后驗分布的本地模型;
37、客戶端模型改變量確定模塊,用于根據經過末次所述第一更新處理、末次所述第二更新處理后的服從局部后驗分布的本地模型和初始的本地模型,確定本地模型改變量;
38、客戶端發(fā)送模塊,用于將所述本地模型改變量發(fā)送至所述服務器端。
39、根據本公開的第四方面,提供一種基于互信息正則化的聯(lián)邦學習后驗推理系統(tǒng),應用于服務器端,包括:
40、服務器端初始化模塊,用于初始化全局模型,確定經過初始化的全局模型;
41、服務器端第一發(fā)送模塊,用于將所述經過初始化的全局模型發(fā)送至客戶端;
42、服務器端接收模塊,用于接收所述客戶端發(fā)送的本地模型改變量;
43、服務器端全局聚合模塊,用于根據所述客戶端發(fā)送的本地模型改變量,確定更新的全局模型;
44、服務器端第二發(fā)送模塊,用于將所述更新的全局模型發(fā)送至所述客戶端;
45、服務器端目標全局模型確定模塊,用于當采樣所述更新的全局模型收斂至服從最優(yōu)全局后驗分布樣本時,確定目標全局模型。
46、根據本公開的第五方面,提供一種基于互信息正則化的聯(lián)邦學習后驗推理系統(tǒng),包括:
47、本地模型更新模塊,用于在本地訓練節(jié)點根據本地模型的損失函數(shù),基于反向傳播對所述本地模型進行第一更新處理,并采用從預設的正態(tài)分布中采樣的高斯噪聲對經過第一更新處理的本地模型進行第二更新處理;
48、模型改變量確定模塊,用于在所述第一更新處理、所述第二更新處理的更新次數(shù)均達到預設的更新閾值時,確定本地模型改變量;
49、全局聚合模塊,用于將客戶端本輪發(fā)送的本地模型改變量進行全局聚合處理,并更新全局模型;
50、目標全局模型確定模塊,用于確定目標全局模型,當采樣所述更新的全局模型收斂至服從最優(yōu)全局后驗分布樣本時,確定為所述目標全局模型。
51、通信模塊,用于將所述客戶端的本地模型改變量傳輸至服務器端,并將所述服務器端的經過初始化的全局模型或者更新的全局模型傳輸至所述客戶端。
52、與現(xiàn)有技術相比,本公開實施例具有如下至少一種有益效果:
53、通過上述技術方案,客戶端接收服務器端發(fā)送的經過初始化的全局模型或者更新的全局模型,并基于經過初始化的全局模型或更新的全局模型,對本地模型進行初始化以及獲取本地模型的先驗概率分布,客戶端采用本地數(shù)據進行模型訓練,以進行局部后驗推斷,在局部后驗推斷過程中,在本地模型的損失函數(shù)中引入全局模型與數(shù)據間互信息正則化項,即本地模型的先驗概率分布的正則項,從而獲得服從局部后驗分布的本地模型,緩解全局后驗推斷中的偏差,進而提高全局模型的泛化能力,在不確定性校準方面實現(xiàn)更優(yōu)的性能。
54、本公開的實施例,基于隨機梯度朗之萬動力學算法進行局部后驗推斷,實現(xiàn)對帶有互信息正則化的局部目標函數(shù)的最優(yōu)后驗的高效采樣。