前提・実現したいこと
R言語を使ってmnistの手書き数字3と4について、3と4の数字を主成分分析で教師なし分類したいのですができません。片方の数字3だけを主成分分析で圧縮するコードを下に貼ります。
該当のソースコード
R
1library(ggplot2) 2library(dplyr) 3#install.packages("R.utils") 4library(R.utils) # unzip()を使う 5library(gclus) 6library(MASS) 7#install.packages("recommenderlab") 8library("recommenderlab") 9 10#download data from http://yann.lecun.com/exdb/mnist/ 11#download.file("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", 12# "train-images-idx3-ubyte.gz") 13#download.file("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", 14# "train-labels-idx1-ubyte.gz") 15#download.file("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", 16# "t10k-images-idx3-ubyte.gz") 17#download.file("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", 18# "t10k-labels-idx1-ubyte.gz") 19 20# gunzip the file 21#R.utils::gunzip("train-images-idx3-ubyte.gz") 22#R.utils::gunzip("train-labels-idx1-ubyte.gz") 23#R.utils::gunzip("t10k-images-idx3-ubyte.gz") 24#R.utils::gunzip("t10k-labels-idx1-ubyte.gz") 25 26# load image files 27load_image_file = function(filename) { 28 ret = list() 29 f = file(filename, 'rb') 30 readBin(f, 'integer', n = 1, size = 4, endian = 'big') 31 n = readBin(f, 'integer', n = 1, size = 4, endian = 'big') 32 nrow = readBin(f, 'integer', n = 1, size = 4, endian = 'big') 33 ncol = readBin(f, 'integer', n = 1, size = 4, endian = 'big') 34 x = readBin(f, 'integer', n = n * nrow * ncol, size = 1, signed = FALSE) 35 close(f) 36 data.frame(matrix(x, ncol = nrow * ncol, byrow = TRUE)) 37} 38 39# load label files 40load_label_file = function(filename) { 41 f = file(filename, 'rb') 42 readBin(f, 'integer', n = 1, size = 4, endian = 'big') 43 n = readBin(f, 'integer', n = 1, size = 4, endian = 'big') 44 y = readBin(f, 'integer', n = n, size = 1, signed = FALSE) 45 close(f) 46 y 47} 48 49# load images 50train = load_image_file("train-images-idx3-ubyte") 51test = load_image_file("t10k-images-idx3-ubyte") 52 53# load labels 54train$y = as.factor(load_label_file("train-labels-idx1-ubyte")) 55test$y = as.factor(load_label_file("t10k-labels-idx1-ubyte")) 56 57# helper function for visualization 58show_digit = function(arr784, col = gray(12:1 / 12), ...) { 59 image(matrix(as.matrix(arr784[-785]), nrow = 28)[, 28:1], col = col, ...) 60} 61 62# load image files 63load_image_file = function(filename) { 64 ret = list() 65 f = file(filename, 'rb') 66 readBin(f, 'integer', n = 1, size = 4, endian = 'big') 67 n = readBin(f, 'integer', n = 1, size = 4, endian = 'big') 68 nrow = readBin(f, 'integer', n = 1, size = 4, endian = 'big') 69 ncol = readBin(f, 'integer', n = 1, size = 4, endian = 'big') 70 x = readBin(f, 'integer', n = n * nrow * ncol, size = 1, signed = FALSE) 71 close(f) 72 data.frame(matrix(x, ncol = nrow * ncol, byrow = TRUE)) 73} 74 75# load label files 76load_label_file = function(filename) { 77 f = file(filename, 'rb') 78 readBin(f, 'integer', n = 1, size = 4, endian = 'big') 79 n = readBin(f, 'integer', n = 1, size = 4, endian = 'big') 80 y = readBin(f, 'integer', n = n, size = 1, signed = FALSE) 81 close(f) 82 y 83} 84 85# load images 86train = load_image_file("../input/mnistdt/train-images-idx3-ubyte") 87test = load_image_file("../input/mnistdt/t10k-images-idx3-ubyte") 88 89# load labels 90train$y = as.factor(load_label_file("../input/mnistdt/train-labels-idx1-ubyte")) 91test$y = as.factor(load_label_file("../input/mnistdt/t10k-labels-idx1-ubyte")) 92 93#数字3の最初の100個のデータ 94X <- train[train$y==3,][1:100,-785] 95#平均ベクトル 96mu.X = colMeans(X) 97show_digit( 255-mu.X) # 平均的な手書きの3の数字の図 98 99# 誤差Z 100Z <- t(apply(X, 1, function(x, m){x- m}, m= mu.X)) 101show_digit( Z[1,]) 102 103show_digit( Z[10,]) 104 105show_digit( Z[100,]) 106 107cov.Z <- cov(Z) 108dim(cov.Z) 109 110pca.Z <- eigen(cov.Z) 111 112show_digit( (255-X[1,])) # Fullの情報 113 114#k=50 115U.50 <- pca.Z$vectors[,1:50] 116Z1.50 <-t(pca.Z$vectors[,1:50])%*%Z[1,] %>% as.numeric 117UX.50 <- U.50%*%as.matrix(Z1.50,ncol=1) 118show_digit(255-(UX.50+as.matrix(mu.X))) 119 120#k=100 121U.100 <- pca.Z$vectors[,1:100] 122Z1.100 <-t(pca.Z$vectors[,1:100])%*%Z[1,] %>% as.numeric 123UX.100 <- U.100%*%as.matrix(Z1.100,ncol=1) 124show_digit(255- (UX.100+as.matrix(mu.X))) 125 126#k=150 127U.150 <- pca.Z$vectors[,1:150] 128Z1.150 <-t(pca.Z$vectors[,1:150])%*%Z[1,] %>% as.numeric 129UX.150 <- U.150%*%as.matrix(Z1.150,ncol=1) 130show_digit(255- (UX.150+as.matrix(mu.X))) 131#k=200 132U.200 <- pca.Z$vectors[,1:200] 133Z1.200 <-t(pca.Z$vectors[,1:200])%*%Z[1,] %>% as.numeric 134UX.200 <- U.200%*%as.matrix(Z1.200,ncol=1) 135show_digit(255-(UX.200+as.matrix(mu.X))) 136 137plot(1:784,pca.Z$values/sum(pca.Z$values),type="o",col=2, lwd=2, xlab="dimension",ylab="variance explained",cex=0.4) 138 139plot(1:784,cumsum(pca.Z$values)/sum(pca.Z$values),type="o",col=2, lwd=2, xlab="dimension",ylab="variance explained",cex=0.4) 140 141U <- pca.Z$vectors[,1:8] 142 143show_digit(255-U[,1]) 144show_digit(255-U[,2]) 145show_digit(255-U[,3]) 146show_digit(255-U[,4]) 147show_digit(255-U[,5]) 148show_digit(255-U[,6]) 149show_digit(255-U[,7]) 150show_digit(255-U[,8])
試したこと
上記の1つの数字を分類するコードで
X <- train[train$y==3,][1:100,-785] + train[train$y==4,][1:100,-785]
や
X <- train[train$y==3||4,][1:100,-785]
と試したりしましたが、3と4が混ざったような画像が出力されるだけで教師なし分類ができません。
補足情報(FW/ツールのバージョンなど)
kaggleのnotepad上のRエディタを使用しています。
どううまくいかないのかを書いてください。
加筆しました。3のみを主成分分析するコードを少し変えても教師なし分類した結果はでず、困っているということです。
sessionInfo()を実行した結果と、str(train[1,])の結果を貼れますか?
おそらく貼れますが、現在出先ですので1,2時間ほどお待ちください。すいません。
sessionInfo()とstr(train[1,])の結果です。
R version 3.6.0 (2019-04-26)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Debian GNU/Linux 9 (stretch)
Matrix products: default
BLAS/LAPACK: /usr/lib/libopenblasp-r0.2.19.so
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
[3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=C
[7] LC_PAPER=en_US.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] recommenderlab_0.2-5 proxy_0.4-23 arules_1.6-4
[4] Matrix_1.2-17 NMF_0.21.0 synchronicity_1.3.5
[7] bigmemory_4.5.33 rngtools_1.4 pkgmaker_0.27
[10] registry_0.5-1 MASS_7.3-51.4 gclus_1.3.2
[13] cluster_2.1.0 R.utils_2.9.0 R.oo_1.22.0
[16] R.methodsS3_1.7.1 forcats_0.4.0 stringr_1.4.0
[19] dplyr_0.8.3 purrr_0.3.2 readr_1.3.1
[22] tidyr_1.0.0 tibble_2.1.3 ggplot2_3.2.1.9000
[25] tidyverse_1.2.1 bigrquery_1.2.0 httr_1.4.1
loaded via a namespace (and not attached):
[1] nlme_3.1-141 fs_1.3.1 lubridate_1.7.4
[4] bit64_0.9-7 doParallel_1.0.15 RColorBrewer_1.1-2
[7] repr_1.0.1.9000 tools_3.6.0 backports_1.1.5
[10] R6_2.4.0 irlba_2.3.3 DBI_1.0.0
[13] colorspace_1.4-1 withr_2.1.2 tidyselect_0.2.5
[16] bit_1.1-14 compiler_3.6.0 cli_1.1.0
[19] rvest_0.3.4 Cairo_1.5-10 xml2_1.2.2
[22] scales_1.0.0 pbdZMQ_0.3-3 digest_0.6.21
[25] base64enc_0.1-3 pkgconfig_2.0.3 htmltools_0.4.0
[28] bibtex_0.4.2 rlang_0.4.0 readxl_1.3.1
[31] rstudioapi_0.10 recosystem_0.4.2 generics_0.0.2
[34] jsonlite_1.6 magrittr_1.5 Rcpp_1.0.2
[37] IRkernel_1.0.2.9000 munsell_0.5.0 lifecycle_0.1.0
[40] stringi_1.4.3 plyr_1.8.4 grid_3.6.0
[43] parallel_3.6.0 bigmemory.sri_0.1.3 crayon_1.3.4
[46] lattice_0.20-38 IRdisplay_0.7.0.9000 haven_2.1.1
[49] hms_0.5.1 zeallot_0.1.0 pillar_1.4.2
[52] uuid_0.1-2 reshape2_1.4.3 codetools_0.2-16
[55] glue_1.3.1 evaluate_0.14 getPass_0.2-2
[58] modelr_0.1.4 vctrs_0.2.0 foreach_1.4.7
[61] cellranger_1.1.0 gtable_0.3.0 assertthat_0.2.1
[64] gridBase_0.4-7 xtable_1.8-4 broom_0.5.2
[67] gargle_0.4.0 iterators_1.0.12
'data.frame': 1 obs. of 785 variables:
$ X1 : int 0
$ X2 : int 0
$ X3 : int 0
$ X4 : int 0
$ X5 : int 0
$ X6 : int 0
$ X7 : int 0
$ X8 : int 0
$ X9 : int 0
$ X10 : int 0
$ X11 : int 0
$ X12 : int 0
$ X13 : int 0
$ X14 : int 0
$ X15 : int 0
$ X16 : int 0
$ X17 : int 0
$ X18 : int 0
$ X19 : int 0
$ X20 : int 0
$ X21 : int 0
$ X22 : int 0
$ X23 : int 0
$ X24 : int 0
$ X25 : int 0
$ X26 : int 0
$ X27 : int 0
$ X28 : int 0
$ X29 : int 0
$ X30 : int 0
$ X31 : int 0
$ X32 : int 0
$ X33 : int 0
$ X34 : int 0
$ X35 : int 0
$ X36 : int 0
$ X37 : int 0
$ X38 : int 0
$ X39 : int 0
$ X40 : int 0
$ X41 : int 0
$ X42 : int 0
$ X43 : int 0
$ X44 : int 0
$ X45 : int 0
$ X46 : int 0
$ X47 : int 0
$ X48 : int 0
$ X49 : int 0
$ X50 : int 0
$ X51 : int 0
$ X52 : int 0
$ X53 : int 0
$ X54 : int 0
$ X55 : int 0
$ X56 : int 0
$ X57 : int 0
$ X58 : int 0
$ X59 : int 0
$ X60 : int 0
$ X61 : int 0
$ X62 : int 0
$ X63 : int 0
$ X64 : int 0
$ X65 : int 0
$ X66 : int 0
$ X67 : int 0
$ X68 : int 0
$ X69 : int 0
$ X70 : int 0
$ X71 : int 0
$ X72 : int 0
$ X73 : int 0
$ X74 : int 0
$ X75 : int 0
$ X76 : int 0
$ X77 : int 0
$ X78 : int 0
$ X79 : int 0
$ X80 : int 0
$ X81 : int 0
$ X82 : int 0
$ X83 : int 0
$ X84 : int 0
$ X85 : int 0
$ X86 : int 0
$ X87 : int 0
$ X88 : int 0
$ X89 : int 0
$ X90 : int 0
$ X91 : int 0
$ X92 : int 0
$ X93 : int 0
$ X94 : int 0
$ X95 : int 0
$ X96 : int 0
$ X97 : int 0
$ X98 : int 0
$ X99 : int 0
[list output truncated]
回答1件
あなたの回答
tips
プレビュー