1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        【關(guān)于 嵌套實(shí)體識(shí)別 之 Biaffine 】 那些你不知道的事

        共 24557字,需瀏覽 50分鐘

         ·

        2021-05-08 15:14

        作者:楊夕

        項(xiàng)目地址:https://github.com/km1994/nlp_paper_study

        論文:https://www.aclweb.org/anthology/2020.acl-main.577/

        代碼:https://github.com/juntaoy/biaffine-ner

        代碼【中文】:https://github.com/suolyer/PyTorch_BERT_Biaffine_NER

        個(gè)人介紹:大佬們好,我叫楊夕,該項(xiàng)目主要是本人在研讀頂會(huì)論文和復(fù)現(xiàn)經(jīng)典論文過程中,所見、所思、所想、所聞,可能存在一些理解錯(cuò)誤,希望大佬們多多指正。

        • 【關(guān)于 嵌套實(shí)體識(shí)別 之 Biaffine 】 那些你不知道的事

          • 摘要

          • 一、數(shù)據(jù)處理模塊

            • 1.1 原始數(shù)據(jù)格式

            • 1.2 數(shù)據(jù)預(yù)處理模塊 data_pre()

              • 1.2.1 數(shù)據(jù)預(yù)處理 主 函數(shù)

              • 1.2.2 訓(xùn)練數(shù)據(jù)加載 load_data(file_path)

              • 1.2.3 數(shù)據(jù)編碼 encoder(sentence, argument)

            • 1.3 數(shù)據(jù)轉(zhuǎn)化為 MyDataset 對(duì)象

            • 1.4 構(gòu)建 數(shù)據(jù) 迭代器

            • 1.5 最后數(shù)據(jù)構(gòu)建格式

          • 二、模型構(gòu)建 模塊

            • 2.1 主題框架介紹

            • 2.2 embedding layer

            • 2.2 BiLSTM

            • 2.3 FFNN

            • 2.4 biaffine model

            • 2.5 沖突解決

            • 2.6 損失函數(shù)

          • 三、學(xué)習(xí)率衰減 模塊

          • 四、loss 損失函數(shù)定義

          • 四、模型訓(xùn)練

            • 4.1 span_loss 損失函數(shù)定義

            • 4.2 focal_loss 損失函數(shù)定義

          • 參考

        摘要

        • 動(dòng)機(jī):NER 研究 關(guān)注于 扁平化NER,而忽略了 實(shí)體嵌套問題;

        • 方法:在本文中,我們使用基于圖的依存關(guān)系解析中的思想,以通過 biaffine model 為模型提供全局的輸入視圖。biaffine model 對(duì)句子中的開始標(biāo)記和結(jié)束標(biāo)記對(duì)進(jìn)行評(píng)分,我們使用該標(biāo)記來探索所有跨度,以便該模型能夠準(zhǔn)確地預(yù)測(cè)命名實(shí)體。

        • 工作介紹:在這項(xiàng)工作中,我們將NER重新確定為開始和結(jié)束索引的任務(wù),并為這些對(duì)定義的范圍分配類別。我們的系統(tǒng)在多層BiLSTM之上使用biaffine模型,將分?jǐn)?shù)分配給句子中所有可能的跨度。此后,我們不用構(gòu)建依賴關(guān)系樹,而是根據(jù)候選樹的分?jǐn)?shù)對(duì)它們進(jìn)行排序,然后返回符合 Flat 或 Nested NER約束的排名最高的樹 span;

        • 實(shí)驗(yàn)結(jié)果:我們根據(jù)三個(gè)嵌套的NER基準(zhǔn)(ACE 2004,ACE 2005,GENIA)和五個(gè)扁平的NER語料庫(CONLL 2002(荷蘭語,西班牙語),CONLL 2003(英語,德語)和ONTONOTES)對(duì)系統(tǒng)進(jìn)行了評(píng)估。結(jié)果表明,我們的系統(tǒng)在所有三個(gè)嵌套的NER語料庫和所有五個(gè)平坦的NER語料庫上均取得了SoTA結(jié)果,與以前的SoTA相比,實(shí)際收益高達(dá)2.2%的絕對(duì)百分比。

        一、數(shù)據(jù)處理模塊

        1.1 原始數(shù)據(jù)格式

        原始數(shù)據(jù)格式如下所示:

        {
        "text": "當(dāng)希望工程救助的百萬兒童成長起來,科教興國蔚然成風(fēng)時(shí),今天有收藏價(jià)值的書你沒買,明日就叫你悔不當(dāng)初!",
        "entity_list": []
        }
        {
        "text": "藏書本來就是所有傳統(tǒng)收藏門類中的第一大戶,只是我們結(jié)束溫飽的時(shí)間太短而已。",
        "entity_list": []
        }
        {
        "text": "因有關(guān)日寇在京掠奪文物詳情,藏界較為重視,也是我們收藏北京史料中的要件之一。",
        "entity_list":
        [
        {"type": "ns", "argument": "北京"}
        ]
        }
        ...

        1.2 數(shù)據(jù)預(yù)處理模塊 data_pre()

        1.2.1 數(shù)據(jù)預(yù)處理 主 函數(shù)

        • 步驟:

        1. 加載數(shù)據(jù);

        2. 對(duì)數(shù)據(jù)進(jìn)行編碼,轉(zhuǎn)化為 訓(xùn)練數(shù)據(jù) 格式

        • 代碼介紹:

        def data_pre(file_path):
        sentences, arguments = load_data(file_path)
        data = []
        for i in tqdm(range(len(sentences))):
        encode_sent, token_type_ids, attention_mask, span_label, span_mask = encoder(
        sentences[i], arguments[i])

        tmp = {}
        tmp['input_ids'] = encode_sent
        tmp['input_seg'] = token_type_ids
        tmp['input_mask'] = attention_mask
        tmp['span_label'] = span_label
        tmp['span_mask'] = span_mask
        data.append(tmp)

        return data

        • 輸出結(jié)果:

        data[0:2]:
        [
        {
        'input_ids': [
        101, 1728, 3300, 1068, 3189, 2167, 1762, 776, 2966, 1932, 3152, 4289, 6422, 2658, 8024, 5966, 4518, 6772, 711, 7028, 6228, 8024, 738, 3221, 2769, 812, 3119, 5966, 1266, 776, 1380, 3160, 704, 4638, 6206, 816, 722, 671, 511, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
        ],
        'input_seg': [
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
        ],
        'input_mask': [
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
        ],
        'span_label': array(
        [
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]
        ]
        ),
        'span_mask': [
        [
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
        ],
        [
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
        ], ...
        ]
        }, ...
        ]

        1.2.2 訓(xùn)練數(shù)據(jù)加載 load_data(file_path)

        • 代碼介紹:

        def load_data(file_path):
        with open(file_path, 'r', encoding='utf8') as f:
        lines = f.readlines()
        sentences = []
        arguments = []
        for line in lines:
        data = json.loads(line)
        text = data['text']
        entity_list = data['entity_list']
        args_dict={}
        if entity_list != []:
        for entity in entity_list:
        entity_type = entity['type']
        entity_argument=entity['argument']
        args_dict[entity_type] = entity_argument
        sentences.append(text)
        arguments.append(args_dict)
        return sentences, arguments

        • 輸出結(jié)果:

        print(f"sentences[0:2]:{sentences[0:2]}")
        print(f"arguments[0:2]:{arguments[0:2]}")

        >>>
        sentences[0:2]:['因有關(guān)日寇在京掠奪文物詳情,藏界較為重視,也是我們收藏北京史料中的要件之一。', '我們藏有一冊(cè)1945年 6月油印的《北京文物保存保管狀態(tài)之調(diào)查報(bào)告》,調(diào)查范圍涉及故宮、歷博、古研所、北大清華圖書館、北圖、日偽資料庫等二十幾家,言及文物二十萬件以上,洋洋三萬余言,是珍貴的北京史料。']
        arguments[0:2]:[{'ns': '北京'}, {'ns': '北京', 'nt': '古研所'}]

        1.2.3 數(shù)據(jù)編碼 encoder(sentence, argument)

        • 代碼介紹:

        # step 1:獲取 Bert tokenizer
        tokenizer=tools.get_tokenizer()
        # step 2: 獲取 label 到 id 間 的 映射表;
        label2id,id2label,num_labels = tools.load_schema()

        def encoder(sentence, argument):
        # step 3:利用 tokenizer 對(duì) sentence 進(jìn)行 編碼
        encode_dict = tokenizer.encode_plus(
        sentence,
        max_length=args.max_length,
        pad_to_max_length=True,
        truncation=True
        )
        encode_sent = encode_dict['input_ids']
        token_type_ids = encode_dict['token_type_ids']
        attention_mask = encode_dict['attention_mask']

        # step 4:span_mask 生成
        zero = [0 for i in range(args.max_length)]
        span_mask=[ attention_mask for i in range(sum(attention_mask))]
        span_mask.extend([ zero for i in range(sum(attention_mask),args.max_length)])

        # step 5:span_label 生成
        span_label = [0 for i in range(args.max_length)]
        span_label = [span_label for i in range(args.max_length)]
        span_label = np.array(span_label)
        for entity_type,arg in argument.items():
        encode_arg = tokenizer.encode(arg)
        start_idx = tools.search(encode_arg[1:-1], encode_sent)
        end_idx = start_idx + len(encode_arg[1:-1]) - 1
        span_label[start_idx, end_idx] = label2id[entity_type]+1

        return encode_sent, token_type_ids, attention_mask, span_label, span_mask

        • 步驟:

        1. 獲取 Bert tokenizer;

        2. 獲取 label 到 id 間 的 映射表;

        3. encode_plus返回所有編碼信息

        encode_dict:
        {
        'input_ids': [101, 1728, 3300, 1068, 3189, 2167, 1762, 776, 2966, 1932, 3152, 4289, 6422, 2658, 8024, 5966, 4518, 6772, 711, 7028, 6228, 8024, 738, 3221, 2769, 812, 3119, 5966, 1266, 776, 1380, 3160, 704, 4638, 6206, 816, 722, 671, 511, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
        }

        注:
        ‘input_ids’:顧名思義,是單詞在詞典中的編碼
        ‘token_type_ids’, 區(qū)分兩個(gè)句子的編碼
        ‘a(chǎn)ttention_mask’, 指定對(duì)哪些詞進(jìn)行self-Attention操作

        1. span_mask 生成

        2. span_label 生成

        • 介紹:該方法 生成 一個(gè) 大小 為 args.max_length*args.max_length 的矩陣,用于 定位 span 在 句子中的位置【開始位置、結(jié)束位置】,span 在矩陣中行號(hào) 為 開始位置,列號(hào)為 結(jié)束位置,對(duì)應(yīng)的值 為 該 span所對(duì)應(yīng)的類型;

        • 實(shí)例代碼介紹:

        >>>
        import numpy as np
        span_label = [0 for i in range(10)]
        span_label = [span_label for i in range(10)]
        span_label = np.array(span_label)
        start = [1, 3, 7]
        end = [ 2,9, 9]
        label2id = [1,2,4]
        for i in range(len(label2id)):
        span_label[start[i], end[i]] = label2id[i]

        >>>
        array( [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 2],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 4],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
        > 注:行號(hào) 為 start,列號(hào) 為 end,值 為 label2id

        1.3 數(shù)據(jù)轉(zhuǎn)化為 MyDataset 對(duì)象

        將數(shù)據(jù)轉(zhuǎn)化為 torch.tensor 類型

        class MyDataset(Dataset):
        def __init__(self, data):
        self.data = data

        def __len__(self):
        return len(self.data)

        def __getitem__(self, index):
        item = self.data[index]
        one_data = {
        "input_ids": torch.tensor(item['input_ids']).long(),
        "input_seg": torch.tensor(item['input_seg']).long(),
        "input_mask": torch.tensor(item['input_mask']).float(),
        "span_label": torch.tensor(item['span_label']).long(),
        "span_mask": torch.tensor(item['span_mask']).long()
        }
        return one_data

        1.4 構(gòu)建 數(shù)據(jù) 迭代器

        def yield_data(file_path):
        tmp = MyDataset(data_pre(file_path))
        return DataLoader(tmp, batch_size=args.batch_size, shuffle=True)

        1.5 最后數(shù)據(jù)構(gòu)建格式

        data[0:2]:
        [
        {
        'input_ids': [
        101, 1728, 3300, 1068, 3189, 2167, 1762, 776, 2966, 1932, 3152, 4289, 6422, 2658, 8024, 5966, 4518, 6772, 711, 7028, 6228, 8024, 738, 3221, 2769, 812, 3119, 5966, 1266, 776, 1380, 3160, 704, 4638, 6206, 816, 722, 671, 511, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
        ],
        'input_seg': [
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
        ],
        'input_mask': [
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
        ],
        'span_label': array(
        [
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]
        ]
        ),
        'span_mask': [
        [
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
        ],
        [
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
        ], ...
        ]
        }, ...
        ]

        二、模型構(gòu)建 模塊

        2.1 主題框架介紹

        模型主要由 embedding layer、BiLSTM、FFNN、biaffine model 四部分組成。

        2.2 embedding layer

        1. BERT:遵循 (Kantor and Globerson, 2019) 的方法來獲取目標(biāo)令牌的上下文相關(guān)嵌入,每側(cè)有64個(gè)周圍令牌;

        2. character-based word embeddings:使用 CNN 編碼 characters of the tokens.

        class myModel(nn.Module):
        def __init__(self, pre_train_dir: str, dropout_rate: float):
        super().__init__()
        self.roberta_encoder = BertModel.from_pretrained(pre_train_dir)
        self.roberta_encoder.resize_token_embeddings(len(tokenizer))
        ...

        def forward(self, input_ids, input_mask, input_seg, is_training=False):
        bert_output = self.roberta_encoder(input_ids=input_ids,
        attention_mask=input_mask,
        token_type_ids=input_seg)
        encoder_rep = bert_output[0]
        ...

        2.2 BiLSTM

        拼接 char emb 和 word emb,并輸入到 BiLSTM,以獲得 word 表示;

        class myModel(nn.Module):
        def __init__(self, pre_train_dir: str, dropout_rate: float):
        super().__init__()
        ...
        self.lstm=torch.nn.LSTM(input_size=768,hidden_size=768, \
        num_layers=1,batch_first=True, \
        dropout=0.5,bidirectional=True)
        ...

        def forward(self, input_ids, input_mask, input_seg, is_training=False):
        ...
        encoder_rep,_ = self.lstm(encoder_rep)
        ...

        2.3 FFNN

        從BiLSTM獲得單詞表示形式后,我們應(yīng)用兩個(gè)單獨(dú)的FFNN為 span 的開始/結(jié)束創(chuàng)建不同的表示形式(hs / he)。對(duì) span 的開始/結(jié)束使用不同的表示,可使系統(tǒng)學(xué)會(huì)單獨(dú)識(shí)別 span 的開始/結(jié)束。與直接使用LSTM輸出的模型相比,這提高了準(zhǔn)確性,因?yàn)閷?shí)體開始和結(jié)束的上下文不同。

        class myModel(nn.Module):
        def __init__(self, pre_train_dir: str, dropout_rate: float):
        ...
        self.start_layer = torch.nn.Sequential(
        torch.nn.Linear(in_features=2*768, out_features=128),
        torch.nn.ReLU()
        )
        self.end_layer = torch.nn.Sequential(
        torch.nn.Linear(in_features=2*768, out_features=128),
        torch.nn.ReLU()
        )
        ...

        def forward(self, input_ids, input_mask, input_seg, is_training=False):
        ...
        start_logits = self.start_layer(encoder_rep)
        end_logits = self.end_layer(encoder_rep)
        ...

        2.4 biaffine model

        在句子上使用biaffine模型來創(chuàng)建 l×l×c 評(píng)分張量(rm),其中l(wèi)是句子的長度,c 是 NER 類別的數(shù)量 +1(對(duì)于非實(shí)體)。

        其中si和ei是 span i 的開始和結(jié)束索引,Um 是 d×c×d 張量,Wm是2d×c矩陣,bm是偏差

        • 定義

        class biaffine(nn.Module):
        def __init__(self, in_size, out_size, bias_x=True, bias_y=True):
        super().__init__()
        self.bias_x = bias_x
        self.bias_y = bias_y
        self.out_size = out_size
        self.U = torch.nn.Parameter(torch.Tensor(in_size + int(bias_x),out_size,in_size + int(bias_y)))
        def forward(self, x, y):
        if self.bias_x:
        x = torch.cat((x, torch.ones_like(x[..., :1])), dim=-1)
        if self.bias_y:
        y = torch.cat((y, torch.ones_like(y[..., :1])), dim=-1)
        bilinar_mapping = torch.einsum('bxi,ioj,byj->bxyo', x, self.U, y)
        return bilinar_mapping

        • 調(diào)用

        class myModel(nn.Module):
        def __init__(self, pre_train_dir: str, dropout_rate: float):
        ...
        self.biaffne_layer = biaffine(128,num_label)
        ...

        def forward(self, input_ids, input_mask, input_seg, is_training=False):
        ...
        span_logits = self.biaffne_layer(start_logits,end_logits)
        span_logits = span_logits.contiguous()
        ...

        2.5 沖突解決

        張量 vr_m 提供在 s_i≤e_i 的約束下(實(shí)體的起點(diǎn)在其終點(diǎn)之前)可以構(gòu)成命名實(shí)體的所有可能 span 的分?jǐn)?shù)。我們?yōu)槊總€(gè)跨度分配一個(gè)NER類別 y0

        然后,我們按照其類別得分 (r_m * (i_{y'})) 降序?qū)λ衅渌胺菍?shí)體”類別的 span 進(jìn)行排序,并應(yīng)用以下后處理約束:對(duì)于嵌套的NER,只要選擇了一個(gè)實(shí)體不會(huì)與排名較高的實(shí)體發(fā)生沖突。對(duì)于 實(shí)體 i與其他實(shí)體 j ,如果 s_i<s_j≤e_i<e_j 或 s_j<s_i≤e_j<e_i ,那么這兩個(gè)實(shí)體沖突。此時(shí)只會(huì)選擇類別得分較高的 span。

        eg:
        在 句子 :In the Bank of China 中, 實(shí)體 the Bank 的 邊界與 實(shí)體 Bank of China 沖突,

        注:對(duì)于 flat NER,我們應(yīng)用了一個(gè)更多的約束,其中包含或在排名在它之前的實(shí)體之內(nèi)的任何實(shí)體都將不會(huì)被選擇。我們命名實(shí)體識(shí)別器的學(xué)習(xí)目標(biāo)是為每個(gè)有效范圍分配正確的類別(包括非實(shí)體)。

        2.6 損失函數(shù)

        因?yàn)樵撊蝿?wù)屬于 多類別分類問題:

        class myModel(nn.Module):
        def __init__(self, pre_train_dir: str, dropout_rate: float):
        ...

        def forward(self, input_ids, input_mask, input_seg, is_training=False):
        ...
        span_prob = torch.nn.functional.softmax(span_logits, dim=-1)

        if is_training:
        return span_logits
        else:
        return span_prob

        三、學(xué)習(xí)率衰減 模塊

        class WarmUp_LinearDecay:
        def __init__(self, optimizer: optim.AdamW, init_rate, warm_up_epoch, decay_epoch, min_lr_rate=1e-8):
        self.optimizer = optimizer
        self.init_rate = init_rate
        self.epoch_step = train_data_length / args.batch_size
        self.warm_up_steps = self.epoch_step * warm_up_epoch
        self.decay_steps = self.epoch_step * decay_epoch
        self.min_lr_rate = min_lr_rate
        self.optimizer_step = 0
        self.all_steps = args.epoch*(train_data_length/args.batch_size)

        def step(self):
        self.optimizer_step += 1
        if self.optimizer_step <= self.warm_up_steps:
        rate = (self.optimizer_step / self.warm_up_steps) * self.init_rate
        elif self.warm_up_steps < self.optimizer_step <= self.decay_steps:
        rate = self.init_rate
        else:
        rate = (1.0 - ((self.optimizer_step - self.decay_steps) / (self.all_steps-self.decay_steps))) * self.init_rate
        if rate < self.min_lr_rate:
        rate = self.min_lr_rate
        for p in self.optimizer.param_groups:
        p["lr"] = rate
        self.optimizer.step()

        四、loss 損失函數(shù)定義

        4.1 span_loss 損失函數(shù)定義

        • 核心思想:對(duì)于模型學(xué)習(xí)到的所有實(shí)體的 start 和 end 位置,構(gòu)造首尾實(shí)體匹配任務(wù),即判斷某個(gè) start 位置是否與某個(gè)end位置匹配為一個(gè)實(shí)體,是則預(yù)測(cè)為1,否則預(yù)測(cè)為0,相當(dāng)于轉(zhuǎn)化為一個(gè)二分類問題,正樣本就是真實(shí)實(shí)體的匹配,負(fù)樣本是非實(shí)體的位置匹配。

        import torch
        from torch import nn
        from utils.arguments_parse import args
        from data_preprocessing import tools
        label2id,id2label,num_labels=tools.load_schema()
        num_label = num_labels+1

        class Span_loss(nn.Module):
        def __init__(self):
        super().__init__()
        self.loss_func = torch.nn.CrossEntropyLoss(reduction="none")

        def forward(self,span_logits,span_label,seq_mask):
        # batch_size,seq_len,hidden=span_label.shape
        span_label = span_label.view(size=(-1,))
        span_logits = span_logits.view(size=(-1, num_label))
        span_loss = self.loss_func(input=span_logits, target=span_label)
        # start_extend = seq_mask.unsqueeze(2).expand(-1, -1, seq_len)
        # end_extend = seq_mask.unsqueeze(1).expand(-1, seq_len, -1)
        span_mask = seq_mask.view(size=(-1,))
        span_loss *=span_mask
        avg_se_loss = torch.sum(span_loss) / seq_mask.size()[0]
        # avg_se_loss = torch.sum(sum_loss) / bsz
        return avg_se_loss

        注:view函數(shù)的作用為重構(gòu)張量的維度,相當(dāng)于numpy中resize()的功能

        • 參考論文:《A Unified MRC Framwork for Name Entity Recognition》

        4.2 focal_loss 損失函數(shù)定義

        • 目標(biāo):解決分類問題中類別不平衡、分類難度差異的一個(gè) loss;

        • 思路:降低了大量簡單負(fù)樣本在訓(xùn)練中所占的權(quán)重,也可理解為一種困難樣本挖掘。

        • 損失函數(shù)形式:

        Focal loss是在交叉熵?fù)p失函數(shù)基礎(chǔ)上進(jìn)行的修改,首先回顧二分類交叉上損失:

        y'是經(jīng)過激活函數(shù)的輸出,所以在0-1之間。可見普通的交叉熵對(duì)于正樣本而言,輸出概率越大損失越小。對(duì)于負(fù)樣本而言,輸出概率越小則損失越小。此時(shí)的損失函數(shù)在大量簡單樣本的迭代過程中比較緩慢且可能無法優(yōu)化至最優(yōu)。那么Focal loss是怎么改進(jìn)的呢?

        首先在原有的基礎(chǔ)上加了一個(gè)因子,其中g(shù)amma>0使得減少易分類樣本的損失。使得更關(guān)注于困難的、錯(cuò)分的樣本。

        例如gamma為2,對(duì)于正類樣本而言,預(yù)測(cè)結(jié)果為0.95肯定是簡單樣本,所以(1-0.95)的gamma次方就會(huì)很小,這時(shí)損失函數(shù)值就變得更小。而預(yù)測(cè)概率為0.3的樣本其損失相對(duì)很大。對(duì)于負(fù)類樣本而言同樣,預(yù)測(cè)0.1的結(jié)果應(yīng)當(dāng)遠(yuǎn)比預(yù)測(cè)0.7的樣本損失值要小得多。對(duì)于預(yù)測(cè)概率為0.5時(shí),損失只減少了0.25倍,所以更加關(guān)注于這種難以區(qū)分的樣本。這樣減少了簡單樣本的影響,大量預(yù)測(cè)概率很小的樣本疊加起來后的效應(yīng)才可能比較有效。

        此外,加入平衡因子alpha,用來平衡正負(fù)樣本本身的比例不均:

        只添加alpha雖然可以平衡正負(fù)樣本的重要性,但是無法解決簡單與困難樣本的問題。

        lambda調(diào)節(jié)簡單樣本權(quán)重降低的速率,當(dāng)lambda為0時(shí)即為交叉熵?fù)p失函數(shù),當(dāng)lambda增加時(shí),調(diào)整因子的影響也在增加。實(shí)驗(yàn)發(fā)現(xiàn)lambda為2是最優(yōu)。

        • 代碼實(shí)現(xiàn)

        import torch
        import torch.nn as nn
        import torch.nn.functional as F
        class FocalLoss(nn.Module):
        '''Multi-class Focal loss implementation'''
        def __init__(self, gamma=2, weight=None, ignore_index=-100):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight
        self.ignore_index = ignore_index

        def forward(self, input, target):
        """
        input: [N, C]
        target: [N, ]
        """
        logpt = F.log_softmax(input, dim=1)
        pt = torch.exp(logpt)
        logpt = (1 - pt) ** self.gamma * logpt
        loss = F.nll_loss(logpt, target, self.weight, ignore_index=self.ignore_index)
        return loss

        • 參考論文:《 Focal Loss for Dense Object Detection 》

        四、模型訓(xùn)練

        def train():
        # step 1:數(shù)據(jù)預(yù)處理
        train_data = data_prepro.yield_data(args.train_path)
        test_data = data_prepro.yield_data(args.test_path)

        # step 2:模型定義
        model = myModel(pre_train_dir=args.pretrained_model_path, dropout_rate=0.5).to(device)

        # step 3:優(yōu)化函數(shù) 定義
        # model.load_state_dict(torch.load(args.checkpoints))
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'gamma', 'beta']
        optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate': 0.0}
        ]
        optimizer = optim.AdamW(params=optimizer_grouped_parameters, lr=args.learning_rate)

        schedule = WarmUp_LinearDecay(
        optimizer = optimizer,
        init_rate = args.learning_rate,
        warm_up_epoch = args.warm_up_epoch,
        decay_epoch = args.decay_epoch
        )

        # step 4:span_loss 函數(shù) 定義
        span_loss_func = span_loss.Span_loss().to(device)
        span_acc = metrics.metrics_span().to(device)

        # step 5:訓(xùn)練
        step=0
        best=0
        for epoch in range(args.epoch):
        for item in train_data:
        step+=1
        input_ids, input_mask, input_seg = item["input_ids"], item["input_mask"], item["input_seg"]
        span_label,span_mask = item['span_label'],item["span_mask"]
        optimizer.zero_grad()
        span_logits = model(
        input_ids=input_ids.to(device),
        input_mask=input_mask.to(device),
        input_seg=input_seg.to(device),
        is_training=True
        )
        span_loss_v = span_loss_func(span_logits,span_label.to(device),span_mask.to(device))
        loss = span_loss_v
        loss = loss.float().mean().type_as(loss)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_norm)
        schedule.step()
        # optimizer.step()
        if step%100 == 0:
        span_logits = torch.nn.functional.softmax(span_logits, dim=-1)
        recall,precise,span_f1=span_acc(span_logits,span_label.to(device))
        logger.info('epoch %d, step %d, loss %.4f, recall %.4f, precise %.4f, span_f1 %.4f'% (epoch,step,loss,recall,precise,span_f1))
        with torch.no_grad():
        count=0
        span_f1=0
        recall=0
        precise=0

        for item in test_data:
        count+=1
        input_ids, input_mask, input_seg = item["input_ids"], item["input_mask"], item["input_seg"]
        span_label,span_mask = item['span_label'],item["span_mask"]

        optimizer.zero_grad()
        span_logits = model(
        input_ids=input_ids.to(device),
        input_mask=input_mask.to(device),
        input_seg=input_seg.to(device),
        is_training=False
        )
        tmp_recall,tmp_precise,tmp_span_f1=span_acc(span_logits,span_label.to(device))
        span_f1+=tmp_span_f1
        recall+=tmp_recall
        precise+=tmp_precise

        span_f1 = span_f1/count
        recall=recall/count
        precise=precise/count

        logger.info('-----eval----')
        logger.info('epoch %d, step %d, loss %.4f, recall %.4f, precise %.4f, span_f1 %.4f'% (epoch,step,loss,recall,precise,span_f1))
        logger.info('-----eval----')
        if best < span_f1:
        best=span_f1
        torch.save(model.state_dict(), f=args.checkpoints)
        logger.info('-----save the best model----')

        參考

        1. Named Entity R


        瀏覽 502
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        評(píng)論
        圖片
        表情
        推薦
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            国产成人性生活 | 99re在线精品 | 亚洲五月天综合 | 欧美美女另类操逼 | 做爰高潮小视频在线观看 | 国产在线观看国产精品产拍 | 99久久久无码国产精品免费麻豆 | 91黑丝国产 | 一本大道AV片加勒比无码 | 另类国产TS人妖高潮系列视频 |