Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

p225 のbatch_size宣言について #13

Open
roadto93ds opened this issue Oct 10, 2021 · 1 comment
Open

p225 のbatch_size宣言について #13

roadto93ds opened this issue Oct 10, 2021 · 1 comment

Comments

@roadto93ds
Copy link

p225 のdataloaderを導入した訓練についてですが、

for文の外で batch_sizeを決めて動かしたのですが、

for文の中で
imgs.shape[0] という形で改めて定義しないと動かないのはなぜでしょうか?

# epoch回数
n_epochs = 100
# batch_size
batch_size = 64

# data_loader
train_data_loader = DataLoader(cifar2, batch_size=batch_size, shuffle=True)

# 損失関数
loss_fn = nn.NLLLoss()
# 最適化戦略
optimizer = optim.SGD(model.parameters(), lr=1e-2)


for epoch in range(n_epochs):
  for imgs, labels in train_data_loader:
    # model出力
    outputs = model(imgs.view(batch_size,-1)) # batch_size行,3*32*32列にする
    # 損失関数での評価
    train_loss = loss_fn(outputs, labels)

    optimizer.zero_grad()

    train_loss.backward() 

    optimizer.step()

  print("Epoch:{}, Loss:{}".format(epoch, float(train_loss)))
@Gin5050
Copy link
Owner

Gin5050 commented Oct 22, 2021

roadto93ds 様

ご質問ありがとうございます。
また、ご連絡が遅くなり申し訳ありません。

ご質問の件ですが、データ数がバッチサイズで割り切れないのが理由だと思われます。
(※ 詳しいエラー内容を見てないので推測です)
該当部分ではデータ数10000をサイズ64のバッチにするので、バッチサイズ64のtensorが156個できます。
この時64*156=9984となり、内側のfor文のラストのimgsはimgs.shape[0]=16でバッチサイズ64にできません。

そのため、ここではfor文の外ではなく内側で batch_sizeを再計算してバッチサイズを64にできない場合でも計算できるようにしています。

余談ですが、DataLoaderのオプションでdrop_last=Trueにすると割り切れない部分は落としてくれるので、内側でbatch_sizeを計算しなくても動きます。

参考:https://pytorch.org/docs/stable/data.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants