質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.50%
深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

強化学習

強化学習とは、ある環境下のエージェントが現状を推測し行動を決定することで報酬を獲得するという見解から、その報酬を最大限に得る方策を学ぶ機械学習のことを指します。問題解決時に得る報酬が選択結果によって変化することで、より良い行動を選択しようと学習する点が特徴です。

AWS Glue

AWS Glueは、分析のためのデータの抽出や変換、ロードを簡単にするフルマネージド型のサービスです。データ処理の自動化の他、データ収集やETL処理も自動化・サーバレス化することが可能。AWSに保存したデータを指定すると、AWS Glueでデータ検索することもできます。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

1回答

1287閲覧

過学習の対処法を理解したい

jeesus

総合スコア1

深層学習

深層学習は、多数のレイヤのニューラルネットワークによる機械学習手法。人工知能研究の一つでディープラーニングとも呼ばれています。コンピューター自体がデータの潜在的な特徴を汲み取り、効率的で的確な判断を実現することができます。

強化学習

強化学習とは、ある環境下のエージェントが現状を推測し行動を決定することで報酬を獲得するという見解から、その報酬を最大限に得る方策を学ぶ機械学習のことを指します。問題解決時に得る報酬が選択結果によって変化することで、より良い行動を選択しようと学習する点が特徴です。

AWS Glue

AWS Glueは、分析のためのデータの抽出や変換、ロードを簡単にするフルマネージド型のサービスです。データ処理の自動化の他、データ収集やETL処理も自動化・サーバレス化することが可能。AWSに保存したデータを指定すると、AWS Glueでデータ検索することもできます。

機械学習

機械学習は、データからパターンを自動的に発見し、そこから知能的な判断を下すためのコンピューターアルゴリズムを指します。人工知能における課題のひとつです。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2021/11/04 04:47

編集2022/01/12 10:55

前提・実現したいこと

下記を実行し、まとめたところVal_Lossの値が上昇傾向に見られました。
このことから、過学習と判断したため、どのようにすれば下降に向かうか教えていただきたいです。

該当のソースコード

python

1#! pip install datasets transformers 2from huggingface_hub import notebook_login 3 4notebook_login() 5# !apt install git-lfs 6 7import transformers 8 9print(transformers.__version__) 10 11GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"] 12 13task = "cola" 14model_checkpoint = "distilbert-base-uncased" 15batch_size = 16 16 17from datasets import load_dataset, load_metric 18 19actual_task = "mnli" if task == "mnli-mm" else task 20dataset = load_dataset("glue", actual_task) 21metric = load_metric('glue', actual_task) 22 23dataset 24 25dataset["train"][0] 26 27import datasets 28import random 29import pandas as pd 30from IPython.display import display, HTML 31 32def show_random_elements(dataset, num_examples=10): 33 assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset." 34 picks = [] 35 for _ in range(num_examples): 36 pick = random.randint(0, len(dataset)-1) 37 while pick in picks: 38 pick = random.randint(0, len(dataset)-1) 39 picks.append(pick) 40 41 df = pd.DataFrame(dataset[picks]) 42 for column, typ in dataset.features.items(): 43 if isinstance(typ, datasets.ClassLabel): 44 df[column] = df[column].transform(lambda i: typ.names[i]) 45 display(HTML(df.to_html())) 46 47show_random_elements(dataset["train"]) 48 49metric 50 51import numpy as np 52 53fake_preds = np.random.randint(0, 2, size=(64,)) 54fake_labels = np.random.randint(0, 2, size=(64,)) 55metric.compute(predictions=fake_preds, references=fake_labels) 56 57from transformers import AutoTokenizer 58 59tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True) 60 61tokenizer("Hello, this one sentence!", "And this sentence goes with it.") 62 63task_to_keys = { 64 "cola": ("sentence", None), 65 "mnli": ("premise", "hypothesis"), 66 "mnli-mm": ("premise", "hypothesis"), 67 "mrpc": ("sentence1", "sentence2"), 68 "qnli": ("question", "sentence"), 69 "qqp": ("question1", "question2"), 70 "rte": ("sentence1", "sentence2"), 71 "sst2": ("sentence", None), 72 "stsb": ("sentence1", "sentence2"), 73 "wnli": ("sentence1", "sentence2"), 74} 75 76sentence1_key, sentence2_key = task_to_keys[task] 77if sentence2_key is None: 78 print(f"Sentence: {dataset['train'][0][sentence1_key]}") 79else: 80 print(f"Sentence 1: {dataset['train'][0][sentence1_key]}") 81 print(f"Sentence 2: {dataset['train'][0][sentence2_key]}") 82 83def preprocess_function(examples): 84 if sentence2_key is None: 85 return tokenizer(examples[sentence1_key], truncation=True) 86 return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True) 87 88preprocess_function(dataset['train'][:5]) 89 90 91encoded_dataset = dataset.map(preprocess_function, batched=True) 92 93from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer 94 95num_labels = 3 if task.startswith("mnli") else 1 if task=="stsb" else 2 96model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels) 97 98metric_name = "pearson" if task == "stsb" else "matthews_correlation" if task == "cola" else "accuracy" 99model_name = model_checkpoint.split("/")[-1] 100 101args = TrainingArguments( 102 f"{model_name}-finetuned-{task}", 103 evaluation_strategy = "epoch", 104 save_strategy = "epoch", 105 learning_rate=2e-5, 106 per_device_train_batch_size=batch_size, 107 per_device_eval_batch_size=batch_size, 108 num_train_epochs=5, 109 weight_decay=0.01, 110 load_best_model_at_end=True, 111 metric_for_best_model=metric_name, 112 push_to_hub=True, 113) 114 115def compute_metrics(eval_pred): 116 predictions, labels = eval_pred 117 if task != "stsb": 118 predictions = np.argmax(predictions, axis=1) 119 else: 120 predictions = predictions[:, 0] 121 return metric.compute(predictions=predictions, references=labels) 122 123validation_key = "validation_mismatched" if task == "mnli-mm" else "validation_matched" if task == "mnli" else "validation" 124trainer = Trainer( 125 model, 126 args, 127 train_dataset=encoded_dataset["train"], 128 eval_dataset=encoded_dataset[validation_key], 129 tokenizer=tokenizer, 130 compute_metrics=compute_metrics 131) 132 133trainer.train() 134

試したこと

対策したことといたしましては、
①.shard を用いてデータ数を10分の1にする
②range を用いて、データの範囲指定

などを実行しました。

イメージ説明

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

回答1

0

コード全部理解して回答してる訳ではないのですが、みた感じEarlyStopping使ってなさそうなので使ってみてはいかがでしょうか?
あとは、Batch NormalizationDropoutなど

投稿2021/11/04 08:07

編集2021/11/04 10:21
kyokio

総合スコア560

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

jeesus

2021/11/04 10:36

今、画像を追加したのですが上記が算出された結果となります。 ValidationLossのみ、過学習の傾向が見られるのですが、その際にもEarlyStoppingを用いても、よろしいのでしょうか。
kyokio

2021/11/04 10:55 編集

ありがとうございます。 この結果を見る限り、あまり効果的だとは思えないですね。 (trainデータの学習がある程度進んでいるのにvalデータでのlossが増えているので) ちなみにデータの数とバッチサイズはどれくらいでしょうか?
jeesus

2021/11/05 01:16

バッチサイズは16で、 データ数は、960です。
guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだベストアンサーが選ばれていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.50%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問