tensorflow 汽车分类

网友投稿 551 2022-10-11

tensorflow 汽车分类

tensorflow 汽车分类

1.data_processing.py

import pandas as pdfrom urllib.request import urlretrievedef load_data(download=True): # download data from : if download: data_path, _ = urlretrieve(""car.csv") print("Downloaded to car.csv") # use pandas to view the data structure col_names = ["buying", "maint", "doors", "persons", "lug_boot", "safety", "class"] data = pd.read_csv("car.csv", names=col_names) return datadef convert2onehot(data): # covert data to onehot representation return pd.get_dummies(data, prefix=data.columns)if __name__ == "__main__": data = load_data(download=True) new_data = convert2onehot(data) print(data.head()) print("\nNum of data: ", len(data), "\n") # 1728 # view data values for name in data.keys(): print(name, pd.unique(data[name])) print("\n", new_data.head(2)) new_data.to_csv("car_onehot.csv", index=False)

2. ​​model.py​​

import numpy as npimport tensorflow as tfimport matplotlib.pyplot as pltimport data_processingdata = data_processing.load_data(download=True)new_data = data_processing.convert2onehot(data)# prepare training datanew_data = new_data.values.astype(np.float32) # change to numpy array and float32np.random.shuffle(new_data)sep = int(0.7*len(new_data))train_data = new_data[:sep] # training data (70%)test_data = new_data[sep:] # test data (30%)# build networktf_input = tf.placeholder(tf.float32, [None, 25], "input")tfx = tf_input[:, :21]tfy = tf_input[:, 21:]l1 = tf.layers.dense(tfx, 128, tf.nn.relu, name="l1")l2 = tf.layers.dense(l1, 128, tf.nn.relu, name="l2")out = tf.layers.dense(l2, 4, name="l3")prediction = tf.nn.softmax(out, name="pred")loss = tf.losses.softmax_cross_entropy(onehot_labels=tfy, logits=out)accuracy = tf.metrics.accuracy( # return (acc, update_op), and create 2 local variables labels=tf.argmax(tfy, axis=1), predictions=tf.argmax(out, axis=1),)[1]opt = tf.train.GradientDescentOptimizer(learning_rate=0.1)train_op = opt.minimize(loss)sess = tf.Session()sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()))# trainingplt.ion()fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))accuracies, steps = [], []for t in range(4000): # training batch_index = np.random.randint(len(train_data), size=32) sess.run(train_op, {tf_input: train_data[batch_index]}) if t % 50 == 0: # testing acc_, pred_, loss_ = sess.run([accuracy, prediction, loss], {tf_input: test_data}) accuracies.append(acc_) steps.append(t) print("Step: %i" % t,"| Accurate: %.2f" % acc_,"| Loss: %.2f" % loss_,) # visualize testing ax1.cla() for c in range(4): bp = ax1.bar(c+0.1, height=sum((np.argmax(pred_, axis=1) == c)), width=0.2, color='red') bt = ax1.bar(c-0.1, height=sum((np.argmax(test_data[:, 21:], axis=1) == c)), width=0.2, color='blue') ax1.set_xticks(range(4), ["accepted", "good", "unaccepted", "very good"]) ax1.legend(handles=[bp, bt], labels=["prediction", "target"]) ax1.set_ylim((0, 400)) ax2.cla() ax2.plot(steps, accuracies, label="accuracy") ax2.set_ylim(ymax=1) ax2.set_ylabel("accuracy") plt.pause(0.01)plt.ioff()plt.show()

​​https://github.com/MorvanZhou/train-classifier-from-scratch​​

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:tensorflow 保存读取
下一篇:QQQ TLT 股债平衡策略
相关文章

 发表评论

暂时没有评论,来抢沙发吧~