本公開(kāi)總體上涉及機(jī)器學(xué)習(xí)。更具體地,本公開(kāi)涉及用于校準(zhǔn)從教師模型到學(xué)生模型的蒸餾學(xué)習(xí)的技術(shù)。
背景技術(shù):
1、在機(jī)器學(xué)習(xí)中,知識(shí)蒸餾通??芍复鷮⒅R(shí)(例如,經(jīng)由蒸餾訓(xùn)練)從教師模型轉(zhuǎn)移到學(xué)生模型的過(guò)程。通常,盡管不一定,但是教師模型會(huì)比學(xué)生模型更大(例如,就參數(shù)的數(shù)量而言)。具體地,雖然大型模型(諸如非常深的神經(jīng)網(wǎng)絡(luò)或許多模型的集成)比小型模型具有更高的知識(shí)容量,但是這種容量可能不會(huì)在所有情況下都得到充分利用或都是需要的。例如,由于較小模型評(píng)估成本較低,因此可將它們部署在功能較弱的硬件(諸如移動(dòng)裝置)上。更一般地說(shuō),學(xué)生模型可被設(shè)計(jì)成更簡(jiǎn)單,訓(xùn)練更快,和/或根據(jù)部署(例如,系統(tǒng)約束)限制進(jìn)行部署。教師模型不必遵守此類(lèi)限制,并且可花費(fèi)更多時(shí)間進(jìn)行訓(xùn)練。因此,從教師模型到學(xué)生模型的知識(shí)蒸餾可帶來(lái)益處的情況有很多種。
技術(shù)實(shí)現(xiàn)思路
1、本公開(kāi)的實(shí)施例的各方面和優(yōu)點(diǎn)將在以下描述中部分地闡述,或者可從描述中學(xué)習(xí),或者可通過(guò)實(shí)施例的實(shí)踐來(lái)學(xué)習(xí)。
2、本公開(kāi)的一個(gè)示例方面涉及一種用于以改進(jìn)的計(jì)算效率執(zhí)行蒸餾訓(xùn)練的計(jì)算系統(tǒng)。該計(jì)算系統(tǒng)包括:一個(gè)或多個(gè)處理器;教師模型,該教師模型包括教師模型主體、教師logit頭和教師預(yù)測(cè)頭,其中教師模型主體被配置為處理輸入以生成教師中間表示,其中教師logit頭被配置為處理教師中間表示以生成教師logit值,并且其中教師預(yù)測(cè)頭被配置為處理教師logit值以生成教師概率值;學(xué)生模型,該學(xué)生模型包括學(xué)生模型主體、第一學(xué)生logit頭、第二學(xué)生logit頭和學(xué)生預(yù)測(cè)頭,其中學(xué)生模型主體被配置為處理輸入以生成學(xué)生中間表示,其中第一學(xué)生logit頭被配置為處理學(xué)生中間表示以生成第一學(xué)生logit值,其中第二學(xué)生logit頭被配置為處理學(xué)生中間表示以生成第二學(xué)生logit值,并且其中學(xué)生預(yù)測(cè)頭被配置為處理第一學(xué)生logit值和第二學(xué)生logit值以生成學(xué)生概率值;以及一個(gè)或多個(gè)非暫時(shí)性計(jì)算機(jī)可讀介質(zhì),該一個(gè)或多個(gè)非暫時(shí)性計(jì)算機(jī)可讀介質(zhì)共同存儲(chǔ)指令,這些指令在由一個(gè)或多個(gè)處理器執(zhí)行時(shí)使計(jì)算系統(tǒng)執(zhí)行操作。這些操作包括:基于教師logit值和第一學(xué)生logit值來(lái)評(píng)估第一損失函數(shù);基于第一損失函數(shù)來(lái)修改至少第一學(xué)生logit頭的一個(gè)或多個(gè)參數(shù);基于教師概率值和學(xué)生概率值來(lái)評(píng)估第二不同的損失函數(shù);以及基于第二損失函數(shù)來(lái)修改至少第二學(xué)生logit頭的一個(gè)或多個(gè)參數(shù)。
3、本公開(kāi)的另一個(gè)示例方面涉及一個(gè)或多個(gè)非暫時(shí)性計(jì)算機(jī)可讀介質(zhì),該一個(gè)或多個(gè)非暫時(shí)性計(jì)算機(jī)可讀介質(zhì)共同存儲(chǔ):機(jī)器學(xué)習(xí)學(xué)生模型,其中:機(jī)器學(xué)習(xí)學(xué)生模型包括學(xué)生模型主體、第一學(xué)生logit頭、第二學(xué)生logit頭和學(xué)生預(yù)測(cè)頭,學(xué)生模型主體被配置為處理輸入以生成學(xué)生中間表示,第一學(xué)生logit頭被配置為處理學(xué)生中間表示以生成第一學(xué)生logit值,第二學(xué)生logit頭被配置為處理學(xué)生中間表示以生成第二學(xué)生logit值,學(xué)生預(yù)測(cè)頭被配置為處理第一學(xué)生logit值和第二學(xué)生logit值以生成學(xué)生概率值,第一學(xué)生logit頭已使用第一損失函數(shù)進(jìn)行了訓(xùn)練,第一損失函數(shù)評(píng)估第一學(xué)生logit值和由教師模型生成的教師logit值,并且第二學(xué)生logit頭已使用第二損失函數(shù)進(jìn)行了訓(xùn)練,第二損失函數(shù)評(píng)估學(xué)生概率值和由教師模型生成的教師概率值;以及用于運(yùn)行機(jī)器學(xué)習(xí)學(xué)生模型以處理輸入以生成學(xué)生概率值的指令。
4、本公開(kāi)的另一個(gè)示例方面涉及一種用于以改進(jìn)的計(jì)算效率執(zhí)行蒸餾訓(xùn)練的計(jì)算系統(tǒng),該計(jì)算系統(tǒng)包括:一個(gè)或多個(gè)處理器;教師模型,該教師模型包括教師模型主體、教師logit頭和教師預(yù)測(cè)頭,其中教師模型主體被配置為處理輸入以生成教師中間表示,其中教師logit頭被配置為處理教師中間表示以生成教師logit值,并且其中教師預(yù)測(cè)頭被配置為處理教師logit值以生成教師概率值;多個(gè)學(xué)生模型,其中每個(gè)學(xué)生模型包括學(xué)生模型主體、第一學(xué)生logit頭和第二學(xué)生logit頭,其中學(xué)生模型主體被配置為處理輸入以生成學(xué)生中間表示,其中第一學(xué)生logit頭被配置為處理學(xué)生中間表示以生成第一學(xué)生logit值,其中第二學(xué)生logit頭被配置為處理學(xué)生中間表示以生成第二學(xué)生logit值;學(xué)生集成預(yù)測(cè)頭,該學(xué)生集成預(yù)測(cè)頭被配置為根據(jù)來(lái)自多個(gè)學(xué)生模型的多個(gè)第一學(xué)生logit值和多個(gè)第二學(xué)生logit值生成學(xué)生概率值;以及一個(gè)或多個(gè)非暫時(shí)性計(jì)算機(jī)可讀介質(zhì),該一個(gè)或多個(gè)非暫時(shí)性計(jì)算機(jī)可讀介質(zhì)共同存儲(chǔ)指令,這些指令在由一個(gè)或多個(gè)處理器執(zhí)行時(shí)使計(jì)算系統(tǒng)執(zhí)行操作。這些操作包括對(duì)于多個(gè)學(xué)生模型中的每個(gè)學(xué)生模型:基于教師logit值和第一學(xué)生logit值來(lái)評(píng)估第一損失函數(shù);基于第一損失函數(shù)來(lái)修改至少第一學(xué)生logit頭的一個(gè)或多個(gè)參數(shù);基于教師概率值和學(xué)生概率值來(lái)評(píng)估第二不同的損失函數(shù);以及基于第二損失函數(shù)來(lái)修改每個(gè)學(xué)生模型的第二學(xué)生logit頭的一個(gè)或多個(gè)參數(shù)。
5、本公開(kāi)的另一個(gè)示例方面涉及一種用于以改進(jìn)的計(jì)算效率執(zhí)行蒸餾訓(xùn)練的計(jì)算系統(tǒng)。該計(jì)算系統(tǒng)包括:一個(gè)或多個(gè)處理器;教師模型,該教師模型包括教師模型主體、第一教師評(píng)分頭和第二教師評(píng)分頭,其中教師模型主體被配置為處理輸入以生成教師中間表示,其中第一教師評(píng)分頭被配置為處理教師中間表示以在第一評(píng)分域中生成第一教師評(píng)分值,并且其中第二教師評(píng)分頭被配置為處理第一教師評(píng)分值以在第二評(píng)分域中生成第二教師評(píng)分值,其中第二評(píng)分域?qū)?yīng)于教師模型的目標(biāo);學(xué)生模型,該學(xué)生模型包括學(xué)生模型主體、第一學(xué)生評(píng)分頭、第二學(xué)生評(píng)分頭和第三學(xué)生評(píng)分頭,其中學(xué)生模型主體被配置為處理輸入以生成學(xué)生中間表示,其中第一學(xué)生評(píng)分頭被配置為處理學(xué)生中間表示以在第一評(píng)分域中生成第一學(xué)生評(píng)分值,其中第二學(xué)生評(píng)分頭被配置為處理學(xué)生中間表示以在第一評(píng)分域中生成第二學(xué)生評(píng)分值,并且其中第三學(xué)生評(píng)分頭被配置為處理第一學(xué)生評(píng)分值和第二學(xué)生評(píng)分值以在第二評(píng)分域中生成第三學(xué)生評(píng)分值;以及一個(gè)或多個(gè)非暫時(shí)性計(jì)算機(jī)可讀介質(zhì),該一個(gè)或多個(gè)非暫時(shí)性計(jì)算機(jī)可讀介質(zhì)共同存儲(chǔ)指令,這些指令在由一個(gè)或多個(gè)處理器執(zhí)行時(shí)使計(jì)算系統(tǒng)執(zhí)行操作。這些操作包括基于第一教師評(píng)分值和第一學(xué)生評(píng)分值來(lái)評(píng)估第一損失函數(shù);基于第一損失函數(shù)來(lái)修改至少第一學(xué)生評(píng)分頭的一個(gè)或多個(gè)參數(shù);基于第二教師評(píng)分值和第三學(xué)生評(píng)分值來(lái)評(píng)估第二不同的損失函數(shù);以及基于第二損失函數(shù)來(lái)修改至少第二學(xué)生評(píng)分頭的一個(gè)或多個(gè)參數(shù)。
6、本公開(kāi)的其他方面涉及各種系統(tǒng)、設(shè)備、非暫時(shí)性計(jì)算機(jī)可讀介質(zhì)、用戶(hù)界面和電子裝置。
7、將參考以下描述和隨附權(quán)利要求更好地理解本公開(kāi)的各種實(shí)施例的這些和其他特征、方面和優(yōu)點(diǎn)。并入本說(shuō)明書(shū)中并構(gòu)成本說(shuō)明書(shū)的一部分的附圖示出了本公開(kāi)的示例實(shí)施例,并且連同描述一起用于解釋相關(guān)原理。
1.一種用于以改進(jìn)的計(jì)算效率執(zhí)行蒸餾訓(xùn)練的計(jì)算系統(tǒng),所述計(jì)算系統(tǒng)包括:
2.如權(quán)利要求1所述的計(jì)算系統(tǒng),其中所述第一損失函數(shù)包括平方損失、huber損失、平滑分位數(shù)損失、分位數(shù)回歸損失或平滑損失中的一者。
3.如任一項(xiàng)前述權(quán)利要求所述的計(jì)算系統(tǒng),其中所述第一損失函數(shù)包括lp損失函數(shù)。
4.如任一項(xiàng)前述權(quán)利要求所述的計(jì)算系統(tǒng),其中所述第一損失函數(shù)比所述第二損失函數(shù)收斂得更快。
5.如任一項(xiàng)前述權(quán)利要求所述的計(jì)算系統(tǒng),其中所述第二損失函數(shù)包括適當(dāng)?shù)脑u(píng)分規(guī)則,所述適當(dāng)?shù)脑u(píng)分規(guī)則在由所述教師產(chǎn)生的預(yù)測(cè)分布的適當(dāng)域中的期望統(tǒng)計(jì)數(shù)據(jù)點(diǎn)處最小化。
6.如任一項(xiàng)前述權(quán)利要求所述的計(jì)算系統(tǒng),其中所述第一損失函數(shù)和所述第二損失函數(shù)中的一者或兩者是以下項(xiàng)中的一者或兩者:圍繞收斂最優(yōu)值對(duì)稱(chēng)凸或強(qiáng)凸。
7.如任一項(xiàng)前述權(quán)利要求所述的計(jì)算系統(tǒng),其中所述第二損失函數(shù)包括交叉熵?fù)p失函數(shù),所述交叉熵?fù)p失函數(shù)在關(guān)于預(yù)測(cè)分布的預(yù)測(cè)平均概率處給出最小值。
8.如任一項(xiàng)前述權(quán)利要求所述的計(jì)算系統(tǒng),還包括:
9.如權(quán)利要求1至7中任一項(xiàng)所述的計(jì)算系統(tǒng),還包括:
10.如任一項(xiàng)前述權(quán)利要求所述的計(jì)算系統(tǒng),其中:
11.如任一項(xiàng)前述權(quán)利要求所述的計(jì)算系統(tǒng),其中所述學(xué)生預(yù)測(cè)頭包括邏輯函數(shù)并且所述學(xué)生概率值包括邏輯回歸輸出。
12.如任一項(xiàng)前述權(quán)利要求所述的計(jì)算系統(tǒng),其中所述教師概率值存儲(chǔ)在非暫時(shí)性計(jì)算機(jī)可讀介質(zhì)中并從所述非暫時(shí)性計(jì)算機(jī)可讀介質(zhì)訪問(wèn)以用于訓(xùn)練所述學(xué)生模型。
13.一個(gè)或多個(gè)非暫時(shí)性計(jì)算機(jī)可讀介質(zhì),所述一個(gè)或多個(gè)非暫時(shí)性計(jì)算機(jī)可讀介質(zhì)共同存儲(chǔ):
14.如權(quán)利要求13所述的一個(gè)或多個(gè)非暫時(shí)性計(jì)算機(jī)可讀介質(zhì),其中:
15.如權(quán)利要求13或14所述的一個(gè)或多個(gè)非暫時(shí)性計(jì)算機(jī)可讀介質(zhì),其中所述第一損失函數(shù)包括平方損失函數(shù),并且所述第二損失函數(shù)包括交叉熵?fù)p失函數(shù)。
16.一種用于以改進(jìn)的計(jì)算效率執(zhí)行蒸餾訓(xùn)練的計(jì)算系統(tǒng),所述計(jì)算系統(tǒng)包括:
17.如權(quán)利要求16所述的計(jì)算系統(tǒng),其中所述第一損失函數(shù)包括平方損失、huber損失、平滑分位數(shù)損失、分位數(shù)回歸損失或平滑損失中的一者。
18.如權(quán)利要求16或17所述的計(jì)算系統(tǒng),其中所述第一損失函數(shù)包括lp損失函數(shù)。
19.如權(quán)利要求16、17或18所述的計(jì)算系統(tǒng),其中所述第一損失函數(shù)比所述第二損失函數(shù)收斂得更快。
20.如權(quán)利要求16、17、18或19所述的計(jì)算系統(tǒng),其中所述第二損失函數(shù)收斂到關(guān)于在對(duì)所述學(xué)生模型表現(xiàn)為相同的示例上的教師預(yù)測(cè)的分布給出最小損失的點(diǎn)。
21.一種用于以改進(jìn)的計(jì)算效率執(zhí)行蒸餾訓(xùn)練的計(jì)算系統(tǒng),所述計(jì)算系統(tǒng)包括:
22.如權(quán)利要求21所述的計(jì)算系統(tǒng),其中所述第一評(píng)分域包括logit域,并且其中所述第二評(píng)分域包括概率域。