306 lines
12 KiB
Python
306 lines
12 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
import math
|
|
from math import sqrt
|
|
import torch.utils.data as tud
|
|
|
|
|
|
class PositionalEmbedding(nn.Module):
|
|
def __init__(self, d_model, max_len=5000):
|
|
super(PositionalEmbedding, self).__init__()
|
|
# Compute the positional encodings once in log space.
|
|
pe = torch.zeros(max_len, d_model).float()
|
|
pe.require_grad = False
|
|
|
|
position = torch.arange(0, max_len).float().unsqueeze(1)
|
|
div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
|
|
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
|
|
pe = pe.unsqueeze(0)
|
|
self.register_buffer('pe', pe)
|
|
|
|
def forward(self, x):
|
|
return self.pe[:, :x.size(1)]
|
|
|
|
|
|
class TokenEmbedding(nn.Module):
|
|
def __init__(self, c_in, d_model):
|
|
super(TokenEmbedding, self).__init__()
|
|
padding = 1 if torch.__version__ >= '1.5.0' else 2
|
|
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
|
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv1d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
|
|
|
|
def forward(self, x):
|
|
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
|
|
return x
|
|
|
|
|
|
class DataEmbedding(nn.Module):
|
|
def __init__(self, c_in, d_model, dropout=0.0):
|
|
super(DataEmbedding, self).__init__()
|
|
|
|
self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
|
|
self.position_embedding = PositionalEmbedding(d_model=d_model)
|
|
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
|
|
def forward(self, x):
|
|
x = self.value_embedding(x) + self.position_embedding(x)
|
|
return self.dropout(x)
|
|
|
|
|
|
class TriangularCausalMask():
|
|
def __init__(self, B, L, device="cpu"):
|
|
mask_shape = [B, 1, L, L]
|
|
with torch.no_grad():
|
|
self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
|
|
|
|
@property
|
|
def mask(self):
|
|
return self._mask
|
|
|
|
|
|
class AnomalyAttention(nn.Module):
|
|
def __init__(self, win_size, mask_flag=True, scale=None, attention_dropout=0.0, output_attention=False):
|
|
super(AnomalyAttention, self).__init__()
|
|
self.scale = scale
|
|
self.mask_flag = mask_flag
|
|
self.output_attention = output_attention
|
|
self.dropout = nn.Dropout(attention_dropout)
|
|
window_size = win_size
|
|
self.distances = torch.zeros((window_size, window_size)).cuda()
|
|
for i in range(window_size):
|
|
for j in range(window_size):
|
|
self.distances[i][j] = abs(i - j)
|
|
|
|
def forward(self, queries, keys, values, sigma, attn_mask):
|
|
B, L, H, E = queries.shape
|
|
_, S, _, D = values.shape
|
|
scale = self.scale or 1. / sqrt(E)
|
|
|
|
scores = torch.einsum("blhe,bshe->bhls", queries, keys)
|
|
if self.mask_flag:
|
|
if attn_mask is None:
|
|
attn_mask = TriangularCausalMask(B, L, device=queries.device)
|
|
scores.masked_fill_(attn_mask.mask, -np.inf)
|
|
attn = scale * scores
|
|
|
|
sigma = sigma.transpose(1, 2) # B L H -> B H L
|
|
window_size = attn.shape[-1]
|
|
sigma = torch.sigmoid(sigma * 5) + 1e-5
|
|
sigma = torch.pow(3, sigma) - 1
|
|
sigma = sigma.unsqueeze(-1).repeat(1, 1, 1, window_size) # B H L L
|
|
prior = self.distances.unsqueeze(0).unsqueeze(0).repeat(sigma.shape[0], sigma.shape[1], 1, 1).cuda()
|
|
prior = 1.0 / (math.sqrt(2 * math.pi) * sigma) * torch.exp(-prior ** 2 / 2 / (sigma ** 2))
|
|
|
|
series = self.dropout(torch.softmax(attn, dim=-1))
|
|
V = torch.einsum("bhls,bshd->blhd", series, values)
|
|
|
|
if self.output_attention:
|
|
return (V.contiguous(), series, prior, sigma)
|
|
else:
|
|
return (V.contiguous(), None)
|
|
|
|
|
|
class AttentionLayer(nn.Module):
|
|
def __init__(self, attention, d_model, n_heads, d_keys=None,
|
|
d_values=None):
|
|
super(AttentionLayer, self).__init__()
|
|
|
|
d_keys = d_keys or (d_model // n_heads)
|
|
d_values = d_values or (d_model // n_heads)
|
|
self.norm = nn.LayerNorm(d_model)
|
|
self.inner_attention = attention
|
|
self.query_projection = nn.Linear(d_model,
|
|
d_keys * n_heads)
|
|
self.key_projection = nn.Linear(d_model,
|
|
d_keys * n_heads)
|
|
self.value_projection = nn.Linear(d_model,
|
|
d_values * n_heads)
|
|
self.sigma_projection = nn.Linear(d_model,
|
|
n_heads)
|
|
self.out_projection = nn.Linear(d_values * n_heads, d_model)
|
|
|
|
self.n_heads = n_heads
|
|
|
|
def forward(self, queries, keys, values, attn_mask):
|
|
B, L, _ = queries.shape
|
|
_, S, _ = keys.shape
|
|
H = self.n_heads
|
|
x = queries
|
|
queries = self.query_projection(queries).view(B, L, H, -1)
|
|
keys = self.key_projection(keys).view(B, S, H, -1)
|
|
values = self.value_projection(values).view(B, S, H, -1)
|
|
sigma = self.sigma_projection(x).view(B, L, H)
|
|
|
|
out, series, prior, sigma = self.inner_attention(
|
|
queries,
|
|
keys,
|
|
values,
|
|
sigma,
|
|
attn_mask
|
|
)
|
|
out = out.view(B, L, -1)
|
|
|
|
return self.out_projection(out), series, prior, sigma
|
|
|
|
|
|
class EncoderLayer(nn.Module):
|
|
def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
|
|
super(EncoderLayer, self).__init__()
|
|
d_ff = d_ff or 4 * d_model
|
|
self.attention = attention
|
|
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
|
|
self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.activation = F.relu if activation == "relu" else F.gelu
|
|
|
|
def forward(self, x, attn_mask=None):
|
|
new_x, attn, mask, sigma = self.attention(
|
|
x, x, x,
|
|
attn_mask=attn_mask
|
|
)
|
|
x = x + self.dropout(new_x)
|
|
y = x = self.norm1(x)
|
|
y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
|
|
y = self.dropout(self.conv2(y).transpose(-1, 1))
|
|
|
|
return self.norm2(x + y), attn, mask, sigma
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self, attn_layers, norm_layer=None):
|
|
super(Encoder, self).__init__()
|
|
self.attn_layers = nn.ModuleList(attn_layers)
|
|
self.norm = norm_layer
|
|
|
|
def forward(self, x, attn_mask=None):
|
|
# x [B, L, D]
|
|
series_list = []
|
|
prior_list = []
|
|
sigma_list = []
|
|
for attn_layer in self.attn_layers:
|
|
x, series, prior, sigma = attn_layer(x, attn_mask=attn_mask)
|
|
series_list.append(series)
|
|
prior_list.append(prior)
|
|
sigma_list.append(sigma)
|
|
|
|
if self.norm is not None:
|
|
x = self.norm(x)
|
|
|
|
return x, series_list, prior_list, sigma_list
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, customs: {}, dataloader: tud.DataLoader):
|
|
super(Model, self).__init__()
|
|
win_size = int(customs["input_size"])
|
|
enc_in = c_out = dataloader.dataset.train_inputs.shape[-1]
|
|
d_model = 512
|
|
n_heads = 8
|
|
e_layers = 3
|
|
d_ff = 512
|
|
dropout = 0.0
|
|
activation = 'gelu'
|
|
output_attention = True
|
|
self.k = 3
|
|
self.win_size = win_size
|
|
|
|
self.name = "AnomalyTransformer"
|
|
# Encoding
|
|
self.embedding = DataEmbedding(enc_in, d_model, dropout)
|
|
|
|
# Encoder
|
|
self.encoder = Encoder(
|
|
[
|
|
EncoderLayer(
|
|
AttentionLayer(
|
|
AnomalyAttention(win_size, False, attention_dropout=dropout, output_attention=output_attention),
|
|
d_model, n_heads),
|
|
d_model,
|
|
d_ff,
|
|
dropout=dropout,
|
|
activation=activation
|
|
) for l in range(e_layers)
|
|
],
|
|
norm_layer=torch.nn.LayerNorm(d_model)
|
|
)
|
|
|
|
self.projection = nn.Linear(d_model, c_out, bias=True)
|
|
|
|
def forward(self, x):
|
|
enc_out = self.embedding(x)
|
|
enc_out, series, prior, sigmas = self.encoder(enc_out)
|
|
enc_out = self.projection(enc_out)
|
|
return enc_out, series, prior, sigmas
|
|
|
|
@staticmethod
|
|
def my_kl_loss(p, q):
|
|
res = p * (torch.log(p + 0.0001) - torch.log(q + 0.0001))
|
|
return torch.mean(torch.sum(res, dim=-1), dim=1)
|
|
|
|
def loss(self, x, y_true, epoch: int = None, device: str = "cpu"):
|
|
output, series, prior, _ = self.forward(x)
|
|
series_loss = 0.0
|
|
prior_loss = 0.0
|
|
for u in range(len(prior)):
|
|
series_loss += (torch.mean(self.my_kl_loss(series[u], (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, self.win_size)).detach())) +
|
|
torch.mean(self.my_kl_loss((prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, self.win_size)).detach(), series[u])))
|
|
|
|
prior_loss += (torch.mean(self.my_kl_loss((prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, self.win_size)), series[u].detach())) +
|
|
torch.mean(self.my_kl_loss(series[u].detach(), (prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1, self.win_size)))))
|
|
series_loss = series_loss / len(prior)
|
|
prior_loss = prior_loss / len(prior)
|
|
rec_loss = nn.MSELoss()(output, x)
|
|
|
|
loss1 = rec_loss - self.k * series_loss
|
|
loss2 = rec_loss + self.k * prior_loss
|
|
|
|
# Minimax strategy
|
|
loss1.backward(retain_graph=True)
|
|
loss2.backward()
|
|
|
|
return loss1.item()
|
|
|
|
def detection(self, x, y_true, epoch: int = None, device: str = "cpu"):
|
|
temperature = 50
|
|
output, series, prior, _ = self.forward(x)
|
|
|
|
loss = torch.mean(nn.MSELoss()(x, output), dim=-1)
|
|
|
|
series_loss = 0.0
|
|
prior_loss = 0.0
|
|
for u in range(len(prior)):
|
|
if u == 0:
|
|
series_loss = self.my_kl_loss(series[u], (
|
|
prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
|
|
self.win_size)).detach()) * temperature
|
|
prior_loss = self.my_kl_loss(
|
|
(prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
|
|
self.win_size)),
|
|
series[u].detach()) * temperature
|
|
else:
|
|
series_loss += self.my_kl_loss(series[u], (
|
|
prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
|
|
self.win_size)).detach()) * temperature
|
|
prior_loss += self.my_kl_loss(
|
|
(prior[u] / torch.unsqueeze(torch.sum(prior[u], dim=-1), dim=-1).repeat(1, 1, 1,
|
|
self.win_size)),
|
|
series[u].detach()) * temperature
|
|
metric = torch.softmax((-series_loss - prior_loss), dim=-1)
|
|
|
|
cri = metric * loss
|
|
cri = cri.mean(dim=-1)
|
|
return cri, None
|
|
|
|
|