NLP(8)--利用RNN实现多分类任务

前言

仅记录学习过程,有问题欢迎讨论

循环神经网络RNN(recurrent neural network):
  • 主要思想:将整个序列划分成多个时间步,将每一个时间步的信息依次输入模型,同时将模型输出的结果传给下一个时间步
  • 自带了tanh的激活函数

代码

发现RNN效率高很多

import json
import random

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.utils.data as Data

"""
构建一个 用RNN实现的 判断某个字符的位置 的任务

5 分类任务 判断 a出现的位置 返回index +1 or -1
"""


class TorchModel(nn.Module):
    def __init__(self, sentence_length, hidden_size, vocab, input_dim, output_size):
        super(TorchModel, self).__init__()
        #
        self.emb = nn.Embedding(len(vocab) + 1, input_dim)
        self.rnn = nn.RNN(input_dim, hidden_size, batch_first=True)

        self.pool = nn.MaxPool1d(sentence_length)
        self.leaner = nn.Linear(hidden_size, output_size)
        self.loss = nn.functional.cross_entropy

    def forward(self, x, y=None):
        # x = 15 * 4
        x = self.emb(x)  # output = 15 * 4 * 10
        x, h = self.rnn(x)  # output = 15 * 4 * 20 h = 1*15*20
        x = self.pool(x.transpose(1, 2)).squeeze()  # output = 15 * 20 * (1,被去除)
        y_pred = self.leaner(x)  # output = 15 * 5
        if y is not None:
            return self.loss(y_pred, y)
        else:
            return y_pred

    # 创建字符集 只有6个 希望a出现的概率大点


def build_vocab():
    chars = "abcdef"
    vocab = {}
    for index, char in enumerate(chars):
        vocab[char] = index + 1
    # vocab['unk'] = len(vocab) + 1
    return vocab


# 构建样本集
def build_dataset(vocab, data_size, sentence_length):
    dataset_x = []
    dataset_y = []
    for i in range(data_size):
        x, y = build_simple(vocab, sentence_length)
        dataset_x.append(x)
        dataset_y.append(y)
    return torch.LongTensor(dataset_x), torch.LongTensor(dataset_y)


# 构建样本
def build_simple(vocab, sentence_length):
    # 随机生成 长度为4的字符串
    x = [random.choice(list(vocab.keys())) for _ in range(sentence_length)]
    if x.count('a') != 0:
        y = x.index('a')
    else:
        y = 4

    # 转化为 数字
    x = [vocab[char] for char in list(x)]
    return x, y


def main():
    batch_size = 15
    simple_size = 500
    vocab = build_vocab()
    # 每个样本的长度为4
    sentence_length = 4
    # 样本的向量维度为10
    input_dim = 10
    # rnn的隐藏层 随便设置为20
    hidden_size = 20
    # 5 分类任务
    output_size = 5
    # 学习率
    lr = 0.02
    # 轮次
    epoch_size = 25
    model = TorchModel(sentence_length, hidden_size, vocab, input_dim, output_size)

    # 优化函数
    optim = torch.optim.Adam(model.parameters(), lr=lr)
    # 样本
    x, y = build_dataset(vocab, simple_size, sentence_length)
    dataset = Data.TensorDataset(x, y)
    dataiter = Data.DataLoader(dataset, batch_size)
    for epoch in range(epoch_size):
        epoch_loss = []
        model.train()
        for x, y_true in dataiter:
            loss = model(x, y_true)
            loss.backward()
            optim.step()
            optim.zero_grad()
            epoch_loss.append(loss.item())
        print("第%d轮 loss = %f" % (epoch + 1, np.mean(epoch_loss)))
        # evaluate
        acc = evaluate(model, vocab, sentence_length)  # 测试本轮模型结果

    return


# 评估效果
def evaluate(model, vocab, sentence_length):
    model.eval()
    x, y = build_dataset(vocab, 200, sentence_length)
    correct, wrong = 0, 0
    with torch.no_grad():
        y_pred = model(x)
        for y_p, y_t in zip(y_pred, y):  # 与真实标签进行对比
            if int(torch.argmax(y_p)) == int(y_t):
                correct += 1  # 正样本判断正确
            else:
                wrong += 1
    print("正确预测个数:%d / %d, 正确率:%f" % (correct, correct + wrong, correct / (correct + wrong)))
    return correct / (correct + wrong)


if __name__ == '__main__':
    main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/573717.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【C++航海王:追寻罗杰的编程之路】C++11(二)

目录 C11(上) 1 -> STL中的一些变化 2 -> 右值引用和移动语义 2.1 -> 左值引用和右值引用 2.2 -> 左值引用与右值引用比较 2.3 -> 右值引用使用场景与意义 2.4 -> 右值引用引用左值及其更深入的使用场景分析 2.5 -> 完美转发 C11(上) 1 -> STL…

4 -25

1 100个英语单词两篇六级阅读 2 cf补题; 3 仿b站项目看源码 debug分析业务。 上了一天课,晚上去健身。 物理备课,周六去上课腻。 五一回来毛泽东思想期末考试,概率论期中考试。

轻松搭建MySQL 8.0:Ubuntu上的完美指南

欢迎来到我的博客,代码的世界里,每一行都是一个故事 轻松搭建MySQL 8.0:Ubuntu上的完美指南 前言脚本编写脚本实现部署过程参数成功页面 彩蛋坏蛋解决方法 前言 在数字化时代,数据就像是我们的宝藏,而MySQL数据库就是…

【Qt 学习笔记】Qt常用控件 | 输入类控件 | Text Edit的使用及说明

