1. 使用 DPO 微調(diào) Llama 2

        共 8045字,需瀏覽 17分鐘

         ·

        2023-08-23 15:37

        簡(jiǎn)介

        基于人類反饋的強(qiáng)化學(xué)習(xí) (Reinforcement Learning from Human Feedback,RLHF) 事實(shí)上已成為 GPT-4 或 Claude 等 LLM 訓(xùn)練的最后一步,它可以確保語言模型的輸出符合人類在閑聊或安全性等方面的期望。然而,它也給 NLP 引入了一些 RL 相關(guān)的復(fù)雜性: 既要構(gòu)建一個(gè)好的獎(jiǎng)勵(lì)函數(shù),并訓(xùn)練一個(gè)模型用以估計(jì)每個(gè)狀態(tài)的價(jià)值 (value); 又要注意最終生成的 LLM 不能與原始模型相差太遠(yuǎn),如果太遠(yuǎn)的話會(huì)使得模型容易產(chǎn)生亂碼而非有意義的文本。該過程非常復(fù)雜,涉及到許多復(fù)雜的組件,而這些組件本身在訓(xùn)練過程中又是動(dòng)態(tài)變化的,因此把它們料理好并不容易。

        Rafailov、Sharma、Mitchell 等人最近發(fā)表了一篇論文 Direct Preference Optimization,論文提出將現(xiàn)有方法使用的基于強(qiáng)化學(xué)習(xí)的目標(biāo)轉(zhuǎn)換為可以通過簡(jiǎn)單的二元交叉熵?fù)p失直接優(yōu)化的目標(biāo),這一做法大大簡(jiǎn)化了 LLM 的提純過程。

        本文介紹了直接偏好優(yōu)化 (Direct Preference Optimization,DPO) 法,該方法現(xiàn)已集成至 TRL 庫 中。同時(shí),我們還展示了如何在 stack-exchange preference 數(shù)據(jù)集上微調(diào)最新的 Llama v2 7B 模型, stack-exchange preference 數(shù)據(jù)集中包含了各個(gè) stack-exchange 門戶上的各種問題及其排序后的回答。

        DPO 與 PPO

        在通過 RL 優(yōu)化人類衍生偏好時(shí),一直以來的傳統(tǒng)做法是使用一個(gè)輔助獎(jiǎng)勵(lì)模型來微調(diào)目標(biāo)模型,以通過 RL 機(jī)制最大化目標(biāo)模型所能獲得的獎(jiǎng)勵(lì)。直觀上,我們使用獎(jiǎng)勵(lì)模型向待優(yōu)化模型提供反饋,以促使它多生成高獎(jiǎng)勵(lì)輸出,少生成低獎(jiǎng)勵(lì)輸出。同時(shí),我們使用凍結(jié)的參考模型來確保輸出偏差不會(huì)太大,且繼續(xù)保持輸出的多樣性。這通常需要在目標(biāo)函數(shù)設(shè)計(jì)時(shí),除了獎(jiǎng)勵(lì)最大化目標(biāo)外再添加一個(gè)相對(duì)于參考模型的 KL 懲罰項(xiàng),這樣做有助于防止模型學(xué)習(xí)作弊或鉆營(yíng)獎(jiǎng)勵(lì)模型。

        DPO 繞過了建模獎(jiǎng)勵(lì)函數(shù)這一步,這源于一個(gè)關(guān)鍵洞見: 從獎(jiǎng)勵(lì)函數(shù)到最優(yōu) RL 策略的分析映射。這個(gè)映射直觀地度量了給定獎(jiǎng)勵(lì)函數(shù)與給定偏好數(shù)據(jù)的匹配程度。有了它,作者就可與將基于獎(jiǎng)勵(lì)和參考模型的 RL 損失直接轉(zhuǎn)換為僅基于參考模型的損失,從而直接在偏好數(shù)據(jù)上優(yōu)化語言模型!因此,DPO 從尋找最小化 RLHF 損失的最佳方案開始,通過改變參量的方式推導(dǎo)出一個(gè) 僅需 參考模型的損失!

        有了它,我們可以直接優(yōu)化該似然目標(biāo),而不需要獎(jiǎng)勵(lì)模型或繁瑣的強(qiáng)化學(xué)習(xí)優(yōu)化過程。

        如何使用 TRL 進(jìn)行訓(xùn)練

        如前所述,一個(gè)典型的 RLHF 流水線通常包含以下幾個(gè)環(huán)節(jié):

        1. 有監(jiān)督微調(diào) (supervised fine-tuning,SFT)
        2. 用偏好標(biāo)簽標(biāo)注數(shù)據(jù)
        3. 基于偏好數(shù)據(jù)訓(xùn)練獎(jiǎng)勵(lì)模型
        4. RL 優(yōu)化

        TRL 庫包含了所有這些環(huán)節(jié)所需的工具程序。而 DPO 訓(xùn)練直接消滅了獎(jiǎng)勵(lì)建模和 RL 這兩個(gè)環(huán)節(jié) (環(huán)節(jié) 3 和 4),直接根據(jù)標(biāo)注好的偏好數(shù)據(jù)優(yōu)化 DPO 目標(biāo)。

        使用 DPO,我們?nèi)匀恍枰獔?zhí)行環(huán)節(jié) 1,但我們僅需在 TRL 中向 DPOTrainer 提供環(huán)節(jié) 2 準(zhǔn)備好的偏好數(shù)據(jù),而不再需要環(huán)節(jié) 3 和 4。標(biāo)注好的偏好數(shù)據(jù)需要遵循特定的格式,它是一個(gè)含有以下 3 個(gè)鍵的字典:

        • prompt : 即推理時(shí)輸入給模型的提示
        • chosen : 即針對(duì)給定提示的較優(yōu)回答
        • rejected :  即針對(duì)給定提示的較劣回答或非給定提示的回答

        例如,對(duì)于 stack-exchange preference 數(shù)據(jù)集,我們可以通過以下工具函數(shù)將數(shù)據(jù)集中的樣本映射至上述字典格式并刪除所有原始列:

        def return_prompt_and_responses(samples) -> Dict[str, str, str]:
            return {
                "prompt": [
                    "Question: " + question + "\n\nAnswer: "
                    for question in samples["question"]
                ],
                "chosen": samples["response_j"], # rated better than k
                "rejected": samples["response_k"], # rated worse than j
            }

        dataset = load_dataset(
            "lvwerra/stack-exchange-paired",
            split="train",
            data_dir="data/rl"
        )
        original_columns = dataset.column_names

        dataset.map(
            return_prompt_and_responses,
            batched=True,
            remove_columns=original_columns
        )

        一旦有了排序數(shù)據(jù)集,DPO 損失其實(shí)本質(zhì)上就是一種有監(jiān)督損失,其經(jīng)由參考模型獲得隱式獎(jiǎng)勵(lì)。因此,從上層來看,DPOTrainer 需要我們輸入待優(yōu)化的基礎(chǔ)模型以及參考模型:

        dpo_trainer = DPOTrainer(
            model, # 經(jīng) SFT 的基礎(chǔ)模型
            model_ref, # 一般為經(jīng) SFT 的基礎(chǔ)模型的一個(gè)拷貝
            beta=0.1# DPO 的溫度超參
            train_dataset=dataset, # 上文準(zhǔn)備好的數(shù)據(jù)集
            tokenizer=tokenizer, # 分詞器
            args=training_args, # 訓(xùn)練參數(shù),如: batch size, 學(xué)習(xí)率等
        )

        其中,超參 beta 是 DPO 損失的溫度,通常在 0.10.5 之間。它控制了我們對(duì)參考模型的關(guān)注程度,beta 越小,我們就越忽略參考模型。對(duì)訓(xùn)練器初始化后,我們就可以簡(jiǎn)單調(diào)用以下方法,使用給定的 training_args 在給定數(shù)據(jù)集上進(jìn)行訓(xùn)練了:

        dpo_trainer.train()

        基于 Llama v2 進(jìn)行實(shí)驗(yàn)

        在 TRL 中實(shí)現(xiàn) DPO 訓(xùn)練器的好處是,人們可以利用 TRL 及其依賴庫 (如 Peft 和 Accelerate) 中已有的 LLM 相關(guān)功能。有了這些庫,我們甚至可以使用 bitsandbytes 庫提供的 QLoRA 技術(shù) 來訓(xùn)練 Llama v2 模型。

        有監(jiān)督微調(diào)

        如上文所述,我們先用 TRL 的 SFTTrainer 在 SFT 數(shù)據(jù)子集上使用 QLoRA 對(duì) 7B Llama v2 模型進(jìn)行有監(jiān)督微調(diào):

        # load the base model in 4-bit quantization
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
        )

        base_model = AutoModelForCausalLM.from_pretrained(
            script_args.model_name, # "meta-llama/Llama-2-7b-hf"
            quantization_config=bnb_config,
            device_map={""0},
            trust_remote_code=True,
            use_auth_token=True,
        )
        base_model.config.use_cache = False

        # add LoRA layers on top of the quantized base model
        peft_config = LoraConfig(
            r=script_args.lora_r,
            lora_alpha=script_args.lora_alpha,
            lora_dropout=script_args.lora_dropout,
            target_modules=["q_proj""v_proj"],
            bias="none",
            task_type="CAUSAL_LM",
        )
        ...
        trainer = SFTTrainer(
            model=base_model,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            peft_config=peft_config,
            packing=True,
            max_seq_length=None,
            tokenizer=tokenizer,
            args=training_args, # HF Trainer arguments
        )
        trainer.train()

        DPO 訓(xùn)練

        SFT 結(jié)束后,我們保存好生成的模型。接著,我們繼續(xù)進(jìn)行 DPO 訓(xùn)練,我們把 SFT 生成的模型作為 DPO 的基礎(chǔ)模型和參考模型,并在上文生成的 stack-exchange preference 數(shù)據(jù)上,以 DPO 為目標(biāo)函數(shù)訓(xùn)練模型。我們選擇對(duì)模型進(jìn)行 LoRa 微調(diào),因此我們使用 Peft 的 AutoPeftModelForCausalLM 函數(shù)加載模型:

        model = AutoPeftModelForCausalLM.from_pretrained(
            script_args.model_name_or_path, # location of saved SFT model
            low_cpu_mem_usage=True,
            torch_dtype=torch.float16,
            load_in_4bit=True,
            is_trainable=True,
        )
        model_ref = AutoPeftModelForCausalLM.from_pretrained(
            script_args.model_name_or_path, # same model as the main one
            low_cpu_mem_usage=True,
            torch_dtype=torch.float16,
            load_in_4bit=True,
        )
        ...
        dpo_trainer = DPOTrainer(
            model,
            model_ref,
            args=training_args,
            beta=script_args.beta,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            peft_config=peft_config,
        )
        dpo_trainer.train()
        dpo_trainer.save_model()

        可以看出,我們以 4 比特的方式加載模型,然后通過 peft_config 參數(shù)選擇 QLora 方法對(duì)其進(jìn)行訓(xùn)練。訓(xùn)練器還會(huì)用評(píng)估數(shù)據(jù)集評(píng)估訓(xùn)練進(jìn)度,并報(bào)告一些關(guān)鍵指標(biāo),例如可以選擇通過 WandB 記錄并顯示隱式獎(jiǎng)勵(lì)。最后,我們可以將訓(xùn)練好的模型推送到 HuggingFace Hub。

        總結(jié)

        SFT 和 DPO 訓(xùn)練腳本的完整源代碼可在該目錄 examples/stack_llama_2 處找到,訓(xùn)好的已合并模型也已上傳至 HF Hub (見 此處)。

        你可以在 這兒 找到我們的模型在訓(xùn)練過程的 WandB 日志,其中包含了 DPOTrainer 在訓(xùn)練和評(píng)估期間記錄下來的以下獎(jiǎng)勵(lì)指標(biāo):

        • rewards/chosen (較優(yōu)回答的獎(jiǎng)勵(lì)) : 針對(duì)較優(yōu)回答,策略模型與參考模型的對(duì)數(shù)概率二者之差的均值,按 beta 縮放。
        • rewards/rejected (較劣回答的獎(jiǎng)勵(lì)) : 針對(duì)較劣回答,策略模型與參考模型的對(duì)數(shù)概率二者之差的均值,按 beta 縮放。
        • rewards/accuracy (獎(jiǎng)勵(lì)準(zhǔn)確率) : 較優(yōu)回答的獎(jiǎng)勵(lì)大于相應(yīng)較劣回答的獎(jiǎng)勵(lì)的頻率的均值
        • rewards/margins (獎(jiǎng)勵(lì)余裕值) : 較優(yōu)回答的獎(jiǎng)勵(lì)與相應(yīng)較劣回答的獎(jiǎng)勵(lì)二者之差的均值。

        直觀上講,在訓(xùn)練過程中,我們希望余裕值增加并且準(zhǔn)確率達(dá)到 1.0,換句話說,較優(yōu)回答的獎(jiǎng)勵(lì)高于較劣回答的獎(jiǎng)勵(lì) (或余裕值大于零)。隨后,我們還可以在評(píng)估數(shù)據(jù)集上計(jì)算這些指標(biāo)。

        我們希望我們代碼的發(fā)布可以降低讀者的入門門檻,讓大家可以在自己的數(shù)據(jù)集上嘗試這種大語言模型對(duì)齊方法,我們迫不及待地想看到你會(huì)用它做哪些事情!如果你想試試我們訓(xùn)練出來的模型,可以玩玩這個(gè) space: trl-lib/stack-llama。

        ?? 寶子們可以戳 閱讀原文 查看文中所有的外部鏈接喲!


        英文原文: https://hf.co/blog/dpo-trl

        原文作者: Kashif Rasul, Younes Belkada, Leandro von Werra

        譯者: Matrix Yao (姚偉峰),英特爾深度學(xué)習(xí)工程師,工作方向?yàn)?transformer-family 模型在各模態(tài)數(shù)據(jù)上的應(yīng)用及大規(guī)模模型的訓(xùn)練推理。

        審校/排版: zhongdongy (阿東)

        瀏覽 122
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        評(píng)論
        圖片
        表情
        推薦
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
          
          

            1. 成人 视频免费观看网站 | 成人影院久久 | 色色热热 | 闺蜜下药揉捏我的大乳 | 欧美一级乱黄 |