質問編集履歴
2
エラー文の追記
title
CHANGED
File without changes
|
body
CHANGED
@@ -3,7 +3,11 @@
|
|
3
3
|
from nn import build_model (2)
|
4
4
|
```
|
5
5
|
(1)は通っているのですが、(2)が通りません.原因を教えていただきたいです
|
6
|
+
以下のようなエラーが出ます
|
7
|
+
```ここに言語を入力
|
8
|
+
ModuleNotFoundError: No module named 'nn'
|
6
9
|
|
10
|
+
```
|
7
11
|
全体のコードも示しておきます
|
8
12
|
````python
|
9
13
|
import logging
|
1
コード全体の追記
title
CHANGED
File without changes
|
body
CHANGED
@@ -2,4 +2,67 @@
|
|
2
2
|
import torch.nn as nn (1)
|
3
3
|
from nn import build_model (2)
|
4
4
|
```
|
5
|
-
(1)は通っているのですが、(2)が通りません.原因を教えていただきたいです
|
5
|
+
(1)は通っているのですが、(2)が通りません.原因を教えていただきたいです
|
6
|
+
|
7
|
+
全体のコードも示しておきます
|
8
|
+
````python
|
9
|
+
import logging
|
10
|
+
import os
|
11
|
+
import pickle
|
12
|
+
|
13
|
+
import torch
|
14
|
+
import torch.nn as nn
|
15
|
+
import torch.optim as optim
|
16
|
+
|
17
|
+
from config import Config
|
18
|
+
from nn import build_model
|
19
|
+
from tokenizer import Tokenizer
|
20
|
+
from utils import (DialogDataset, one_cycle, evaluate,
|
21
|
+
seed_everything, BalancedDataLoader,
|
22
|
+
make_train_data_from_txt, make_itf)
|
23
|
+
|
24
|
+
logging.basicConfig(level=logging.INFO)
|
25
|
+
|
26
|
+
if __name__ == '__main__':
|
27
|
+
logging.info('*** Initializing ***')
|
28
|
+
|
29
|
+
if not os.path.isdir(Config.data_dir):
|
30
|
+
os.mkdir(Config.data_dir)
|
31
|
+
|
32
|
+
seed_everything(Config.seed)
|
33
|
+
device = torch.device(Config.device)
|
34
|
+
|
35
|
+
start_epoch = 0
|
36
|
+
tokenizer = Tokenizer.from_pretrained(Config.model_name)
|
37
|
+
|
38
|
+
logging.info('Preparing training data')
|
39
|
+
if Config.use_pickle:
|
40
|
+
with open(f'{Config.pickle_path}', 'rb') as f:
|
41
|
+
train_data = pickle.load(f)
|
42
|
+
else:
|
43
|
+
train_data = make_train_data_from_txt(Config, tokenizer)
|
44
|
+
itf = make_itf(train_data, Config.vocab_size)
|
45
|
+
dataset = DialogDataset(train_data, tokenizer)
|
46
|
+
|
47
|
+
logging.info('Define Models')
|
48
|
+
model = build_model(Config).to(device)
|
49
|
+
model.unfreeze()
|
50
|
+
|
51
|
+
logging.info('Define Loss and Optimizer')
|
52
|
+
criterion = nn.CrossEntropyLoss(reduction='none')
|
53
|
+
optimizer = optim.AdamW(model.parameters(), lr=Config.lr, betas=Config.betas, eps=1e-9)
|
54
|
+
|
55
|
+
if Config.load:
|
56
|
+
state_dict = torch.load(f'{Config.data_dir}/{Config.fn}.pth')
|
57
|
+
start_epoch = 10
|
58
|
+
print(f'Start Epoch: {start_epoch}')
|
59
|
+
model.load_state_dict(state_dict['model'])
|
60
|
+
optimizer.load_state_dict(state_dict['opt'])
|
61
|
+
|
62
|
+
logging.info('Start Training')
|
63
|
+
for epoch in range(start_epoch, Config.n_epoch):
|
64
|
+
one_cycle(epoch, Config, model, optimizer, criterion,
|
65
|
+
BalancedDataLoader(dataset, tokenizer.pad_token_id),
|
66
|
+
tokenizer, device)
|
67
|
+
evaluate(Config, 'おはよーーー', tokenizer, model, device)
|
68
|
+
```
|