1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        像 Keras 一樣優(yōu)雅地使用 pytorch-lightning

        共 7417字,需瀏覽 15分鐘

         ·

        2021-01-27 19:40

        你好,我是云哥。本篇文章為大家介紹一個(gè)可以幫助大家優(yōu)雅地進(jìn)行深度學(xué)習(xí)研究的工具:pytorch-lightning。

        公眾號(hào)后臺(tái)回復(fù)關(guān)鍵字:源碼,獲取本文源代碼!

        pytorch-lightning 是建立在pytorch之上的高層次模型接口,pytorch-lightning之于pytorch,就如同keras之于tensorflow。

        關(guān)于pytorch-lightning的完整入門介紹,可以參考我的另外一篇文章。

        使用pytorch-lightning漂亮地進(jìn)行深度學(xué)習(xí)研究

        我用了約80行代碼對(duì) pytorch-lightning 做了進(jìn)一步封裝,使得對(duì)它不熟悉的用戶可以用類似Keras的風(fēng)格使用它,輕而易舉地實(shí)現(xiàn)如下功能:

        • 模型訓(xùn)練(cpu,gpu,多GPU)

        • 模型評(píng)估 (自定義評(píng)估指標(biāo))

        • 最優(yōu)模型參數(shù)保存(ModelCheckPoint)

        • 自定義學(xué)習(xí)率 (lr_schedule)

        • 畫出優(yōu)美的Loss和Metric曲線

        它甚至?xí)菿eras還要更加簡單和好用一些。

        這個(gè)封裝的類 LightModel 添加到了我的開源倉庫 torchkeras 中,用戶可以用pip進(jìn)行安裝。

        pip?install?-U?torchkeras

        以下是一個(gè)通過LightModel使用DNN模型進(jìn)行二分類的完整范例。

        在本例的最后,云哥將向大家表演一個(gè)"金蟬脫殼"的絕技。不要離開。????

        import?numpy?as?np?
        import?pandas?as?pd?
        from?matplotlib?import?pyplot?as?plt
        import?torch
        from?torch?import?nn
        import?torch.nn.functional?as?F
        from?torch.utils.data?import?Dataset,DataLoader,TensorDataset
        import?datetime

        #attention?these?two?lines
        import?pytorch_lightning?as?pl?
        import?torchkeras?

        一,準(zhǔn)備數(shù)據(jù)

        %matplotlib?inline
        %config?InlineBackend.figure_format?=?'svg'

        #number?of?samples
        n_positive,n_negative?=?2000,2000

        #positive?samples
        r_p?=?5.0?+?torch.normal(0.0,1.0,size?=?[n_positive,1])?
        theta_p?=?2*np.pi*torch.rand([n_positive,1])
        Xp?=?torch.cat([r_p*torch.cos(theta_p),r_p*torch.sin(theta_p)],axis?=?1)
        Yp?=?torch.ones_like(r_p)

        #negative?samples
        r_n?=?8.0?+?torch.normal(0.0,1.0,size?=?[n_negative,1])?
        theta_n?=?2*np.pi*torch.rand([n_negative,1])
        Xn?=?torch.cat([r_n*torch.cos(theta_n),r_n*torch.sin(theta_n)],axis?=?1)
        Yn?=?torch.zeros_like(r_n)

        #concat?positive?and?negative?samples
        X?=?torch.cat([Xp,Xn],axis?=?0)
        Y?=?torch.cat([Yp,Yn],axis?=?0)


        #visual?samples
        plt.figure(figsize?=?(6,6))
        plt.scatter(Xp[:,0],Xp[:,1],c?=?"r")
        plt.scatter(Xn[:,0],Xn[:,1],c?=?"g")
        plt.legend(["positive","negative"]);


        #?split?samples?into?train?and?valid?data.
        ds?=?TensorDataset(X,Y)
        ds_train,ds_valid?=?torch.utils.data.random_split(ds,[int(len(ds)*0.7),len(ds)-int(len(ds)*0.7)])
        dl_train?=?DataLoader(ds_train,batch_size?=?100,shuffle=True,num_workers=4)
        dl_valid?=?DataLoader(ds_valid,batch_size?=?100,num_workers=4)

        二,定義模型

        #define?the?network?like?torch
        class?Net(nn.Module):??
        ????def?__init__(self):
        ????????super().__init__()
        ????????self.fc1?=?nn.Linear(2,6)
        ????????self.fc2?=?nn.Linear(6,12)?
        ????????self.fc3?=?nn.Linear(12,1)
        ????????
        ????def?forward(self,x):
        ????????x?=?F.relu(self.fc1(x))
        ????????x?=?F.relu(self.fc2(x))
        ????????y?=?nn.Sigmoid()(self.fc3(x))
        ????????return?y???????
        class?Model(torchkeras.LightModel):
        ????def?shared_step(self,batch):
        ????????x,?y?=?batch
        ????????prediction?=?self(x)
        ????????loss?=?nn.BCELoss()(prediction,y)
        ????????preds?=?torch.where(prediction>0.5,torch.ones_like(prediction),torch.zeros_like(prediction))
        ????????acc?=?pl.metrics.functional.accuracy(preds,?y)
        ????????#?attention:?there?must?be?a?key?of?"loss"?in?the?returned?dict?
        ????????dic?=?{"loss":loss,"acc":acc}?
        ????????return?dic
        ????
        ????#optimizer,and?optional?lr_scheduler
        ????def?configure_optimizers(self):
        ????????optimizer?=?torch.optim.Adam(self.parameters(),?lr=1e-2)
        ????????lr_scheduler?=?torch.optim.lr_scheduler.StepLR(optimizer,?step_size=10,?gamma=0.0001)
        ????????return?{"optimizer":optimizer,"lr_scheduler":lr_scheduler}
        ?

        注意,下面我們把網(wǎng)絡(luò)結(jié)構(gòu)net包裝在一個(gè)model的殼之中。????

        pl.seed_everything(123)

        #?we?wrap?the?network?into?a?Model?
        net?=?Net()
        model?=?Model(net)

        torchkeras.summary(model,input_shape?=(2,))

        ----------------------------------------------------------------
        ????????Layer?(type)???????????????Output?Shape?????????Param?#
        ================================================================
        ????????????Linear-1????????????????????[-1,?4]??????????????12
        ????????????Linear-2????????????????????[-1,?8]??????????????40
        ????????????Linear-3????????????????????[-1,?1]???????????????9
        ================================================================
        Total?params:?61
        Trainable?params:?61
        Non-trainable?params:?0
        ----------------------------------------------------------------
        Input?size?(MB):?0.000008
        Forward/backward?pass?size?(MB):?0.000099
        Params?size?(MB):?0.000233
        Estimated?Total?Size?(MB):?0.000340
        ----------------------------------------------------------------

        三,訓(xùn)練模型


        ckpt_callback?=?pl.callbacks.ModelCheckpoint(
        ????monitor='val_loss',
        ????save_top_k=1,
        ????mode='min'
        )

        #?gpus=0?則使用cpu訓(xùn)練,gpus=1則使用1個(gè)gpu訓(xùn)練,gpus=2則使用2個(gè)gpu訓(xùn)練,gpus=-1則使用所有g(shù)pu訓(xùn)練,
        #?gpus=[0,1]則指定使用0號(hào)和1號(hào)gpu訓(xùn)練,?gpus="0,1,2,3"則使用0,1,2,3號(hào)gpu訓(xùn)練
        #?tpus=1?則使用1個(gè)tpu訓(xùn)練

        trainer?=?pl.Trainer(max_epochs=10,gpus=0,callbacks?=?[ckpt_callback])?

        #斷點(diǎn)續(xù)訓(xùn)
        #trainer?=?pl.Trainer(resume_from_checkpoint='./lightning_logs/version_31/checkpoints/epoch=02-val_loss=0.05.ckpt')

        trainer.fit(model,dl_train,dl_valid)
        GPU?available:?False,?used:?False
        TPU?available:?None,?using:?0?TPU?cores

        ??|?Name?|?Type?|?Params
        ------------------------------
        0?|?net??|?Net??|?115???
        ------------------------------
        115???????Trainable?params
        0?????????Non-trainable?params
        115???????Total?params

        ================================================================================2021-01-24?20:47:39
        epoch?=??0
        {'val_loss':?0.6492899060249329,?'val_acc':?0.6033333539962769}
        {'acc':?0.5374999642372131,?'loss':?0.6766871809959412}

        ================================================================================2021-01-24?20:47:40
        epoch?=??1
        {'val_loss':?0.5390750765800476,?'val_acc':?0.763333261013031}
        {'acc':?0.676428496837616,?'loss':?0.5993633270263672}

        ================================================================================2021-01-24?20:47:41
        epoch?=??2
        {'val_loss':?0.3617284595966339,?'val_acc':?0.8608333468437195}
        {'acc':?0.8050000071525574,?'loss':?0.4533742070198059}

        ================================================================================2021-01-24?20:47:42
        epoch?=??3
        {'val_loss':?0.21798092126846313,?'val_acc':?0.9158334732055664}
        {'acc':?0.8910714387893677,?'loss':?0.28334707021713257}

        ================================================================================2021-01-24?20:47:43
        epoch?=??4
        {'val_loss':?0.18157465755939484,?'val_acc':?0.9208333492279053}
        {'acc':?0.926428496837616,?'loss':?0.20261192321777344}

        ================================================================================2021-01-24?20:47:44
        epoch?=??5
        {'val_loss':?0.17406059801578522,?'val_acc':?0.9300000071525574}
        {'acc':?0.9203571677207947,?'loss':?0.1980973333120346}

        ================================================================================2021-01-24?20:47:45
        epoch?=??6
        {'val_loss':?0.16323940455913544,?'val_acc':?0.935833215713501}
        {'acc':?0.9242857694625854,?'loss':?0.1862144023180008}

        ================================================================================2021-01-24?20:47:46
        epoch?=??7
        {'val_loss':?0.16635416448116302,?'val_acc':?0.9300000071525574}
        {'acc':?0.925000011920929,?'loss':?0.18595384061336517}

        ================================================================================2021-01-24?20:47:47
        epoch?=??8
        {'val_loss':?0.1665605753660202,?'val_acc':?0.9258332848548889}
        {'acc':?0.9267856478691101,?'loss':?0.18308643996715546}

        ================================================================================2021-01-24?20:47:48
        epoch?=??9
        {'val_loss':?0.1757962554693222,?'val_acc':?0.9300000071525574}
        {'acc':?0.9246429204940796,?'loss':?0.18282662332057953}
        #?visual?the?results
        fig,?(ax1,ax2)?=?plt.subplots(nrows=1,ncols=2,figsize?=?(12,5))
        ax1.scatter(Xp[:,0],Xp[:,1],?c="r")
        ax1.scatter(Xn[:,0],Xn[:,1],c?=?"g")
        ax1.legend(["positive","negative"]);
        ax1.set_title("y_true")

        Xp_pred?=?X[torch.squeeze(model.forward(X)>=0.5)]
        Xn_pred?=?X[torch.squeeze(model.forward(X)<0.5)]

        ax2.scatter(Xp_pred[:,0],Xp_pred[:,1],c?=?"r")
        ax2.scatter(Xn_pred[:,0],Xn_pred[:,1],c?=?"g")
        ax2.legend(["positive","negative"]);
        ax2.set_title("y_pred")

        四,評(píng)估模型

        import?pandas?as?pd?

        history?=?model.history
        dfhistory?=?pd.DataFrame(history)?
        dfhistory?
        %matplotlib?inline
        %config?InlineBackend.figure_format?=?'svg'

        import?matplotlib.pyplot?as?plt

        def?plot_metric(dfhistory,?metric):
        ????train_metrics?=?dfhistory[metric]
        ????val_metrics?=?dfhistory['val_'+metric]
        ????epochs?=?range(1,?len(train_metrics)?+?1)
        ????plt.plot(epochs,?train_metrics,?'bo--')
        ????plt.plot(epochs,?val_metrics,?'ro-')
        ????plt.title('Training?and?validation?'+?metric)
        ????plt.xlabel("Epochs")
        ????plt.ylabel(metric)
        ????plt.legend(["train_"+metric,?'val_'+metric])
        ????plt.show()
        ????
        plot_metric(dfhistory,"loss")

        plot_metric(dfhistory,"acc")


        results?=?trainer.test(model,?test_dataloaders=dl_valid,?verbose?=?False)
        print(results[0])

        {'test_loss':?0.15939873456954956,?'test_acc':?0.9599999785423279}

        五,使用模型

        def?predict(model,dl):
        ????model.eval()
        ????result?=?torch.cat([model.forward(t[0].to(model.device))?for?t?in?dl])
        ????return(result.data)

        result?=?predict(model,dl_valid)

        result?
        tensor([[9.8850e-01],
        ????????[2.3642e-03],
        ????????[1.2128e-04],
        ????????...,
        ????????[9.9002e-01],
        ????????[9.6689e-01],
        ????????[1.5238e-02]])

        六,保存模型

        最優(yōu)模型默認(rèn)保存在 trainer.checkpoint_callback.best_model_path 的目錄下,可以直接加載。

        print(trainer.checkpoint_callback.best_model_path)
        print(trainer.checkpoint_callback.best_model_score)
        model_clone?=?Model.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
        trainer_clone?=?pl.Trainer(max_epochs=10)?
        results?=?trainer_clone.test(model_clone,?test_dataloaders=dl_valid,?verbose?=?False)
        print(results[0])

        {'test_loss':?0.20505842566490173,?'test_acc':?0.9399999976158142}

        最后,給大家表演一個(gè)金蟬脫殼的絕技。????

        使用LightModel之殼訓(xùn)練后,可丟棄該軀殼,直接手動(dòng)保存最優(yōu)的網(wǎng)絡(luò)結(jié)構(gòu)net的權(quán)重。

        best_net?=?model.net?
        torch.save(best_net.state_dict(),"best_net.pt")

        #加載權(quán)重
        net_clone?=?Net()
        net_clone.load_state_dict(torch.load("best_net.pt"))


        data,label?=?next(iter(dl_valid))
        with?torch.no_grad():
        ????preds??=?model(data)
        ????preds_clone?=?net_clone(data)
        ????
        print("model?prediction:\n",preds[0:10],"\n")
        print("net_clone?prediction:\n",preds_clone[0:10])

        model?prediction:
        ?tensor([[9.8850e-01],
        ????????[2.3642e-03],
        ????????[1.2128e-04],
        ????????[1.0022e-04],
        ????????[9.3577e-01],
        ????????[4.9769e-02],
        ????????[9.8537e-01],
        ????????[9.9940e-01],
        ????????[4.1117e-04],
        ????????[9.4009e-01]])?

        net_clone?prediction:
        ?tensor([[9.8850e-01],
        ????????[2.3642e-03],
        ????????[1.2128e-04],
        ????????[1.0022e-04],
        ????????[9.3577e-01],
        ????????[4.9769e-02],
        ????????[9.8537e-01],
        ????????[9.9940e-01],
        ????????[4.1117e-04],
        ????????[9.4009e-01]])


        以上。


        如果對(duì)本文內(nèi)容理解上有需要進(jìn)一步和作者交流的地方,歡迎在公眾號(hào)"算法美食屋"下留言。作者時(shí)間和精力有限,會(huì)酌情予以回復(fù)。


        也可以在公眾號(hào)后臺(tái)回復(fù)關(guān)鍵字:加群,加入讀者交流群和大家討論。


        原創(chuàng)不易,不想被白嫖。歡迎大家三連支持云哥????:點(diǎn)贊,在看,分享。感謝。


        瀏覽 48
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        評(píng)論
        圖片
        表情
        推薦
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            国产在线aaa | 把公主带到乳刑室用乳形调教视频 | 国产又粗又猛又大又爽的视频 | 波多野结衣光棍影院 | 体育生全黄h全肉短篇禁乱 | 亚洲 欧美 日韩 偷 妻 乱 | 97精品一区二区三区 | 大骚逼人人干 | 操小逼逼 | 两男一女两根同进去舒服吗动漫 |