This repository has been archived on 2025-09-14. You can view files and clone it, but cannot push or open issues or pull requests.
Files
2023-05-25 15:30:02 +08:00

50 lines
2.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch.nn as nn
import torch.utils.data as tud
import torch
class Model(nn.Module):
def __init__(self, customs: dict, dataloader: tud.DataLoader = None):
"""
:param customs: 自定义参数内容取自于config.ini文件的[CustomParameters]部分。
:param dataloader: 数据集初始化完成的dataloader。在自定义的预处理方法文件中可以增加内部变量或者方法提供给模型。
例如模型初始化需要数据的维度数量可通过n_features = dataloader.dataset.train_inputs.shape[-1]获取
或在预处理方法的MyDataset类中定义self.n_features = self.train_inputs.shape[-1]
通过n_features = dataloader.dataset.n_features获取
"""
super(Model, self).__init__()
def forward(self, x):
"""
:param x: 模型的输入在本工具中为MyDataset类中__getitem__方法返回的三个变量中的第一个变量。
:return: 模型的输出,可以自定义
"""
return None
def loss(self, x, y_true, epoch: int = None, device: str = "cpu"):
"""
计算loss。注意计算loss时如采用torch之外的库计算会造成梯度截断请全部使用torch的方法
:param x: 输入数据
:param y_true: 真实输出数据
:param epoch: 当前是第几个epoch
:param device: 设备cpu或者cuda
:return: loss值
"""
y_pred = self.forward(x) # 模型的输出
loss = torch.Tensor([1]) # 示例,请修改
loss.backward()
return loss.item()
def detection(self, x, y_true, epoch: int = None, device: str = "cpu"):
"""
检测方法,可以输出异常的分数,也可以输出具体的标签。
如输出异常分数,则后续会根据异常分数自动划分阈值,高于阈值的为异常,自动赋予标签;如输出标签,则直接进行评估。
:param x: 输入数据
:param y_true: 真实输出数据
:param epoch: 当前是第几个epoch
:param device: 设备cpu或者cuda
:return: scorelabel。如选择输出异常的分数则输出scorelabel为None如选择输出标签则输出labelscore为None。
score的格式为torch的Tensor格式尺寸为[batch_size]label的格式为torch的IntTensor格式尺寸为[batch_size]
"""
y_pred = self.forward(x) # 模型的输出
return None, None