質問編集履歴
2
エラー文の追記
test
CHANGED
File without changes
|
test
CHANGED
@@ -8,7 +8,15 @@
|
|
8
8
|
|
9
9
|
(1)は通っているのですが、(2)が通りません.原因を教えていただきたいです
|
10
10
|
|
11
|
+
以下のようなエラーが出ます
|
11
12
|
|
13
|
+
```ここに言語を入力
|
14
|
+
|
15
|
+
ModuleNotFoundError: No module named 'nn'
|
16
|
+
|
17
|
+
|
18
|
+
|
19
|
+
```
|
12
20
|
|
13
21
|
全体のコードも示しておきます
|
14
22
|
|
1
コード全体の追記
test
CHANGED
File without changes
|
test
CHANGED
@@ -7,3 +7,129 @@
|
|
7
7
|
```
|
8
8
|
|
9
9
|
(1)は通っているのですが、(2)が通りません.原因を教えていただきたいです
|
10
|
+
|
11
|
+
|
12
|
+
|
13
|
+
全体のコードも示しておきます
|
14
|
+
|
15
|
+
````python
|
16
|
+
|
17
|
+
import logging
|
18
|
+
|
19
|
+
import os
|
20
|
+
|
21
|
+
import pickle
|
22
|
+
|
23
|
+
|
24
|
+
|
25
|
+
import torch
|
26
|
+
|
27
|
+
import torch.nn as nn
|
28
|
+
|
29
|
+
import torch.optim as optim
|
30
|
+
|
31
|
+
|
32
|
+
|
33
|
+
from config import Config
|
34
|
+
|
35
|
+
from nn import build_model
|
36
|
+
|
37
|
+
from tokenizer import Tokenizer
|
38
|
+
|
39
|
+
from utils import (DialogDataset, one_cycle, evaluate,
|
40
|
+
|
41
|
+
seed_everything, BalancedDataLoader,
|
42
|
+
|
43
|
+
make_train_data_from_txt, make_itf)
|
44
|
+
|
45
|
+
|
46
|
+
|
47
|
+
logging.basicConfig(level=logging.INFO)
|
48
|
+
|
49
|
+
|
50
|
+
|
51
|
+
if __name__ == '__main__':
|
52
|
+
|
53
|
+
logging.info('*** Initializing ***')
|
54
|
+
|
55
|
+
|
56
|
+
|
57
|
+
if not os.path.isdir(Config.data_dir):
|
58
|
+
|
59
|
+
os.mkdir(Config.data_dir)
|
60
|
+
|
61
|
+
|
62
|
+
|
63
|
+
seed_everything(Config.seed)
|
64
|
+
|
65
|
+
device = torch.device(Config.device)
|
66
|
+
|
67
|
+
|
68
|
+
|
69
|
+
start_epoch = 0
|
70
|
+
|
71
|
+
tokenizer = Tokenizer.from_pretrained(Config.model_name)
|
72
|
+
|
73
|
+
|
74
|
+
|
75
|
+
logging.info('Preparing training data')
|
76
|
+
|
77
|
+
if Config.use_pickle:
|
78
|
+
|
79
|
+
with open(f'{Config.pickle_path}', 'rb') as f:
|
80
|
+
|
81
|
+
train_data = pickle.load(f)
|
82
|
+
|
83
|
+
else:
|
84
|
+
|
85
|
+
train_data = make_train_data_from_txt(Config, tokenizer)
|
86
|
+
|
87
|
+
itf = make_itf(train_data, Config.vocab_size)
|
88
|
+
|
89
|
+
dataset = DialogDataset(train_data, tokenizer)
|
90
|
+
|
91
|
+
|
92
|
+
|
93
|
+
logging.info('Define Models')
|
94
|
+
|
95
|
+
model = build_model(Config).to(device)
|
96
|
+
|
97
|
+
model.unfreeze()
|
98
|
+
|
99
|
+
|
100
|
+
|
101
|
+
logging.info('Define Loss and Optimizer')
|
102
|
+
|
103
|
+
criterion = nn.CrossEntropyLoss(reduction='none')
|
104
|
+
|
105
|
+
optimizer = optim.AdamW(model.parameters(), lr=Config.lr, betas=Config.betas, eps=1e-9)
|
106
|
+
|
107
|
+
|
108
|
+
|
109
|
+
if Config.load:
|
110
|
+
|
111
|
+
state_dict = torch.load(f'{Config.data_dir}/{Config.fn}.pth')
|
112
|
+
|
113
|
+
start_epoch = 10
|
114
|
+
|
115
|
+
print(f'Start Epoch: {start_epoch}')
|
116
|
+
|
117
|
+
model.load_state_dict(state_dict['model'])
|
118
|
+
|
119
|
+
optimizer.load_state_dict(state_dict['opt'])
|
120
|
+
|
121
|
+
|
122
|
+
|
123
|
+
logging.info('Start Training')
|
124
|
+
|
125
|
+
for epoch in range(start_epoch, Config.n_epoch):
|
126
|
+
|
127
|
+
one_cycle(epoch, Config, model, optimizer, criterion,
|
128
|
+
|
129
|
+
BalancedDataLoader(dataset, tokenizer.pad_token_id),
|
130
|
+
|
131
|
+
tokenizer, device)
|
132
|
+
|
133
|
+
evaluate(Config, 'おはよーーー', tokenizer, model, device)
|
134
|
+
|
135
|
+
```
|