1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        (附代碼)經(jīng)驗(yàn) | 深度學(xué)習(xí)pytorch訓(xùn)練代碼模板

        共 5877字,需瀏覽 12分鐘

         ·

        2021-09-20 21:51

        點(diǎn)擊左上方藍(lán)字關(guān)注我們



        一個(gè)專注于目標(biāo)檢測與深度學(xué)習(xí)知識(shí)分享的公眾號
        作者|wfnian@知乎

        鏈接|https://zhuanlan.zhihu.com/p/396666255


        從參數(shù)定義,到網(wǎng)絡(luò)模型定義,再到訓(xùn)練步驟,驗(yàn)證步驟,測試步驟,總結(jié)了一套較為直觀的模板。目錄如下:
        1. 導(dǎo)入包以及設(shè)置隨機(jī)種子
        2. 以類的方式定義超參數(shù)
        3. 定義自己的模型
        4. 定義早停類(此步驟可以省略)
        5. 定義自己的數(shù)據(jù)集Dataset,DataLoader
        6. 實(shí)例化模型,設(shè)置loss,優(yōu)化器等
        7. 開始訓(xùn)練以及調(diào)整lr
        8. 繪圖
        9. 預(yù)測


        01


        導(dǎo)入包以及設(shè)置隨機(jī)種子
        import numpy as npimport torchimport torch.nn as nnimport numpy as npimport pandas as pdfrom torch.utils.data import DataLoader, Datasetfrom sklearn.model_selection import train_test_splitimport matplotlib.pyplot as plt
        import randomseed = 42torch.manual_seed(seed)np.random.seed(seed)random.seed(seed)



        02


        以類的方式定義超參數(shù)
        class argparse():    pass
        args = argparse()args.epochs, args.learning_rate, args.patience = [30, 0.001, 4]args.hidden_size, args.input_size= [40, 30]args.device, = [torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),]


        03

        定義自己的模型


        class Your_model(nn.Module):    def __init__(self):        super(Your_model, self).__init__()        pass
        def forward(self,x): pass        return x



        04


        定義早停類(此步驟可以省略)
        class EarlyStopping():    def __init__(self,patience=7,verbose=False,delta=0):        self.patience = patience        self.verbose = verbose        self.counter = 0        self.best_score = None        self.early_stop = False        self.val_loss_min = np.Inf        self.delta = delta    def __call__(self,val_loss,model,path):        print("val_loss={}".format(val_loss))        score = -val_loss        if self.best_score is None:            self.best_score = score            self.save_checkpoint(val_loss,model,path)        elif score < self.best_score+self.delta:            self.counter+=1            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')            if self.counter>=self.patience:                self.early_stop = True        else:            self.best_score = score            self.save_checkpoint(val_loss,model,path)            self.counter = 0    def save_checkpoint(self,val_loss,model,path):        if self.verbose:            print(                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')        torch.save(model.state_dict(), path+'/'+'model_checkpoint.pth')        self.val_loss_min = val_loss



        05


        定義自己的數(shù)據(jù)集Dataset,DataLoader
        class Dataset_name(Dataset):    def __init__(self, flag='train'):        assert flag in ['train', 'test', 'valid']        self.flag = flag        self.__load_data__()
        def __getitem__(self, index): pass def __len__(self): pass
        def __load_data__(self, csv_paths: list): pass print( "train_X.shape:{}\ntrain_Y.shape:{}\nvalid_X.shape:{}\nvalid_Y.shape:{}\n" .format(self.train_X.shape, self.train_Y.shape, self.valid_X.shape, self.valid_Y.shape))
        train_dataset = Dataset_name(flag='train')train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)valid_dataset = Dataset_name(flag='valid')valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=64, shuffle=True)



        06


        實(shí)例化模型,設(shè)置loss,優(yōu)化器等
        model = Your_model().to(args.device)criterion = torch.nn.MSELoss()optimizer = torch.optim.Adam(Your_model.parameters(),lr=args.learning_rate)
        train_loss = []valid_loss = []train_epochs_loss = []valid_epochs_loss = []
        early_stopping = EarlyStopping(patience=args.patience,verbose=True)



        07


        開始訓(xùn)練以及調(diào)整lr
        for epoch in range(args.epochs):    Your_model.train()    train_epoch_loss = []    for idx,(data_x,data_y) in enumerate(train_dataloader,0):        data_x = data_x.to(torch.float32).to(args.device)        data_y = data_y.to(torch.float32).to(args.device)        outputs = Your_model(data_x)        optimizer.zero_grad()        loss = criterion(data_y,outputs)        loss.backward()        optimizer.step()        train_epoch_loss.append(loss.item())        train_loss.append(loss.item())        if idx%(len(train_dataloader)//2)==0:            print("epoch={}/{},{}/{}of train, loss={}".format(                epoch, args.epochs, idx, len(train_dataloader),loss.item()))    train_epochs_loss.append(np.average(train_epoch_loss))
        #=====================valid============================ Your_model.eval() valid_epoch_loss = [] for idx,(data_x,data_y) in enumerate(valid_dataloader,0): data_x = data_x.to(torch.float32).to(args.device) data_y = data_y.to(torch.float32).to(args.device) outputs = Your_model(data_x) loss = criterion(outputs,data_y) valid_epoch_loss.append(loss.item()) valid_loss.append(loss.item()) valid_epochs_loss.append(np.average(valid_epoch_loss)) #==================early stopping====================== early_stopping(valid_epochs_loss[-1],model=Your_model,path=r'c:\\your_model_to_save') if early_stopping.early_stop: print("Early stopping") break #====================adjust lr======================== lr_adjust = { 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 10: 5e-7, 15: 1e-7, 20: 5e-8 } if epoch in lr_adjust.keys(): lr = lr_adjust[epoch] for param_group in optimizer.param_groups: param_group['lr'] = lr        print('Updating learning rate to {}'.format(lr))



        08


        繪圖
        plt.figure(figsize=(12,4))plt.subplot(121)plt.plot(train_loss[:])plt.title("train_loss")plt.subplot(122)plt.plot(train_epochs_loss[1:],'-o',label="train_loss")plt.plot(valid_epochs_loss[1:],'-o',label="valid_loss")plt.title("epochs_loss")plt.legend()plt.show()



        09


        預(yù)測


        # 此處可定義一個(gè)預(yù)測集的Dataloader。也可以直接將你的預(yù)測數(shù)據(jù)reshape,添加batch_size=1Your_model.eval()predict = Your_model(data)


        END



        雙一流大學(xué)研究生團(tuán)隊(duì)創(chuàng)建,專注于目標(biāo)檢測與深度學(xué)習(xí),希望可以將分享變成一種習(xí)慣!

        整理不易,點(diǎn)贊鼓勵(lì)一下吧↓

        瀏覽 58
        點(diǎn)贊
        評論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        評論
        圖片
        表情
        推薦
        點(diǎn)贊
        評論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            成人免费视频网址 | 国模大尺度一区二区三区 | 男人天堂2024手机在线版 | 羽月希人妻初乳涨奶 | 我要看黄色录像一级片 | 俺也去在线视频 | 国产精品视频福利 | 日本免费一二三 | 西西特级444www高清视频 | 免费在线成人毛片 |