【關(guān)于 嵌套實(shí)體識(shí)別 之 Biaffine 】 那些你不知道的事
作者:楊夕
項(xiàng)目地址:https://github.com/km1994/nlp_paper_study
論文:https://www.aclweb.org/anthology/2020.acl-main.577/
代碼:https://github.com/juntaoy/biaffine-ner
代碼【中文】:https://github.com/suolyer/PyTorch_BERT_Biaffine_NER
個(gè)人介紹:大佬們好,我叫楊夕,該項(xiàng)目主要是本人在研讀頂會(huì)論文和復(fù)現(xiàn)經(jīng)典論文過程中,所見、所思、所想、所聞,可能存在一些理解錯(cuò)誤,希望大佬們多多指正。
【關(guān)于 嵌套實(shí)體識(shí)別 之 Biaffine 】 那些你不知道的事
摘要
一、數(shù)據(jù)處理模塊
1.1 原始數(shù)據(jù)格式
1.2 數(shù)據(jù)預(yù)處理模塊 data_pre()
1.2.1 數(shù)據(jù)預(yù)處理 主 函數(shù)
1.2.2 訓(xùn)練數(shù)據(jù)加載 load_data(file_path)
1.2.3 數(shù)據(jù)編碼 encoder(sentence, argument)
1.3 數(shù)據(jù)轉(zhuǎn)化為 MyDataset 對(duì)象
1.4 構(gòu)建 數(shù)據(jù) 迭代器
1.5 最后數(shù)據(jù)構(gòu)建格式
二、模型構(gòu)建 模塊
2.1 主題框架介紹
2.2 embedding layer
2.2 BiLSTM
2.3 FFNN
2.4 biaffine model
2.5 沖突解決
2.6 損失函數(shù)
三、學(xué)習(xí)率衰減 模塊
四、loss 損失函數(shù)定義
四、模型訓(xùn)練
4.1 span_loss 損失函數(shù)定義
4.2 focal_loss 損失函數(shù)定義
參考
摘要
動(dòng)機(jī):NER 研究 關(guān)注于 扁平化NER,而忽略了 實(shí)體嵌套問題;
方法:在本文中,我們使用基于圖的依存關(guān)系解析中的思想,以通過 biaffine model 為模型提供全局的輸入視圖。biaffine model 對(duì)句子中的開始標(biāo)記和結(jié)束標(biāo)記對(duì)進(jìn)行評(píng)分,我們使用該標(biāo)記來探索所有跨度,以便該模型能夠準(zhǔn)確地預(yù)測(cè)命名實(shí)體。
工作介紹:在這項(xiàng)工作中,我們將NER重新確定為開始和結(jié)束索引的任務(wù),并為這些對(duì)定義的范圍分配類別。我們的系統(tǒng)在多層BiLSTM之上使用biaffine模型,將分?jǐn)?shù)分配給句子中所有可能的跨度。此后,我們不用構(gòu)建依賴關(guān)系樹,而是根據(jù)候選樹的分?jǐn)?shù)對(duì)它們進(jìn)行排序,然后返回符合 Flat 或 Nested NER約束的排名最高的樹 span;
實(shí)驗(yàn)結(jié)果:我們根據(jù)三個(gè)嵌套的NER基準(zhǔn)(ACE 2004,ACE 2005,GENIA)和五個(gè)扁平的NER語料庫(CONLL 2002(荷蘭語,西班牙語),CONLL 2003(英語,德語)和ONTONOTES)對(duì)系統(tǒng)進(jìn)行了評(píng)估。結(jié)果表明,我們的系統(tǒng)在所有三個(gè)嵌套的NER語料庫和所有五個(gè)平坦的NER語料庫上均取得了SoTA結(jié)果,與以前的SoTA相比,實(shí)際收益高達(dá)2.2%的絕對(duì)百分比。
一、數(shù)據(jù)處理模塊
1.1 原始數(shù)據(jù)格式
原始數(shù)據(jù)格式如下所示:
{
"text": "當(dāng)希望工程救助的百萬兒童成長起來,科教興國蔚然成風(fēng)時(shí),今天有收藏價(jià)值的書你沒買,明日就叫你悔不當(dāng)初!",
"entity_list": []
}
{
"text": "藏書本來就是所有傳統(tǒng)收藏門類中的第一大戶,只是我們結(jié)束溫飽的時(shí)間太短而已。",
"entity_list": []
}
{
"text": "因有關(guān)日寇在京掠奪文物詳情,藏界較為重視,也是我們收藏北京史料中的要件之一。",
"entity_list":
[
{"type": "ns", "argument": "北京"}
]
}
...
1.2 數(shù)據(jù)預(yù)處理模塊 data_pre()
1.2.1 數(shù)據(jù)預(yù)處理 主 函數(shù)
步驟:
加載數(shù)據(jù);
對(duì)數(shù)據(jù)進(jìn)行編碼,轉(zhuǎn)化為 訓(xùn)練數(shù)據(jù) 格式
代碼介紹:
def data_pre(file_path):
sentences, arguments = load_data(file_path)
data = []
for i in tqdm(range(len(sentences))):
encode_sent, token_type_ids, attention_mask, span_label, span_mask = encoder(
sentences[i], arguments[i])
tmp = {}
tmp['input_ids'] = encode_sent
tmp['input_seg'] = token_type_ids
tmp['input_mask'] = attention_mask
tmp['span_label'] = span_label
tmp['span_mask'] = span_mask
data.append(tmp)
return data
輸出結(jié)果:
data[0:2]:
[
{
'input_ids': [
101, 1728, 3300, 1068, 3189, 2167, 1762, 776, 2966, 1932, 3152, 4289, 6422, 2658, 8024, 5966, 4518, 6772, 711, 7028, 6228, 8024, 738, 3221, 2769, 812, 3119, 5966, 1266, 776, 1380, 3160, 704, 4638, 6206, 816, 722, 671, 511, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
'input_seg': [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
'input_mask': [
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
'span_label': array(
[
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]
]
),
'span_mask': [
[
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
[
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
], ...
]
}, ...
]
1.2.2 訓(xùn)練數(shù)據(jù)加載 load_data(file_path)
代碼介紹:
def load_data(file_path):
with open(file_path, 'r', encoding='utf8') as f:
lines = f.readlines()
sentences = []
arguments = []
for line in lines:
data = json.loads(line)
text = data['text']
entity_list = data['entity_list']
args_dict={}
if entity_list != []:
for entity in entity_list:
entity_type = entity['type']
entity_argument=entity['argument']
args_dict[entity_type] = entity_argument
sentences.append(text)
arguments.append(args_dict)
return sentences, arguments
輸出結(jié)果:
print(f"sentences[0:2]:{sentences[0:2]}")
print(f"arguments[0:2]:{arguments[0:2]}")
>>>
sentences[0:2]:['因有關(guān)日寇在京掠奪文物詳情,藏界較為重視,也是我們收藏北京史料中的要件之一。', '我們藏有一冊(cè)1945年 6月油印的《北京文物保存保管狀態(tài)之調(diào)查報(bào)告》,調(diào)查范圍涉及故宮、歷博、古研所、北大清華圖書館、北圖、日偽資料庫等二十幾家,言及文物二十萬件以上,洋洋三萬余言,是珍貴的北京史料。']
arguments[0:2]:[{'ns': '北京'}, {'ns': '北京', 'nt': '古研所'}]
1.2.3 數(shù)據(jù)編碼 encoder(sentence, argument)
代碼介紹:
# step 1:獲取 Bert tokenizer
tokenizer=tools.get_tokenizer()
# step 2: 獲取 label 到 id 間 的 映射表;
label2id,id2label,num_labels = tools.load_schema()
def encoder(sentence, argument):
# step 3:利用 tokenizer 對(duì) sentence 進(jìn)行 編碼
encode_dict = tokenizer.encode_plus(
sentence,
max_length=args.max_length,
pad_to_max_length=True,
truncation=True
)
encode_sent = encode_dict['input_ids']
token_type_ids = encode_dict['token_type_ids']
attention_mask = encode_dict['attention_mask']
# step 4:span_mask 生成
zero = [0 for i in range(args.max_length)]
span_mask=[ attention_mask for i in range(sum(attention_mask))]
span_mask.extend([ zero for i in range(sum(attention_mask),args.max_length)])
# step 5:span_label 生成
span_label = [0 for i in range(args.max_length)]
span_label = [span_label for i in range(args.max_length)]
span_label = np.array(span_label)
for entity_type,arg in argument.items():
encode_arg = tokenizer.encode(arg)
start_idx = tools.search(encode_arg[1:-1], encode_sent)
end_idx = start_idx + len(encode_arg[1:-1]) - 1
span_label[start_idx, end_idx] = label2id[entity_type]+1
return encode_sent, token_type_ids, attention_mask, span_label, span_mask
步驟:
獲取 Bert tokenizer;
獲取 label 到 id 間 的 映射表;
encode_plus返回所有編碼信息
encode_dict:
{
'input_ids': [101, 1728, 3300, 1068, 3189, 2167, 1762, 776, 2966, 1932, 3152, 4289, 6422, 2658, 8024, 5966, 4518, 6772, 711, 7028, 6228, 8024, 738, 3221, 2769, 812, 3119, 5966, 1266, 776, 1380, 3160, 704, 4638, 6206, 816, 722, 671, 511, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
}
注:
‘input_ids’:顧名思義,是單詞在詞典中的編碼
‘token_type_ids’, 區(qū)分兩個(gè)句子的編碼
‘a(chǎn)ttention_mask’, 指定對(duì)哪些詞進(jìn)行self-Attention操作
span_mask 生成
span_label 生成
介紹:該方法 生成 一個(gè) 大小 為 args.max_length*args.max_length 的矩陣,用于 定位 span 在 句子中的位置【開始位置、結(jié)束位置】,span 在矩陣中行號(hào) 為 開始位置,列號(hào)為 結(jié)束位置,對(duì)應(yīng)的值 為 該 span所對(duì)應(yīng)的類型;
實(shí)例代碼介紹:
>>>
import numpy as np
span_label = [0 for i in range(10)]
span_label = [span_label for i in range(10)]
span_label = np.array(span_label)
start = [1, 3, 7]
end = [ 2,9, 9]
label2id = [1,2,4]
for i in range(len(label2id)):
span_label[start[i], end[i]] = label2id[i]
>>>
array( [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 2],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 4],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
> 注:行號(hào) 為 start,列號(hào) 為 end,值 為 label2id
1.3 數(shù)據(jù)轉(zhuǎn)化為 MyDataset 對(duì)象
將數(shù)據(jù)轉(zhuǎn)化為 torch.tensor 類型
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
item = self.data[index]
one_data = {
"input_ids": torch.tensor(item['input_ids']).long(),
"input_seg": torch.tensor(item['input_seg']).long(),
"input_mask": torch.tensor(item['input_mask']).float(),
"span_label": torch.tensor(item['span_label']).long(),
"span_mask": torch.tensor(item['span_mask']).long()
}
return one_data
1.4 構(gòu)建 數(shù)據(jù) 迭代器
def yield_data(file_path):
tmp = MyDataset(data_pre(file_path))
return DataLoader(tmp, batch_size=args.batch_size, shuffle=True)
1.5 最后數(shù)據(jù)構(gòu)建格式
data[0:2]:
[
{
'input_ids': [
101, 1728, 3300, 1068, 3189, 2167, 1762, 776, 2966, 1932, 3152, 4289, 6422, 2658, 8024, 5966, 4518, 6772, 711, 7028, 6228, 8024, 738, 3221, 2769, 812, 3119, 5966, 1266, 776, 1380, 3160, 704, 4638, 6206, 816, 722, 671, 511, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
'input_seg': [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
'input_mask': [
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
'span_label': array(
[
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]
]
),
'span_mask': [
[
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
],
[
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
], ...
]
}, ...
]
二、模型構(gòu)建 模塊
2.1 主題框架介紹

模型主要由 embedding layer、BiLSTM、FFNN、biaffine model 四部分組成。
2.2 embedding layer
BERT:遵循 (Kantor and Globerson, 2019) 的方法來獲取目標(biāo)令牌的上下文相關(guān)嵌入,每側(cè)有64個(gè)周圍令牌;
character-based word embeddings:使用 CNN 編碼 characters of the tokens.
class myModel(nn.Module):
def __init__(self, pre_train_dir: str, dropout_rate: float):
super().__init__()
self.roberta_encoder = BertModel.from_pretrained(pre_train_dir)
self.roberta_encoder.resize_token_embeddings(len(tokenizer))
...
def forward(self, input_ids, input_mask, input_seg, is_training=False):
bert_output = self.roberta_encoder(input_ids=input_ids,
attention_mask=input_mask,
token_type_ids=input_seg)
encoder_rep = bert_output[0]
...
2.2 BiLSTM
拼接 char emb 和 word emb,并輸入到 BiLSTM,以獲得 word 表示;
class myModel(nn.Module):
def __init__(self, pre_train_dir: str, dropout_rate: float):
super().__init__()
...
self.lstm=torch.nn.LSTM(input_size=768,hidden_size=768, \
num_layers=1,batch_first=True, \
dropout=0.5,bidirectional=True)
...
def forward(self, input_ids, input_mask, input_seg, is_training=False):
...
encoder_rep,_ = self.lstm(encoder_rep)
...
2.3 FFNN
從BiLSTM獲得單詞表示形式后,我們應(yīng)用兩個(gè)單獨(dú)的FFNN為 span 的開始/結(jié)束創(chuàng)建不同的表示形式(hs / he)。對(duì) span 的開始/結(jié)束使用不同的表示,可使系統(tǒng)學(xué)會(huì)單獨(dú)識(shí)別 span 的開始/結(jié)束。與直接使用LSTM輸出的模型相比,這提高了準(zhǔn)確性,因?yàn)閷?shí)體開始和結(jié)束的上下文不同。

class myModel(nn.Module):
def __init__(self, pre_train_dir: str, dropout_rate: float):
...
self.start_layer = torch.nn.Sequential(
torch.nn.Linear(in_features=2*768, out_features=128),
torch.nn.ReLU()
)
self.end_layer = torch.nn.Sequential(
torch.nn.Linear(in_features=2*768, out_features=128),
torch.nn.ReLU()
)
...
def forward(self, input_ids, input_mask, input_seg, is_training=False):
...
start_logits = self.start_layer(encoder_rep)
end_logits = self.end_layer(encoder_rep)
...
2.4 biaffine model
在句子上使用biaffine模型來創(chuàng)建 l×l×c 評(píng)分張量(rm),其中l(wèi)是句子的長度,c 是 NER 類別的數(shù)量 +1(對(duì)于非實(shí)體)。

其中si和ei是 span i 的開始和結(jié)束索引,Um 是 d×c×d 張量,Wm是2d×c矩陣,bm是偏差
定義
class biaffine(nn.Module):
def __init__(self, in_size, out_size, bias_x=True, bias_y=True):
super().__init__()
self.bias_x = bias_x
self.bias_y = bias_y
self.out_size = out_size
self.U = torch.nn.Parameter(torch.Tensor(in_size + int(bias_x),out_size,in_size + int(bias_y)))
def forward(self, x, y):
if self.bias_x:
x = torch.cat((x, torch.ones_like(x[..., :1])), dim=-1)
if self.bias_y:
y = torch.cat((y, torch.ones_like(y[..., :1])), dim=-1)
bilinar_mapping = torch.einsum('bxi,ioj,byj->bxyo', x, self.U, y)
return bilinar_mapping
調(diào)用
class myModel(nn.Module):
def __init__(self, pre_train_dir: str, dropout_rate: float):
...
self.biaffne_layer = biaffine(128,num_label)
...
def forward(self, input_ids, input_mask, input_seg, is_training=False):
...
span_logits = self.biaffne_layer(start_logits,end_logits)
span_logits = span_logits.contiguous()
...
2.5 沖突解決
張量 vr_m 提供在 s_i≤e_i 的約束下(實(shí)體的起點(diǎn)在其終點(diǎn)之前)可以構(gòu)成命名實(shí)體的所有可能 span 的分?jǐn)?shù)。我們?yōu)槊總€(gè)跨度分配一個(gè)NER類別 y0

然后,我們按照其類別得分 (r_m * (i_{y'})) 降序?qū)λ衅渌胺菍?shí)體”類別的 span 進(jìn)行排序,并應(yīng)用以下后處理約束:對(duì)于嵌套的NER,只要選擇了一個(gè)實(shí)體不會(huì)與排名較高的實(shí)體發(fā)生沖突。對(duì)于 實(shí)體 i與其他實(shí)體 j ,如果 s_i<s_j≤e_i<e_j 或 s_j<s_i≤e_j<e_i ,那么這兩個(gè)實(shí)體沖突。此時(shí)只會(huì)選擇類別得分較高的 span。
eg:
在 句子 :In the Bank of China 中, 實(shí)體 the Bank 的 邊界與 實(shí)體 Bank of China 沖突,
注:對(duì)于 flat NER,我們應(yīng)用了一個(gè)更多的約束,其中包含或在排名在它之前的實(shí)體之內(nèi)的任何實(shí)體都將不會(huì)被選擇。我們命名實(shí)體識(shí)別器的學(xué)習(xí)目標(biāo)是為每個(gè)有效范圍分配正確的類別(包括非實(shí)體)。
2.6 損失函數(shù)
因?yàn)樵撊蝿?wù)屬于 多類別分類問題:

class myModel(nn.Module):
def __init__(self, pre_train_dir: str, dropout_rate: float):
...
def forward(self, input_ids, input_mask, input_seg, is_training=False):
...
span_prob = torch.nn.functional.softmax(span_logits, dim=-1)
if is_training:
return span_logits
else:
return span_prob
三、學(xué)習(xí)率衰減 模塊
class WarmUp_LinearDecay:
def __init__(self, optimizer: optim.AdamW, init_rate, warm_up_epoch, decay_epoch, min_lr_rate=1e-8):
self.optimizer = optimizer
self.init_rate = init_rate
self.epoch_step = train_data_length / args.batch_size
self.warm_up_steps = self.epoch_step * warm_up_epoch
self.decay_steps = self.epoch_step * decay_epoch
self.min_lr_rate = min_lr_rate
self.optimizer_step = 0
self.all_steps = args.epoch*(train_data_length/args.batch_size)
def step(self):
self.optimizer_step += 1
if self.optimizer_step <= self.warm_up_steps:
rate = (self.optimizer_step / self.warm_up_steps) * self.init_rate
elif self.warm_up_steps < self.optimizer_step <= self.decay_steps:
rate = self.init_rate
else:
rate = (1.0 - ((self.optimizer_step - self.decay_steps) / (self.all_steps-self.decay_steps))) * self.init_rate
if rate < self.min_lr_rate:
rate = self.min_lr_rate
for p in self.optimizer.param_groups:
p["lr"] = rate
self.optimizer.step()
四、loss 損失函數(shù)定義
4.1 span_loss 損失函數(shù)定義
核心思想:對(duì)于模型學(xué)習(xí)到的所有實(shí)體的 start 和 end 位置,構(gòu)造首尾實(shí)體匹配任務(wù),即判斷某個(gè) start 位置是否與某個(gè)end位置匹配為一個(gè)實(shí)體,是則預(yù)測(cè)為1,否則預(yù)測(cè)為0,相當(dāng)于轉(zhuǎn)化為一個(gè)二分類問題,正樣本就是真實(shí)實(shí)體的匹配,負(fù)樣本是非實(shí)體的位置匹配。
import torch
from torch import nn
from utils.arguments_parse import args
from data_preprocessing import tools
label2id,id2label,num_labels=tools.load_schema()
num_label = num_labels+1
class Span_loss(nn.Module):
def __init__(self):
super().__init__()
self.loss_func = torch.nn.CrossEntropyLoss(reduction="none")
def forward(self,span_logits,span_label,seq_mask):
# batch_size,seq_len,hidden=span_label.shape
span_label = span_label.view(size=(-1,))
span_logits = span_logits.view(size=(-1, num_label))
span_loss = self.loss_func(input=span_logits, target=span_label)
# start_extend = seq_mask.unsqueeze(2).expand(-1, -1, seq_len)
# end_extend = seq_mask.unsqueeze(1).expand(-1, seq_len, -1)
span_mask = seq_mask.view(size=(-1,))
span_loss *=span_mask
avg_se_loss = torch.sum(span_loss) / seq_mask.size()[0]
# avg_se_loss = torch.sum(sum_loss) / bsz
return avg_se_loss
注:view函數(shù)的作用為重構(gòu)張量的維度,相當(dāng)于numpy中resize()的功能
參考論文:《A Unified MRC Framwork for Name Entity Recognition》
4.2 focal_loss 損失函數(shù)定義
目標(biāo):解決分類問題中類別不平衡、分類難度差異的一個(gè) loss;
思路:降低了大量簡單負(fù)樣本在訓(xùn)練中所占的權(quán)重,也可理解為一種困難樣本挖掘。
損失函數(shù)形式:
Focal loss是在交叉熵?fù)p失函數(shù)基礎(chǔ)上進(jìn)行的修改,首先回顧二分類交叉上損失:

y'是經(jīng)過激活函數(shù)的輸出,所以在0-1之間。可見普通的交叉熵對(duì)于正樣本而言,輸出概率越大損失越小。對(duì)于負(fù)樣本而言,輸出概率越小則損失越小。此時(shí)的損失函數(shù)在大量簡單樣本的迭代過程中比較緩慢且可能無法優(yōu)化至最優(yōu)。那么Focal loss是怎么改進(jìn)的呢?


首先在原有的基礎(chǔ)上加了一個(gè)因子,其中g(shù)amma>0使得減少易分類樣本的損失。使得更關(guān)注于困難的、錯(cuò)分的樣本。
例如gamma為2,對(duì)于正類樣本而言,預(yù)測(cè)結(jié)果為0.95肯定是簡單樣本,所以(1-0.95)的gamma次方就會(huì)很小,這時(shí)損失函數(shù)值就變得更小。而預(yù)測(cè)概率為0.3的樣本其損失相對(duì)很大。對(duì)于負(fù)類樣本而言同樣,預(yù)測(cè)0.1的結(jié)果應(yīng)當(dāng)遠(yuǎn)比預(yù)測(cè)0.7的樣本損失值要小得多。對(duì)于預(yù)測(cè)概率為0.5時(shí),損失只減少了0.25倍,所以更加關(guān)注于這種難以區(qū)分的樣本。這樣減少了簡單樣本的影響,大量預(yù)測(cè)概率很小的樣本疊加起來后的效應(yīng)才可能比較有效。
此外,加入平衡因子alpha,用來平衡正負(fù)樣本本身的比例不均:

只添加alpha雖然可以平衡正負(fù)樣本的重要性,但是無法解決簡單與困難樣本的問題。
lambda調(diào)節(jié)簡單樣本權(quán)重降低的速率,當(dāng)lambda為0時(shí)即為交叉熵?fù)p失函數(shù),當(dāng)lambda增加時(shí),調(diào)整因子的影響也在增加。實(shí)驗(yàn)發(fā)現(xiàn)lambda為2是最優(yōu)。
代碼實(shí)現(xiàn)
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
'''Multi-class Focal loss implementation'''
def __init__(self, gamma=2, weight=None, ignore_index=-100):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.weight = weight
self.ignore_index = ignore_index
def forward(self, input, target):
"""
input: [N, C]
target: [N, ]
"""
logpt = F.log_softmax(input, dim=1)
pt = torch.exp(logpt)
logpt = (1 - pt) ** self.gamma * logpt
loss = F.nll_loss(logpt, target, self.weight, ignore_index=self.ignore_index)
return loss
參考論文:《 Focal Loss for Dense Object Detection 》
四、模型訓(xùn)練
def train():
# step 1:數(shù)據(jù)預(yù)處理
train_data = data_prepro.yield_data(args.train_path)
test_data = data_prepro.yield_data(args.test_path)
# step 2:模型定義
model = myModel(pre_train_dir=args.pretrained_model_path, dropout_rate=0.5).to(device)
# step 3:優(yōu)化函數(shù) 定義
# model.load_state_dict(torch.load(args.checkpoints))
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
'weight_decay_rate': 0.0}
]
optimizer = optim.AdamW(params=optimizer_grouped_parameters, lr=args.learning_rate)
schedule = WarmUp_LinearDecay(
optimizer = optimizer,
init_rate = args.learning_rate,
warm_up_epoch = args.warm_up_epoch,
decay_epoch = args.decay_epoch
)
# step 4:span_loss 函數(shù) 定義
span_loss_func = span_loss.Span_loss().to(device)
span_acc = metrics.metrics_span().to(device)
# step 5:訓(xùn)練
step=0
best=0
for epoch in range(args.epoch):
for item in train_data:
step+=1
input_ids, input_mask, input_seg = item["input_ids"], item["input_mask"], item["input_seg"]
span_label,span_mask = item['span_label'],item["span_mask"]
optimizer.zero_grad()
span_logits = model(
input_ids=input_ids.to(device),
input_mask=input_mask.to(device),
input_seg=input_seg.to(device),
is_training=True
)
span_loss_v = span_loss_func(span_logits,span_label.to(device),span_mask.to(device))
loss = span_loss_v
loss = loss.float().mean().type_as(loss)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_norm)
schedule.step()
# optimizer.step()
if step%100 == 0:
span_logits = torch.nn.functional.softmax(span_logits, dim=-1)
recall,precise,span_f1=span_acc(span_logits,span_label.to(device))
logger.info('epoch %d, step %d, loss %.4f, recall %.4f, precise %.4f, span_f1 %.4f'% (epoch,step,loss,recall,precise,span_f1))
with torch.no_grad():
count=0
span_f1=0
recall=0
precise=0
for item in test_data:
count+=1
input_ids, input_mask, input_seg = item["input_ids"], item["input_mask"], item["input_seg"]
span_label,span_mask = item['span_label'],item["span_mask"]
optimizer.zero_grad()
span_logits = model(
input_ids=input_ids.to(device),
input_mask=input_mask.to(device),
input_seg=input_seg.to(device),
is_training=False
)
tmp_recall,tmp_precise,tmp_span_f1=span_acc(span_logits,span_label.to(device))
span_f1+=tmp_span_f1
recall+=tmp_recall
precise+=tmp_precise
span_f1 = span_f1/count
recall=recall/count
precise=precise/count
logger.info('-----eval----')
logger.info('epoch %d, step %d, loss %.4f, recall %.4f, precise %.4f, span_f1 %.4f'% (epoch,step,loss,recall,precise,span_f1))
logger.info('-----eval----')
if best < span_f1:
best=span_f1
torch.save(model.state_dict(), f=args.checkpoints)
logger.info('-----save the best model----')
參考
Named Entity R
