博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
LeNet 分类 FashionMNIST
阅读量:7239 次
发布时间:2019-06-29

本文共 3151 字,大约阅读时间需要 10 分钟。

import mxnet as mxfrom mxnet import autograd, gluon, init, ndfrom mxnet.gluon import loss as gloss, nnfrom mxnet.gluon import data as gdataimport timeimport sysnet = nn.Sequential()net.add(nn.Conv2D(channels=6, kernel_size=5, activation='sigmoid'),        nn.MaxPool2D(pool_size=2, strides=2),        nn.Conv2D(channels=16, kernel_size=5, activation='sigmoid'),        nn.MaxPool2D(pool_size=2, strides=2),        # Dense 会默认将(批量大小,通道,高,宽)形状的输入转换成        # (批量大小,通道 * 高 * 宽)形状的输入。        nn.Dense(120, activation='sigmoid'),        nn.Dense(84, activation='sigmoid'),        nn.Dense(10))X = nd.random.uniform(shape=(1, 1, 28, 28))net.initialize()for layer in net:    X = layer(X)    print(layer.name, 'output shape:\t', X.shape)# batch_size = 256# train_iter, test_iter = gb.load_data_fashion_mnist(batch_size=batch_size)mnist_train = gdata.vision.FashionMNIST(train=True)mnist_test = gdata.vision.FashionMNIST(train=False)batch_size = 256transformer = gdata.vision.transforms.ToTensor()if sys.platform.startswith('win'):    num_workers = 0else:    num_workers = 4# 小批量数据迭代器(在cpu上)train_iter = gdata.DataLoader(mnist_train.transform_first(transformer), batch_size=batch_size, shuffle=True,                              num_workers=num_workers)test_iter = gdata.DataLoader(mnist_test.transform_first(transformer), batch_size=batch_size, shuffle=False,                             num_workers=num_workers)def try_gpu4():    try:        ctx = mx.gpu()        _ = nd.zeros((1,), ctx=ctx)    except mx.base.MXNetError:        ctx = mx.cpu()    return ctxctx = try_gpu4()def accuracy(y_hat,y):    return (y_hat.argmax(axis=1) == y.astype('float32')).mean().asscalar()def evaluate_accuracy(data_iter, net, ctx):    acc = nd.array([0], ctx=ctx)    for X, y in data_iter:        # 如果 ctx 是 GPU,将数据复制到 GPU 上。        X, y = X.as_in_context(ctx), y.as_in_context(ctx)        acc += accuracy(net(X), y)    return acc.asscalar() / len(data_iter)def train(net, train_iter, test_iter, batch_size, trainer, ctx,              num_epochs):    print('training on', ctx)    loss = gloss.SoftmaxCrossEntropyLoss()    for epoch in range(num_epochs):        train_l_sum, train_acc_sum, start = 0, 0, time.time()        for X, y in train_iter:            X, y = X.as_in_context(ctx), y.as_in_context(ctx)            with autograd.record():                y_hat = net(X)                l = loss(y_hat, y)            l.backward()            trainer.step(batch_size)            train_l_sum += l.mean().asscalar()            train_acc_sum += accuracy(y_hat, y)        test_acc = evaluate_accuracy(test_iter, net, ctx)        print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, '              'time %.1f sec' % (epoch + 1, train_l_sum / len(train_iter),                                 train_acc_sum / len(train_iter),                                 test_acc, time.time() - start))lr, num_epochs = 0.9, 200net.initialize(force_reinit=True, ctx=ctx, init=init.Xavier())trainer = gluon.Trainer(net.collect_params(), 'sgd', {
'learning_rate': lr})train(net, train_iter, test_iter, batch_size, trainer, ctx, num_epochs)

 

转载于:https://www.cnblogs.com/TreeDream/p/10044055.html

你可能感兴趣的文章
轻量高效的开源JavaScript插件和库
查看>>
CSS3-background-clip+background-origin
查看>>
linux yum 安装mysql
查看>>
种下一棵树:有旋Treap
查看>>
设计模式——(Abstract Factory)抽象工厂“改正为简单工厂”
查看>>
图灵热点之阅读篇——七月图书推荐
查看>>
【转载】acedSSGet()函数用法详解
查看>>
bzoj5407: girls
查看>>
BootStrap selectpicker后台动态绑定数据
查看>>
【转】正则基础之——贪婪与非贪婪模式
查看>>
关于 android.net.conn.CONNECTIVITY_CHANGE 7.0之后取消
查看>>
自动化测试的理解
查看>>
微信小程序事件
查看>>
Unity加载第三方C# DLL时,解析不能删除的问题。
查看>>
空间直角坐标系、大地坐标系、平面坐标系、高斯平面直角坐标系(转)
查看>>
Java并发编程-可重入锁
查看>>
MySQL5.7.9压缩包安装配置
查看>>
068、Calico的网络结构是什么?(2019-04-11 周四)
查看>>
rails文件夹介绍
查看>>
c#连接mysql
查看>>