手写系列——MoE网络

news/2025/2/26 5:48:10

参考:

MOE原理解释及从零实现一个MOE(专家混合模型)_moe代码-CSDN博客

MoE环游记:1、从几何意义出发 - 科学空间|Scientific Spaces 

深度学习之图像分类(二十八)-- Sparse-MLP(MoE)网络详解_sparse moe-CSDN博客

深度学习之图像分类(二十九)-- Sparse-MLP网络详解_sparse mlp-CSDN博客 

 

代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 超参数设置
num_experts = 4      # 专家数量
top_k = 2            # 激活专家数
# input_dim = 3072     # CIFAR-10图像展平后维度(32x32x3)
input_dim = 64 * 8 * 8
hidden_dim = 512     # 专家网络隐藏层维度
num_classes = 10     # 分类类别数

# MoE层实现(文献[5][7])
class SparseMoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            ) for _ in range(num_experts)])
        
        self.gate = nn.Sequential(
            nn.Linear(input_dim, num_experts),
            nn.Softmax(dim=1)
        )
        
        # 负载均衡参数(文献[4][7])
        self.balance_loss_weight = 0.01
        self.register_buffer('expert_counts', torch.zeros(num_experts))

    def forward(self, x):
        # 门控计算
        gate_scores = self.gate(x)  # [B, num_experts]
        
        # Top-k选择(文献[5])
        topk_scores, topk_indices = torch.topk(gate_scores, top_k, dim=1)
        mask = F.one_hot(topk_indices, num_experts).float().sum(dim=1)
        
        # 专家输出聚合
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
        selected_experts = expert_outputs.gather(1, topk_indices.unsqueeze(-1).expand(-1, -1, hidden_dim))  # [B, 2, H]
        # print(f"专家输出维度: {expert_outputs.shape}")
        # print(f"选择索引维度: {topk_indices.shape}")
        # print(f"选择专家维度: {selected_experts.shape}")
        weighted_outputs = (selected_experts  * topk_scores.unsqueeze(-1)).sum(dim=1)
        
        # 更新专家使用统计
        self.expert_counts += mask.sum(dim=0)
        
        return weighted_outputs

    def balance_loss(self):
        # 计算负载均衡损失(文献[4][7])
        expert_probs = self.expert_counts / self.expert_counts.sum()
        balance_loss = torch.std(expert_probs) * self.balance_loss_weight
        self.expert_counts.zero_()  # 重置计数器
        return balance_loss

# 完整模型架构(文献[2][6])
class MoEImageClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.moe_layer = SparseMoE()
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.feature_extractor(x)
        x = x.view(x.size(0), -1)  # 展平特征
        x = self.moe_layer(x)
        return self.classifier(x)

# 数据预处理(文献[2])
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)

# 训练流程
model = MoEImageClassifier()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(10):
    for images, labels in train_loader:
        optimizer.zero_grad()
        
        outputs = model(images)
        main_loss = criterion(outputs, labels)
        balance_loss = model.moe_layer.balance_loss()
        
        total_loss = main_loss + balance_loss
        total_loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/10], Loss: {total_loss.item():.4f}')


http://www.niftyadmin.cn/n/5868095.html

相关文章

qtcreator上使用opencv报错

发现是我选择opencv的版本有问题 右键桌面的qtcreator图标,进入Tools目录,可以看到mingw的版本是mingw730_64,因此编译opencv时也要用这个版本 下面是我网上随便找的别人编译好的,发现不行,这个所用的mingw版本也没提&#xff0c…

6. grafana的graph简介

1. Settings功能 2. Visualization功能 (可视化的方式,后续会写一些) 3. Display 功能(显示方面的设置) bars 柱状图方式显示 lines(不选不会出功能) line width 线条的粗细 staircase 会让折…

【IntelliJ IDEA】关于设置固定的文件格式(包括注释、版权信息等)的方法

在IntelliJ IDEA(简称IDEA)中,要设置固定的文件格式(包括注释、版权信息等),使得每次创建新文件时都能自动显示这些内容,可以通过以下步骤实现: 一、设置文件模板 打开IDEA并进入项…

c语言学习,归并排序

C语言,归并排序是经典的分治算法,基本思想是将,待排序的数组分成两个子数组,分别对这两个子数组进行排序,然后将排序好的子数组合并成一个有序的数组。归并排序的时间复杂度为O(n log n),且具有稳定性。 示…

MFC案例:利用双缓冲技术绘制顶点可移动三角形

案例目标:在屏幕上出现一个三角形,同时显示各顶点坐标,当用鼠标选择某顶点并拖动时,三角形随鼠标移动而变形。具体步骤为: 一、在VS2022上建立一个基于对话框的MFC应用,项目名称:DrawMovableTr…

IDEA-插件开发踩坑记录-第五坑-没有飞机场导致无法访问GITHUB导致的讨厌问题

背景 在JetBrains-intellij-idea 插件开发时,出现一个不影响运行,但影响心情的错误提示: Cannot resolve the latest Gradle IntelliJ Plugin version org.gradle.api.GradleException: Cannot resolve the latest Gradle IntelliJ Plugin v…

《OpenCV》——实例:答题卡识别

答题卡识别 实例内容: 该实例实现了一个基于计算机视觉技术的答题卡自动识别与评分系统,利用 OpenCV 库对答题卡图像进行处理和分析,最终得出答题卡的得分。 实例步骤: 导入必要的库 import numpy as np import cv2导入num…

hot100_108. 将有序数组转换为二叉搜索树

hot100_108. 将有序数组转换为二叉搜索树 思路 给你一个整数数组 nums ,其中元素已经按 升序 排列,请你将其转换为一棵 平衡 二叉搜索树。 示例 1: 输入:nums [-10,-3,0,5,9] 输出:[0,-3,9,-10,null,5] 解释&#…