博客主页:Duck Bro 博客主页系列专栏:Qt 专栏关注博主,后期持续更新系列文章如果有错误感谢请大家批评指出,及时修改感谢大家点赞👍收藏⭐评论✍ Qt常用控件 | 输入类控件 | Text Edit的使用及说明 文章编号&#xff…

【题解】牛客挑战赛 71 - A 和的期望

原题链接 https://ac.nowcoder.com/acm/problem/264714 思路分析 快速幂求逆元 费马小定理: a MOD − 1 ≡ 1 ( m o d M O D ) a^{\text{MOD}-1} \equiv 1 \pmod{MOD} aMOD−1≡1(modMOD),可以转换为 a ⋅ a MOD − 2 ≡ 1 ( m o d M O D ) ① a \cd…

4.24总结

对部分代码进行了修改,将一些代码封装成方法,实现了头像功能,通过FileInputStream将本地的图片写入,再通过FileOutputStream拷贝到服务端的文件夹中,并将服务端的文件路径存入数据库中

Linear Blend Skinning (LBS)线性混合蒙皮

LBS是CG的基础概念之一。 Linear Blend Skinning: linearly blend the results of the vertex transformed rigidly with each bone. LBS:线性地混合顶点根据每个骨骼的刚性变形结果。 这个场景应用在哪里呢? 假如我们重建好一个人体,现在用…

水位监测识别摄像机

水位监测识别摄像机是一种利用人工智能技术进行水位监测的智能设备,其作用是监测水体的水位变化并识别潜在的水灾危险,以提供准确数据和及时预警,帮助保护人民生命财产安全。这种摄像机通过高清摄像头实时捕捉水体的图像,然后利用…

Coursera: An Introduction to American Law 学习笔记 Week 03: Property Law

An Introduction to American Law 本文是 https://www.coursera.org/programs/career-training-for-nevadans-k7yhc/learn/american-law 这门课的学习笔记。 文章目录 An Introduction to American LawInstructors Week 03: Property LawKey Property Law TermsSupplemental Re…

【yolo算法道路井盖检测】

yolo算法道路井盖检测 数据集和模型yolov8道路井盖-下水道井盖检测训练模型数据集pyqt界面yolov8道路井盖-下水道井盖检测训练模型数据集 算法原理 1. 数据集准备与增强 数据采集:使用行车记录仪或其他设备收集道路井盖的图像数据。数据标注:对收集到…

如何提交已暂存的更改到本地仓库?

文章目录 如何提交已暂存的更改到本地Git仓库?步骤1:确认并暂存更改步骤2:提交暂存的更改到本地仓库 如何提交已暂存的更改到本地Git仓库? 在Git版本控制系统中,当你对项目文件进行修改后,首先需要将这些更…

大学生在线考试|基于SprinBoot+vue的在线试题库系统系统(源码+数据库+文档)

大学生在线考试目录 基于SprinBootvue的在线试题库系统系统 一、前言 二、系统设计 三、系统功能设计 试卷管理 试题管理 考试管理 错题本 考试记录 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取: 博主介绍&#…

valgrind,memcheck的使用

一,valgrind介绍 ​ valgrind是一个开源的,检测内存泄漏的工具,通常在linux下使用,除此之外,他还能检测内存管理错误,线程bug等错误。粗浅的来讲,valgrind由两部分构成,一部分用来模…

每日OJ题_BFS解决拓扑排序③_力扣LCR 114. 火星词典

目录 力扣LCR 114. 火星词典 解析代码 力扣LCR 114. 火星词典 LCR 114. 火星词典 难度 困难 现有一种使用英语字母的外星文语言,这门语言的字母顺序与英语顺序不同。 给定一个字符串列表 words ,作为这门语言的词典,words 中的字符串已…

SpringBoot-无法从static上下文引用同非static方法

1.问题 说明:无法从static上下文引用同非static方法。 2.解决 说明:return后面的语句中,调用的是变量的方法,而不是类型的方法!

Pytorch学习之路 - CNN

目录 理论预热 实践 构建卷积神经网络 卷积网络模块构建 实战:基于经典网络架构训练图像分类模型 数据预处理部分: 网络模块设置: 网络模型保存与测试 实践 制作好数据源: 图片 标签 展示下数据 加载models中提供的模…

CMake:相关概念与使用入门(一)

1、Cmake概述 Cmake是一个项目构建工具,并且是跨平台的。 关于项目构建我们所熟知的有Makefile,然后通过make命令进行项目的构建,并且大多数是IDE都继承了make,比如:VS的nmake,Linux下的GNU make、Qt的qma…

OpenCV与AI深度学习 | 如何使用YOLOv9分割图像中的对象

本文来源公众号“OpenCV与AI深度学习”,仅用于学术分享,侵权删,干货满满。 原文链接:如何使用YOLOv9分割图像中的对象 1 介绍 在我们之前的文章中,我们使用 YOLOv8 探索了令人兴奋的对象分割世界。分割使计算机视觉比…

Linux进程详解:进程优先级,调度算法,进程特性

文章目录 进程优先级Linux下的调度算法进程特性 进程优先级 进程要访问某种软硬件资源,此时进程需要通过一定的方式(排队),来确认享受某种资源的先后顺序。 优先级是确认先后问题,权限是确认能不能的问题。 资源有限…

5个常见的前端手写功能:浅拷贝与深拷贝、函数柯里化、数组扁平化、数组去重、手写类型判断函数

浅拷贝与深拷贝 浅拷贝 浅拷贝是创建一个新对象,这个对象有着原始对象属性值的一份精确拷贝。如果属性是基本类型,拷贝的就是基本类型的值,如果属性是引用类型,拷贝的就是内存地址,所以如果其中一个对象改变了这个地…