35. 演習 Sequence-to-Sequence (Seq2Seq) モデル
Sequence-to-Sequence (Seq2Seq) モデルは、系列を入力として系列を出力するモデルです。
入力系列をRNNで固定長のベクトルに変換(= Encode)し、そのベクトルを用いて系列を出力(=
Decode)することから、Encoder-Decoder モデルとも呼ばれます。
RNNの代わりにLSTMやGRUでも可能です。
機械翻訳のほか、文書要約や対話生成にも使われます。
今回は機械翻訳を例にとって解説していきます。
編集するにはダブルクリックするか Enter キーを押してください
Collecting wheel==0.34.2
Downloading wheel-0.34.2-py2.py3-none-any.whl (26 kB)
Installing collected packages: wheel
Attempting uninstall: wheel
Found existing installation: wheel 0.37.0
Uninstalling wheel-0.37.0:
Successfully uninstalled wheel-0.37.0
ERROR: pip's dependency resolver does not currently take into account all the packages that a
tensorflow 2.7.0 requires tensorboard~=2.6, but you have tensorboard 2.2.2 which is incompati
tensorflow 2.7.0 requires tensorflow-estimator<2.8,~=2.7.0rc0, but you have tensorflow-estima
Successfully installed wheel-0.34.2
WARNING: The following packages were previously imported in this runtime:
[wheel]
You must restart the runtime in order to use newly installed versions.
RESTART RUNTIME
%pip install "wheel==0.34.2"
from os import path
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
accelerator = 'cu80' if path.exists('/opt/bin/nvidia-smi') else 'cpu'
!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.0-{platform}-linux_x86_64.w
import torch
print(torch.__version__)
print(torch.cuda.is_available())
ERROR: HTTP error 403 while getting http://download.pytorch.org/whl/cu80/torch-0.4.0-cp37-c
ERROR: Could not install requirement torch==0.4.0 from http://download.pytorch.org/whl/cu80/t
1.10.0+cu111
True
37. Resolving uc20f19ec7a1a9d6d90e6737f64a.dl.dropboxusercontent.com (uc20f19ec7a1a9d6d90e6737f
Connecting to uc20f19ec7a1a9d6d90e6737f64a.dl.dropboxusercontent.com (uc20f19ec7a1a9d6d90e6
HTTP request sent, awaiting response... 200 OK
Length: 2784447 (2.7M) [text/plain]
Saving to: ‘data/train.ja’
train.ja 100%[===================>] 2.66M --.-KB/s in 0.07s
2021-12-22 23:57:41 (36.2 MB/s) - ‘data/train.ja’ saved [2784447/2784447]
! ls data
dev.en dev.ja test.en test.ja train.en train.ja
import random
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
from nltk import bleu_score
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from utils import Vocab
# デバイスの設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1)
random_state = 42
print(torch.__version__)
1.10.0+cu111
英語-日本語の対訳コーパスである、Tanaka Corpus (
http://www.edrdg.org/wiki/index.php/Tanaka_Corpus )を使います。
今回はそのうちの一部分を取り出したsmall_parallel_enja: 50k En/Ja Parallel Corpus for Testing
SMT Methods ( https://github.com/odashi/small_parallel_enja )を使用します。
train.enとtrain.jaの中身を見てみましょう。
1.データセットの準備
! head -10 data/train.en
i can 't tell who will arrive first .
many animals have been destroyed by men .
38. i 'm in the tennis club .
emi looks happy .
please bear this fact in mind .
she takes care of my children .
we want to be international .
you ought not to break your promise .
when you cross the street , watch out for cars .
i have nothing to live for .
! head -10 ./data/train.ja
誰 が 一番 に 着 く か 私 に は 分か り ま せ ん 。
多く の 動物 が 人間 に よ っ て 滅ぼ さ れ た 。
私 は テニス 部員 で す 。
エミ は 幸せ そう に 見え ま す 。
この 事実 を 心 に 留め て お い て 下さ い 。
彼女 は 私 たち の 世話 を し て くれ る 。
私 達 は 国際 人 に な り た い と 思 い ま す 。
約束 を 破 る べ き で は あ り ま せ ん 。
道路 を 横切 る とき は 車 に 注意 し なさ い 。
私 に は 生き 甲斐 が な い 。
それぞれの文章が英語-日本語で対応しているのがわかります。
1.1データの読み込みと単語の分割
def load_data(file_path):
# テキストファイルからデータを読み込むメソッド
data = []
for line in open(file_path, encoding='utf-8'):
words = line.strip().split() # スペースで単語を分割
data.append(words)
return data
train_X = load_data('./data/train.en')
train_Y = load_data('./data/train.ja')
# 訓練データと検証データに分割
train_X, valid_X, train_Y, valid_Y = train_test_split(train_X, train_Y, test_size=0.2, random_state
この時点で入力と教師データは以下のようになっています
print('train data', train_X[0])
print('valid data', valid_X[0])
train data ['where', 'shall', 'we', 'eat', 'tonight', '?']
valid data ['you', 'may', 'extend', 'your', 'stay', 'in', 'tokyo', '.']