机器学习课的第一次作业:用任意算法解决一个有监督的分类问题。数据为手机接收到的7个不同无线网的信号强度,预测的结果为房间号(1-4),数据共有2000组。部分数据如下表所示:
1 | 2 | 3 | 4 | 5 | 6 | 7 | 房间号 |
---|---|---|---|---|---|---|---|
-64 | -56 | -61 | -66 | -71 | -82 | -81 | 1 |
-68 | -57 | -61 | -65 | -71 | -85 | -85 | 1 |
-48 | -63 | -55 | -42 | -71 | -81 | -79 | 2 |
-44 | -63 | -67 | -46 | -63 | -81 | -78 | 2 |
-57 | -58 | -56 | -55 | -71 | -84 | -87 | 3 |
-71 | -59 | -52 | -62 | -48 | -91 | -83 | 4 |
-57 | -53 | -51 | -57 | -48 | -88 | -85 | 4 |
思路
使用三层MLP(全连接网络)
完整代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
# encoding=utf-8 import pandas as pd import tensorflow as tf df = pd.read_csv('wifi_localization.txt', sep='\t', header=None) # 读取数据文件 train = df.sample(frac=0.9) # 随机选取90%作为训练集 test = df[~df.index.isin(train.index)] # 剩下的10%作为测试集 train_values = train.values # 转成矩阵的形式 test_values = test.values # 转成矩阵的形式 test_x = test_values[:, :-1] # 前7列是信号强度 test_y = test_values[:, -1] # 最后一列是预测值 input_units = 7 hidden_units = 12 output_units = 4 batch_size = 100 learning_rate = 0.001 epoch = 50 train_size = 1800 ''' 数据和标签和输入 ''' with tf.name_scope('input'): x = tf.placeholder(tf.float32, [None, input_units]) # 7个输入值 with tf.name_scope('label'): y_ = tf.placeholder(tf.int32, [None]) # 预测值有4个类别 y_onehot = tf.one_hot(y_ - 1, 4) # 因为label是1-4,所以要减去1 is_training = tf.placeholder(tf.bool) ''' 隐藏层 ''' with tf.name_scope('layer_1'): w_1 = tf.Variable(tf.truncated_normal([input_units, hidden_units], stddev=0.1)) # 隐藏层权重 b_1 = tf.Variable(tf.zeros([hidden_units])) # 隐藏层偏移 h = tf.matmul(x, w_1) + b_1 # 全连接网络,使用relu作为激活函数 with tf.name_scope('BN_1'): layer = tf.layers.batch_normalization(h, training=is_training) # 使用BN进行正则化 layer = tf.nn.relu(layer) ''' 输出层 ''' with tf.name_scope('layer_2'): w_2 = tf.Variable(tf.truncated_normal([hidden_units, output_units], stddev=0.1)) # 输出层权重 b_2 = tf.Variable(tf.zeros([output_units])) # 输出层偏移 layer = tf.matmul(layer, w_2) + b_2 with tf.name_scope('BN_2'): logits = tf.layers.batch_normalization(layer, training=is_training) # 使用BN进行正则化 ''' 定义损失和优化器 ''' with tf.name_scope('softmax_cross_entropy'): loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_onehot)) # 损失使用softmax交叉熵定义 tf.summary.scalar('loss', loss) with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss) # 优化器 with tf.name_scope('accuracy'): correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(y_onehot, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) tf.summary.scalar('val_acc', accuracy) with tf.Session() as sess: merged = tf.summary.merge_all() writer = tf.summary.FileWriter("logs/", sess.graph) # tensorboard可视化 sess.run(tf.global_variables_initializer()) for _ in range(epoch): for i in range(train_size // batch_size): # 除不尽的会舍去 batch = train_values[i * batch_size:i * batch_size + batch_size] batch_x = batch[:, :-1] # 前7列是信号强度 batch_y = batch[:, -1] # 最后一列是预测值 opy, summary = sess.run([optimizer, merged], feed_dict={x: batch_x, y_: batch_y, is_training: True}) step = _ * 28 + i writer.add_summary(summary, step) if i % 7 == 0: trainloss = sess.run(loss, feed_dict={x: batch_x, y_: batch_y, is_training: False}) print('Epoch:%d/%d, Iter:%d/%d, loss:%f' % (_, epoch, i, train_size // batch_size, trainloss)) acc = sess.run(accuracy, feed_dict={x: test_x, y_: test_y, is_training: False}) if i % 20 == 0 and _ % 5 == 0: print('Batch:%d/%d, Iter:%d/%d, Validation acc:%f' % (_, epoch, i, train_size // batch_size, acc)) |
运行结果
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
Epoch:0/40, Iter:0/28, loss:2.349822 Epoch:0/40, Iter:0/28, Validation acc:0.280000 Epoch:0/40, Iter:7/28, loss:2.129261 Epoch:0/40, Iter:14/28, loss:2.130534 Epoch:0/40, Iter:20/28, Validation acc:0.260000 Epoch:0/40, Iter:21/28, loss:2.416539 Epoch:1/40, Iter:0/28, loss:2.350163 Epoch:1/40, Iter:7/28, loss:2.406591 Epoch:1/40, Iter:14/28, loss:2.599873 Epoch:1/40, Iter:21/28, loss:2.445563 Epoch:2/40, Iter:0/28, loss:2.170509 Epoch:2/40, Iter:7/28, loss:2.029882 Epoch:2/40, Iter:14/28, loss:2.228309 Epoch:2/40, Iter:21/28, loss:2.061807 Epoch:3/40, Iter:0/28, loss:1.749540 Epoch:3/40, Iter:7/28, loss:1.604661 Epoch:3/40, Iter:14/28, loss:1.763323 Epoch:3/40, Iter:21/28, loss:1.611572 Epoch:4/40, Iter:0/28, loss:1.411176 Epoch:4/40, Iter:7/28, loss:1.302577 Epoch:4/40, Iter:14/28, loss:1.370307 Epoch:4/40, Iter:21/28, loss:1.286783 Epoch:5/40, Iter:0/28, loss:1.115393 Epoch:5/40, Iter:0/28, Validation acc:0.430000 Epoch:5/40, Iter:7/28, loss:1.061099 Epoch:5/40, Iter:14/28, loss:1.086004 Epoch:5/40, Iter:20/28, Validation acc:0.660000 Epoch:5/40, Iter:21/28, loss:1.068062 Epoch:6/40, Iter:0/28, loss:0.949269 Epoch:6/40, Iter:7/28, loss:0.946468 Epoch:6/40, Iter:14/28, loss:0.951335 Epoch:6/40, Iter:21/28, loss:0.980757 Epoch:7/40, Iter:0/28, loss:0.887072 Epoch:7/40, Iter:7/28, loss:0.899593 Epoch:7/40, Iter:14/28, loss:0.884317 Epoch:7/40, Iter:21/28, loss:0.931468 Epoch:8/40, Iter:0/28, loss:0.843412 Epoch:8/40, Iter:7/28, loss:0.856387 Epoch:8/40, Iter:14/28, loss:0.839142 Epoch:8/40, Iter:21/28, loss:0.891715 Epoch:9/40, Iter:0/28, loss:0.804764 Epoch:9/40, Iter:7/28, loss:0.822369 Epoch:9/40, Iter:14/28, loss:0.797064 Epoch:9/40, Iter:21/28, loss:0.841469 Epoch:10/40, Iter:0/28, loss:0.776134 Epoch:10/40, Iter:0/28, Validation acc:0.915000 Epoch:10/40, Iter:7/28, loss:0.780300 Epoch:10/40, Iter:14/28, loss:0.768762 Epoch:10/40, Iter:20/28, Validation acc:0.930000 Epoch:10/40, Iter:21/28, loss:0.788761 Epoch:11/40, Iter:0/28, loss:0.738675 Epoch:11/40, Iter:7/28, loss:0.737966 Epoch:11/40, Iter:14/28, loss:0.739051 Epoch:11/40, Iter:21/28, loss:0.741672 Epoch:12/40, Iter:0/28, loss:0.703390 Epoch:12/40, Iter:7/28, loss:0.700485 Epoch:12/40, Iter:14/28, loss:0.704121 Epoch:12/40, Iter:21/28, loss:0.699641 Epoch:13/40, Iter:0/28, loss:0.672843 Epoch:13/40, Iter:7/28, loss:0.666568 Epoch:13/40, Iter:14/28, loss:0.665154 Epoch:13/40, Iter:21/28, loss:0.652875 Epoch:14/40, Iter:0/28, loss:0.638008 Epoch:14/40, Iter:7/28, loss:0.627676 Epoch:14/40, Iter:14/28, loss:0.622130 Epoch:14/40, Iter:21/28, loss:0.606235 Epoch:15/40, Iter:0/28, loss:0.602096 Epoch:15/40, Iter:0/28, Validation acc:0.935000 Epoch:15/40, Iter:7/28, loss:0.590936 Epoch:15/40, Iter:14/28, loss:0.580687 Epoch:15/40, Iter:20/28, Validation acc:0.940000 Epoch:15/40, Iter:21/28, loss:0.563828 Epoch:16/40, Iter:0/28, loss:0.565652 Epoch:16/40, Iter:7/28, loss:0.552791 Epoch:16/40, Iter:14/28, loss:0.539287 Epoch:16/40, Iter:21/28, loss:0.525752 Epoch:17/40, Iter:0/28, loss:0.526950 Epoch:17/40, Iter:7/28, loss:0.516006 Epoch:17/40, Iter:14/28, loss:0.500154 Epoch:17/40, Iter:21/28, loss:0.492598 Epoch:18/40, Iter:0/28, loss:0.492676 Epoch:18/40, Iter:7/28, loss:0.482581 Epoch:18/40, Iter:14/28, loss:0.466553 Epoch:18/40, Iter:21/28, loss:0.459159 Epoch:19/40, Iter:0/28, loss:0.461332 Epoch:19/40, Iter:7/28, loss:0.450579 Epoch:19/40, Iter:14/28, loss:0.433754 Epoch:19/40, Iter:21/28, loss:0.424034 Epoch:20/40, Iter:0/28, loss:0.426153 Epoch:20/40, Iter:0/28, Validation acc:0.935000 Epoch:20/40, Iter:7/28, loss:0.414547 Epoch:20/40, Iter:14/28, loss:0.396597 Epoch:20/40, Iter:20/28, Validation acc:0.960000 Epoch:20/40, Iter:21/28, loss:0.392582 Epoch:21/40, Iter:0/28, loss:0.396746 Epoch:21/40, Iter:7/28, loss:0.385771 Epoch:21/40, Iter:14/28, loss:0.366790 Epoch:21/40, Iter:21/28, loss:0.366531 Epoch:22/40, Iter:0/28, loss:0.369163 Epoch:22/40, Iter:7/28, loss:0.352794 Epoch:22/40, Iter:14/28, loss:0.337169 Epoch:22/40, Iter:21/28, loss:0.340527 Epoch:23/40, Iter:0/28, loss:0.341674 Epoch:23/40, Iter:7/28, loss:0.325479 Epoch:23/40, Iter:14/28, loss:0.315487 Epoch:23/40, Iter:21/28, loss:0.322067 Epoch:24/40, Iter:0/28, loss:0.323417 Epoch:24/40, Iter:7/28, loss:0.307828 Epoch:24/40, Iter:14/28, loss:0.301278 Epoch:24/40, Iter:21/28, loss:0.307972 Epoch:25/40, Iter:0/28, loss:0.310938 Epoch:25/40, Iter:0/28, Validation acc:0.975000 Epoch:25/40, Iter:7/28, loss:0.293329 Epoch:25/40, Iter:14/28, loss:0.287589 Epoch:25/40, Iter:20/28, Validation acc:0.980000 Epoch:25/40, Iter:21/28, loss:0.292356 Epoch:26/40, Iter:0/28, loss:0.295491 Epoch:26/40, Iter:7/28, loss:0.276965 Epoch:26/40, Iter:14/28, loss:0.273189 Epoch:26/40, Iter:21/28, loss:0.277424 Epoch:27/40, Iter:0/28, loss:0.281996 Epoch:27/40, Iter:7/28, loss:0.263847 Epoch:27/40, Iter:14/28, loss:0.260707 Epoch:27/40, Iter:21/28, loss:0.264864 Epoch:28/40, Iter:0/28, loss:0.268509 Epoch:28/40, Iter:7/28, loss:0.249932 Epoch:28/40, Iter:14/28, loss:0.245355 Epoch:28/40, Iter:21/28, loss:0.251433 Epoch:29/40, Iter:0/28, loss:0.259525 Epoch:29/40, Iter:7/28, loss:0.241869 Epoch:29/40, Iter:14/28, loss:0.238430 Epoch:29/40, Iter:21/28, loss:0.244196 Epoch:30/40, Iter:0/28, loss:0.252729 Epoch:30/40, Iter:0/28, Validation acc:0.980000 Epoch:30/40, Iter:7/28, loss:0.234259 Epoch:30/40, Iter:14/28, loss:0.231694 Epoch:30/40, Iter:20/28, Validation acc:0.980000 Epoch:30/40, Iter:21/28, loss:0.234765 Epoch:31/40, Iter:0/28, loss:0.245686 Epoch:31/40, Iter:7/28, loss:0.228102 Epoch:31/40, Iter:14/28, loss:0.229153 Epoch:31/40, Iter:21/28, loss:0.229456 Epoch:32/40, Iter:0/28, loss:0.240966 Epoch:32/40, Iter:7/28, loss:0.221393 Epoch:32/40, Iter:14/28, loss:0.223557 Epoch:32/40, Iter:21/28, loss:0.220024 Epoch:33/40, Iter:0/28, loss:0.234356 Epoch:33/40, Iter:7/28, loss:0.216340 Epoch:33/40, Iter:14/28, loss:0.222685 Epoch:33/40, Iter:21/28, loss:0.216450 Epoch:34/40, Iter:0/28, loss:0.235182 Epoch:34/40, Iter:7/28, loss:0.212004 Epoch:34/40, Iter:14/28, loss:0.218237 Epoch:34/40, Iter:21/28, loss:0.208279 Epoch:35/40, Iter:0/28, loss:0.226007 Epoch:35/40, Iter:0/28, Validation acc:0.980000 Epoch:35/40, Iter:7/28, loss:0.203423 Epoch:35/40, Iter:14/28, loss:0.211753 Epoch:35/40, Iter:20/28, Validation acc:0.985000 Epoch:35/40, Iter:21/28, loss:0.199003 Epoch:36/40, Iter:0/28, loss:0.217677 Epoch:36/40, Iter:7/28, loss:0.194841 Epoch:36/40, Iter:14/28, loss:0.206121 Epoch:36/40, Iter:21/28, loss:0.189889 Epoch:37/40, Iter:0/28, loss:0.212066 Epoch:37/40, Iter:7/28, loss:0.189514 Epoch:37/40, Iter:14/28, loss:0.211716 Epoch:37/40, Iter:21/28, loss:0.186845 Epoch:38/40, Iter:0/28, loss:0.211035 Epoch:38/40, Iter:7/28, loss:0.188129 Epoch:38/40, Iter:14/28, loss:0.213900 Epoch:38/40, Iter:21/28, loss:0.183004 Epoch:39/40, Iter:0/28, loss:0.207326 Epoch:39/40, Iter:7/28, loss:0.184741 Epoch:39/40, Iter:14/28, loss:0.210827 Epoch:39/40, Iter:21/28, loss:0.177730 |