forked from JointCloud/JCC-DeepOD
💥 major update: add ray (auto hyper-parameter tuning tool) to DeepSVDD, COUTA, and TcnED.
This commit is contained in:
parent
bbc97239eb
commit
97fb54c7f1
|
@ -4,16 +4,21 @@ Base class for deep Anomaly detection models
|
|||
some functions are adapted from the pyod library
|
||||
@Author: Hongzuo Xu <hongzuoxu@126.com, xuhongzuo13@nudt.edu.cn>
|
||||
"""
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
import time
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from scipy.stats import binom
|
||||
from deepod.utils.utility import get_sub_seqs, get_sub_seqs_label
|
||||
from deepod.core.networks.base_networks import sequential_net_name
|
||||
from tqdm import tqdm
|
||||
from scipy.stats import binom
|
||||
from ray import tune
|
||||
from ray.air import session, Checkpoint
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
from functools import partial
|
||||
from deepod.utils.utility import get_sub_seqs, get_sub_seqs_label
|
||||
|
||||
|
||||
class BaseDeepAD(metaclass=ABCMeta):
|
||||
|
@ -132,11 +137,15 @@ class BaseDeepAD(metaclass=ABCMeta):
|
|||
|
||||
self.train_data = None
|
||||
self.train_label = None
|
||||
self.val_data = None
|
||||
self.val_label = None
|
||||
|
||||
self.decision_scores_ = None
|
||||
self.labels_ = None
|
||||
self.threshold_ = None
|
||||
|
||||
self.checkpoint_data = {}
|
||||
|
||||
self.random_state = random_state
|
||||
self.set_seed(random_state)
|
||||
return
|
||||
|
@ -192,6 +201,92 @@ class BaseDeepAD(metaclass=ABCMeta):
|
|||
|
||||
return self
|
||||
|
||||
def fit_auto_hyper(self, X, y=None, X_test=None, y_test=None,
|
||||
n_ray_samples=5, time_budget_s=None):
|
||||
"""
|
||||
Fit detector. y is ignored in unsupervised methods.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : numpy array of shape (n_samples, n_features)
|
||||
The input samples.
|
||||
|
||||
y : numpy array of shape (n_samples, )
|
||||
Not used in unsupervised methods, present for API consistency by convention.
|
||||
used in (semi-/weakly-) supervised methods
|
||||
|
||||
X_test : numpy array of shape (n_samples, n_features), default=None
|
||||
The input testing samples for hyper-parameter tuning.
|
||||
|
||||
y_test : numpy array of shape (n_samples, ), default=None
|
||||
Label of input testing samples for hyper-parameter tuning.
|
||||
|
||||
n_ray_samples: int, default=5
|
||||
Number of times to sample from the hyperparameter space
|
||||
|
||||
time_budget_s: int, default=None
|
||||
Global time budget in seconds after which all trials of Ray are stopped.
|
||||
|
||||
Returns
|
||||
-------
|
||||
config : dict
|
||||
tuned hyper-parameter
|
||||
"""
|
||||
if self.data_type == 'ts':
|
||||
self.train_data = get_sub_seqs(X, self.seq_len, self.stride)
|
||||
self.train_label = get_sub_seqs_label(y, self.seq_len, self.stride) if y is not None else None
|
||||
self.n_samples, self.n_features = self.train_data.shape[0], self.train_data.shape[2]
|
||||
|
||||
elif self.data_type == 'tabular':
|
||||
self.train_data = X
|
||||
self.train_label = y
|
||||
self.n_samples, self.n_features = self.train_data.shape
|
||||
|
||||
else:
|
||||
raise NotImplementedError('unsupported data_type')
|
||||
|
||||
config = self.set_tuned_params()
|
||||
metric = "loss" if X_test is None else 'metric'
|
||||
mode = "min" if X_test is None else 'max'
|
||||
scheduler = ASHAScheduler(
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
max_t=self.epochs,
|
||||
grace_period=1,
|
||||
reduction_factor=2,
|
||||
)
|
||||
|
||||
size = sys.getsizeof(self.train_data)/(1024**2)
|
||||
if size >= 30:
|
||||
split = int(len(self.train_data) / (size / 30))
|
||||
self.train_data = self.train_data[:split]
|
||||
self.train_label = self.train_label[:split] if y is not None else None
|
||||
warnings.warn('split training data to meet the 95 MiB limit of ray ImplitFunc')
|
||||
|
||||
result = tune.run(
|
||||
partial(self._training_ray,
|
||||
X_test=X_test, y_test=y_test),
|
||||
resources_per_trial={"cpu": 4, "gpu": 0 if self.device == 'cpu' else 1},
|
||||
config=config,
|
||||
num_samples=n_ray_samples,
|
||||
time_budget_s=time_budget_s,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
best_trial = result.get_best_trial(metric=metric, mode=mode, scope="last")
|
||||
print(f"Best trial config: {best_trial.config}")
|
||||
print(f"Best trial final validation loss: {best_trial.last_result['loss']}")
|
||||
|
||||
# tuned results
|
||||
best_checkpoint = best_trial.checkpoint.to_air_checkpoint().to_dict()
|
||||
best_config = best_trial.config
|
||||
self.load_ray_checkpoint(best_config=best_config, best_checkpoint=best_checkpoint)
|
||||
|
||||
# testing on the input training data
|
||||
self.decision_scores_ = self.decision_function(X)
|
||||
self.labels_ = self._process_decision_scores()
|
||||
return best_config
|
||||
|
||||
def decision_function(self, X, return_rep=False):
|
||||
"""Predict raw anomaly scores of X using the fitted detector.
|
||||
|
||||
|
@ -357,6 +452,9 @@ class BaseDeepAD(metaclass=ABCMeta):
|
|||
|
||||
return
|
||||
|
||||
def _training_ray(self, config, X_test, y_test):
|
||||
return
|
||||
|
||||
def _inference(self):
|
||||
self.net.eval()
|
||||
with torch.no_grad():
|
||||
|
@ -406,6 +504,17 @@ class BaseDeepAD(metaclass=ABCMeta):
|
|||
"""for any updating operation after decision function"""
|
||||
return z, scores
|
||||
|
||||
def set_tuned_net(self, config):
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def set_tuned_params():
|
||||
config = {}
|
||||
return config
|
||||
|
||||
def load_ray_checkpoint(self, best_config, best_checkpoint):
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def set_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
|
|
|
@ -0,0 +1,421 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Base class for deep Anomaly detection models
|
||||
some functions are adapted from the pyod library
|
||||
@Author: Hongzuo Xu <hongzuoxu@126.com, xuhongzuo13@nudt.edu.cn>
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
import time
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from tqdm import tqdm
|
||||
from scipy.stats import binom
|
||||
from ray import tune
|
||||
from ray.air import Checkpoint, session
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
from deepod.utils.utility import get_sub_seqs, get_sub_seqs_label
|
||||
|
||||
|
||||
class BaseDeepAD(metaclass=ABCMeta):
|
||||
"""
|
||||
Abstract class for deep outlier detection models
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
data_type: str, optional (default='tabular')
|
||||
Data type, choice = ['tabular', 'ts']
|
||||
|
||||
network: str, optional (default='MLP')
|
||||
network structure for different data structures
|
||||
|
||||
epochs: int, optional (default=100)
|
||||
Number of training epochs
|
||||
|
||||
batch_size: int, optional (default=64)
|
||||
Number of samples in a mini-batch
|
||||
|
||||
lr: float, optional (default=1e-3)
|
||||
Learning rate
|
||||
|
||||
n_ensemble: int or str, optional (default=1)
|
||||
Number of ensemble size
|
||||
|
||||
seq_len: int, optional (default=100)
|
||||
Size of window used to create subsequences from the data
|
||||
deprecated when handling tabular data (network=='MLP')
|
||||
|
||||
stride: int, optional (default=1)
|
||||
number of time points the window will move between two subsequences
|
||||
deprecated when handling tabular data (network=='MLP')
|
||||
|
||||
epoch_steps: int, optional (default=-1)
|
||||
Maximum steps in an epoch
|
||||
- If -1, all the batches will be processed
|
||||
|
||||
prt_steps: int, optional (default=10)
|
||||
Number of epoch intervals per printing
|
||||
|
||||
device: str, optional (default='cuda')
|
||||
torch device,
|
||||
|
||||
contamination : float in (0., 0.5), optional (default=0.1)
|
||||
The amount of contamination of the data set,
|
||||
i.e. the proportion of outliers in the data set. Used when fitting to
|
||||
define the threshold on the decision function.
|
||||
|
||||
verbose: int, optional (default=1)
|
||||
Verbosity mode
|
||||
|
||||
random_state: int, optional (default=42)
|
||||
the seed used by the random
|
||||
|
||||
Attributes
|
||||
----------
|
||||
decision_scores_ : numpy array of shape (n_samples,)
|
||||
The outlier scores of the training data.
|
||||
The higher, the more abnormal. Outliers tend to have higher
|
||||
scores. This value is available once the detector is fitted.
|
||||
|
||||
threshold_ : float
|
||||
The threshold is based on ``contamination``. It is the
|
||||
``n_samples * contamination`` most abnormal samples in
|
||||
``decision_scores_``. The threshold is calculated for generating
|
||||
binary outlier labels.
|
||||
|
||||
labels_ : int, either 0 or 1
|
||||
The binary labels of the training data. 0 stands for inliers
|
||||
and 1 for outliers/anomalies. It is generated by applying
|
||||
``threshold_`` on ``decision_scores_``.
|
||||
|
||||
"""
|
||||
def __init__(self, model_name, data_type='tabular', network='MLP',
|
||||
epochs=100, batch_size=64, lr=1e-3,
|
||||
n_ensemble=1, seq_len=100, stride=1,
|
||||
epoch_steps=-1, prt_steps=10,
|
||||
device='cuda', contamination=0.1,
|
||||
verbose=1, random_state=42):
|
||||
self.model_name = model_name
|
||||
|
||||
self.data_type = data_type
|
||||
self.network = network
|
||||
|
||||
# if data_type == 'ts':
|
||||
# assert self.network in sequential_net_name, \
|
||||
# 'Assigned network cannot handle time-series data'
|
||||
|
||||
self.seq_len = seq_len
|
||||
self.stride = stride
|
||||
|
||||
self.epochs = epochs
|
||||
self.batch_size = batch_size
|
||||
self.lr = lr
|
||||
|
||||
self.device = device
|
||||
self.contamination = contamination
|
||||
|
||||
self.epoch_steps = epoch_steps
|
||||
self.prt_steps = prt_steps
|
||||
self.verbose = verbose
|
||||
|
||||
self.n_features = -1
|
||||
self.n_samples = -1
|
||||
self.criterion = None
|
||||
self.net = None
|
||||
|
||||
self.n_ensemble = n_ensemble
|
||||
|
||||
self.train_loader = None
|
||||
self.test_loader = None
|
||||
|
||||
self.epoch_time = None
|
||||
|
||||
self.train_data = None
|
||||
self.train_label = None
|
||||
|
||||
self.decision_scores_ = None
|
||||
self.labels_ = None
|
||||
self.threshold_ = None
|
||||
|
||||
self.checkpoint_data = {}
|
||||
|
||||
self.random_state = random_state
|
||||
self.set_seed(random_state)
|
||||
return
|
||||
|
||||
def fit(self, X, y=None):
|
||||
"""
|
||||
Fit detector. y is ignored in unsupervised methods.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : numpy array of shape (n_samples, n_features)
|
||||
The input samples.
|
||||
|
||||
y : numpy array of shape (n_samples, )
|
||||
Not used in unsupervised methods, present for API consistency by convention.
|
||||
used in (semi-/weakly-) supervised methods
|
||||
|
||||
Returns
|
||||
-------
|
||||
self : object
|
||||
Fitted estimator.
|
||||
"""
|
||||
|
||||
if self.data_type == 'ts':
|
||||
X_seqs = get_sub_seqs(X, seq_len=self.seq_len, stride=self.stride)
|
||||
y_seqs = get_sub_seqs_label(y, seq_len=self.seq_len, stride=self.stride) if y is not None else None
|
||||
self.train_data = X_seqs
|
||||
self.train_label = y_seqs
|
||||
self.n_samples, self.n_features = X_seqs.shape[0], X_seqs.shape[2]
|
||||
else:
|
||||
self.train_data = X
|
||||
self.train_label = y
|
||||
self.n_samples, self.n_features = X.shape
|
||||
|
||||
if self.verbose >= 1:
|
||||
print('Start Training...')
|
||||
|
||||
if self.n_ensemble == 'auto':
|
||||
self.n_ensemble = int(np.floor(100 / (np.log(self.n_samples) + self.n_features)) + 1)
|
||||
if self.verbose >= 1:
|
||||
print(f'ensemble size: {self.n_ensemble}')
|
||||
|
||||
for _ in range(self.n_ensemble):
|
||||
self.train_loader, self.net, self.criterion = self.training_prepare(self.train_data,
|
||||
y=self.train_label)
|
||||
self._training()
|
||||
|
||||
if self.verbose >= 1:
|
||||
print('Start Inference on the training data...')
|
||||
|
||||
self.decision_scores_ = self.decision_function(X)
|
||||
self.labels_ = self._process_decision_scores()
|
||||
|
||||
return self
|
||||
|
||||
def decision_function(self, X, return_rep=False):
|
||||
"""Predict raw anomaly scores of X using the fitted detector.
|
||||
|
||||
The anomaly score of an input sample is computed based on the fitted
|
||||
detector. For consistency, outliers are assigned with
|
||||
higher anomaly scores.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : numpy array of shape (n_samples, n_features)
|
||||
The input samples. Sparse matrices are accepted only
|
||||
if they are supported by the base estimator.
|
||||
|
||||
return_rep: boolean, optional, default=False
|
||||
whether return representations
|
||||
|
||||
Returns
|
||||
-------
|
||||
anomaly_scores : numpy array of shape (n_samples,)
|
||||
The anomaly score of the input samples.
|
||||
"""
|
||||
|
||||
testing_n_samples = X.shape[0]
|
||||
|
||||
if self.data_type == 'ts':
|
||||
X = get_sub_seqs(X, seq_len=self.seq_len, stride=1)
|
||||
|
||||
representations = []
|
||||
s_final = np.zeros(testing_n_samples)
|
||||
for _ in range(self.n_ensemble):
|
||||
self.test_loader = self.inference_prepare(X)
|
||||
|
||||
z, scores = self._inference()
|
||||
z, scores = self.decision_function_update(z, scores)
|
||||
|
||||
if self.data_type == 'ts':
|
||||
padding = np.zeros(self.seq_len-1)
|
||||
scores = np.hstack((padding, scores))
|
||||
|
||||
s_final += scores
|
||||
representations.extend(z)
|
||||
representations = np.array(representations)
|
||||
|
||||
if return_rep:
|
||||
return s_final, representations
|
||||
else:
|
||||
return s_final
|
||||
|
||||
def predict(self, X, return_confidence=False):
|
||||
"""Predict if a particular sample is an outlier or not.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : numpy array of shape (n_samples, n_features)
|
||||
The input samples.
|
||||
|
||||
return_confidence : boolean, optional(default=False)
|
||||
If True, also return the confidence of prediction.
|
||||
|
||||
Returns
|
||||
-------
|
||||
outlier_labels : numpy array of shape (n_samples,)
|
||||
For each observation, tells whether
|
||||
it should be considered as an outlier according to the
|
||||
fitted model. 0 stands for inliers and 1 for outliers.
|
||||
confidence : numpy array of shape (n_samples,).
|
||||
Only if return_confidence is set to True.
|
||||
"""
|
||||
|
||||
pred_score = self.decision_function(X)
|
||||
prediction = (pred_score > self.threshold_).astype('int').ravel()
|
||||
|
||||
if return_confidence:
|
||||
confidence = self._predict_confidence(pred_score)
|
||||
return prediction, confidence
|
||||
|
||||
return prediction
|
||||
|
||||
def _predict_confidence(self, test_scores):
|
||||
"""Predict the model's confidence in making the same prediction
|
||||
under slightly different training sets.
|
||||
See :cite:`perini2020quantifying`.
|
||||
|
||||
Parameters
|
||||
-------
|
||||
test_scores : numpy array of shape (n_samples,)
|
||||
The anomaly score of the input samples.
|
||||
|
||||
Returns
|
||||
-------
|
||||
confidence : numpy array of shape (n_samples,)
|
||||
For each observation, tells how consistently the model would
|
||||
make the same prediction if the training set was perturbed.
|
||||
Return a probability, ranging in [0,1].
|
||||
|
||||
"""
|
||||
n = len(self.decision_scores_)
|
||||
|
||||
count_instances = np.vectorize(lambda x: np.count_nonzero(self.decision_scores_ <= x))
|
||||
n_instances = count_instances(test_scores)
|
||||
|
||||
# Derive the outlier probability using Bayesian approach
|
||||
posterior_prob = np.vectorize(lambda x: (1 + x) / (2 + n))(n_instances)
|
||||
|
||||
# Transform the outlier probability into a confidence value
|
||||
confidence = np.vectorize(
|
||||
lambda p: 1 - binom.cdf(n - int(n*self.contamination), n, p)
|
||||
)(posterior_prob)
|
||||
prediction = (test_scores > self.threshold_).astype('int').ravel()
|
||||
np.place(confidence, prediction==0, 1-confidence[prediction == 0])
|
||||
return confidence
|
||||
|
||||
def _process_decision_scores(self):
|
||||
"""Internal function to calculate key attributes:
|
||||
|
||||
- threshold_: used to decide the binary label
|
||||
- labels_: binary labels of training data
|
||||
|
||||
Returns
|
||||
-------
|
||||
self
|
||||
"""
|
||||
|
||||
self.threshold_ = np.percentile(self.decision_scores_, 100 * (1 - self.contamination))
|
||||
self.labels_ = (self.decision_scores_ > self.threshold_).astype('int').ravel()
|
||||
|
||||
self._mu = np.mean(self.decision_scores_)
|
||||
self._sigma = np.std(self.decision_scores_)
|
||||
|
||||
return self
|
||||
|
||||
def _training(config):
|
||||
optimizer = torch.optim.Adam(self.net.parameters(), lr=self.lr, eps=1e-6)
|
||||
|
||||
self.net.train()
|
||||
for i in range(self.epochs):
|
||||
t1 = time.time()
|
||||
total_loss = 0
|
||||
cnt = 0
|
||||
for batch_x in self.train_loader:
|
||||
loss = self.training_forward(batch_x, self.net, self.criterion)
|
||||
self.net.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
cnt += 1
|
||||
|
||||
# terminate this epoch when reaching assigned maximum steps per epoch
|
||||
if cnt > self.epoch_steps != -1:
|
||||
break
|
||||
|
||||
t = time.time() - t1
|
||||
if self.verbose >= 1 and (i == 0 or (i+1) % self.prt_steps == 0):
|
||||
print(f'epoch{i+1:3d}, '
|
||||
f'training loss: {total_loss/cnt:.6f}, '
|
||||
f'time: {t:.1f}s')
|
||||
|
||||
if i == 0:
|
||||
self.epoch_time = t
|
||||
|
||||
self.epoch_update()
|
||||
|
||||
return
|
||||
|
||||
def _inference(self):
|
||||
self.net.eval()
|
||||
with torch.no_grad():
|
||||
z_lst = []
|
||||
score_lst = []
|
||||
|
||||
if self.verbose >= 2:
|
||||
_iter_ = tqdm(self.test_loader, desc='testing: ')
|
||||
else:
|
||||
_iter_ = self.test_loader
|
||||
|
||||
for batch_x in _iter_:
|
||||
batch_z, s = self.inference_forward(batch_x, self.net, self.criterion)
|
||||
z_lst.append(batch_z)
|
||||
score_lst.append(s)
|
||||
|
||||
z = torch.cat(z_lst).data.cpu().numpy()
|
||||
scores = torch.cat(score_lst).data.cpu().numpy()
|
||||
|
||||
return z, scores
|
||||
|
||||
@abstractmethod
|
||||
def training_forward(self, batch_x, net, criterion):
|
||||
"""define forward step in training"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def inference_forward(self, batch_x, net, criterion):
|
||||
"""define forward step in inference"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def training_prepare(self, X, y):
|
||||
"""define train_loader, net, and criterion"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def inference_prepare(self, X):
|
||||
"""define test_loader"""
|
||||
pass
|
||||
|
||||
def epoch_update(self):
|
||||
"""for any updating operation after each training epoch"""
|
||||
return
|
||||
|
||||
def decision_function_update(self, z, scores):
|
||||
"""for any updating operation after decision function"""
|
||||
return z, scores
|
||||
|
||||
@staticmethod
|
||||
def set_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
# torch.backends.cudnn.deterministic = True
|
|
@ -6,8 +6,14 @@ One-class classification
|
|||
|
||||
from deepod.core.base_model import BaseDeepAD
|
||||
from deepod.core.networks.base_networks import MLPnet
|
||||
from deepod.metrics import tabular_metrics
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
import time
|
||||
from ray import tune
|
||||
from ray.air import session, Checkpoint
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
from functools import partial
|
||||
|
||||
|
||||
class DeepSVDD(BaseDeepAD):
|
||||
|
@ -16,9 +22,6 @@ class DeepSVDD(BaseDeepAD):
|
|||
|
||||
Parameters
|
||||
----------
|
||||
data_type: str, optional (default='tabular')
|
||||
Data type, choice=['tabular', 'ts']
|
||||
|
||||
epochs: int, optional (default=100)
|
||||
Number of training epochs
|
||||
|
||||
|
@ -103,7 +106,8 @@ class DeepSVDD(BaseDeepAD):
|
|||
def inference_prepare(self, X):
|
||||
test_loader = DataLoader(X, batch_size=self.batch_size,
|
||||
drop_last=False, shuffle=False)
|
||||
self.criterion.reduction = 'none'
|
||||
assert self.c is not None
|
||||
self.criterion = DSVDDLoss(c=self.c, reduction='none')
|
||||
return test_loader
|
||||
|
||||
def training_forward(self, batch_x, net, criterion):
|
||||
|
@ -118,6 +122,98 @@ class DeepSVDD(BaseDeepAD):
|
|||
s = criterion(batch_z)
|
||||
return batch_z, s
|
||||
|
||||
def _training_ray(self, config, X_test, y_test):
|
||||
train_data = self.train_data[:int(0.8 * len(self.train_data))]
|
||||
val_data = self.train_data[int(0.8 * len(self.train_data)):]
|
||||
|
||||
train_loader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_data, batch_size=self.batch_size, shuffle=True)
|
||||
|
||||
self.net = self.set_tuned_net(config)
|
||||
|
||||
self.c = self._set_c(self.net, train_loader)
|
||||
criterion = DSVDDLoss(c=self.c, reduction='mean')
|
||||
|
||||
optimizer = torch.optim.Adam(self.net.parameters(), lr=config['lr'], eps=1e-6)
|
||||
|
||||
self.net.train()
|
||||
for i in range(config['epochs']):
|
||||
t1 = time.time()
|
||||
total_loss = 0
|
||||
cnt = 0
|
||||
for batch_x in train_loader:
|
||||
loss = self.training_forward(batch_x, self.net, criterion)
|
||||
self.net.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
cnt += 1
|
||||
|
||||
# terminate this epoch when reaching assigned maximum steps per epoch
|
||||
if cnt > self.epoch_steps != -1:
|
||||
break
|
||||
|
||||
# validation phase
|
||||
val_loss = []
|
||||
with torch.no_grad():
|
||||
for batch_x in val_loader:
|
||||
loss = self.training_forward(batch_x, self.net, criterion)
|
||||
val_loss.append(loss)
|
||||
val_loss = torch.mean(torch.stack(val_loss)).data.cpu().item()
|
||||
|
||||
test_metric = -1
|
||||
if X_test is not None and y_test is not None:
|
||||
scores = self.decision_function(X_test)
|
||||
test_metric = tabular_metrics(y_test, scores)[0] # use adjusted Best-F1
|
||||
|
||||
t = time.time() - t1
|
||||
if self.verbose >= 1 and (i == 0 or (i+1) % self.prt_steps == 0):
|
||||
print(f'epoch{i+1:3d}, '
|
||||
f'training loss: {total_loss/cnt:.6f}, '
|
||||
f'validation loss: {val_loss:.6f}, '
|
||||
f'test F1: {test_metric:.3f}, '
|
||||
f'time: {t:.1f}s')
|
||||
|
||||
checkpoint_data = {
|
||||
"epoch": i,
|
||||
"net_state_dict": self.net.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
'c': self.c
|
||||
}
|
||||
checkpoint = Checkpoint.from_dict(checkpoint_data)
|
||||
session.report(
|
||||
{"loss": val_loss, "metric": test_metric},
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
def load_ray_checkpoint(self, best_config, best_checkpoint):
|
||||
self.net = self.set_tuned_net(best_config)
|
||||
self.net.load_state_dict(best_checkpoint['net_state_dict'])
|
||||
self.c = best_checkpoint['c']
|
||||
return
|
||||
|
||||
def set_tuned_net(self, config):
|
||||
network_params = {
|
||||
'n_features': self.n_features,
|
||||
'n_hidden': config['hidden_dims'],
|
||||
'n_output': config['rep_dim'],
|
||||
'activation': self.act,
|
||||
'bias': self.bias
|
||||
}
|
||||
net = MLPnet(**network_params).to(self.device)
|
||||
return net
|
||||
|
||||
@staticmethod
|
||||
def set_tuned_params():
|
||||
config = {
|
||||
'lr': tune.grid_search([1e-5, 1e-4, 1e-3, 1e-2]),
|
||||
'epochs': tune.grid_search([20, 50, 100]),
|
||||
'rep_dim': tune.grid_search([16, 64, 128, 512]),
|
||||
'hidden_dims': tune.choice(['100,100', '100'])
|
||||
}
|
||||
return config
|
||||
|
||||
def _set_c(self, net, dataloader, eps=0.1):
|
||||
"""Initializing the center for the hypersphere"""
|
||||
net.eval()
|
||||
|
|
|
@ -5,12 +5,19 @@ Calibrated One-class classifier for Unsupervised Time series Anomaly detection (
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
import time
|
||||
from torch.utils.data import Dataset
|
||||
from numpy.random import RandomState
|
||||
from torch.utils.data import DataLoader
|
||||
from deepod.utils.utility import get_sub_seqs
|
||||
from ray import tune, air
|
||||
from ray.air import session, Checkpoint
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
from functools import partial
|
||||
|
||||
from deepod.utils.utility import get_sub_seqs, get_sub_seqs_label
|
||||
from deepod.core.networks.ts_network_tcn import TcnResidualBlock
|
||||
from deepod.core.base_model import BaseDeepAD
|
||||
from deepod.metrics import ts_metrics, point_adjustment
|
||||
|
||||
|
||||
class COUTA(BaseDeepAD):
|
||||
|
@ -21,39 +28,55 @@ class COUTA(BaseDeepAD):
|
|||
----------
|
||||
seq_len: integer, default=100
|
||||
sliding window length
|
||||
|
||||
stride: integer, default=1
|
||||
sliding window stride
|
||||
|
||||
epochs: integer, default=40
|
||||
the number of training epochs
|
||||
|
||||
batch_size: integer, default=64
|
||||
the size of mini-batches
|
||||
|
||||
lr: float, default=1e-4
|
||||
learning rate
|
||||
|
||||
ss_type: string, default='FULL'
|
||||
types of perturbation operation type, which can be 'FULL' (using all
|
||||
three anomaly types), 'point', 'contextual', or 'collective'.
|
||||
|
||||
hidden_dims: integer or list of integer, default=16,
|
||||
the number of neural units in the hidden layer
|
||||
|
||||
rep_dim: integer, default=16
|
||||
the dimensionality of the feature space
|
||||
|
||||
rep_hidden: integer, default=16
|
||||
the number of neural units of the hidden layer
|
||||
|
||||
pretext_hidden: integer, default=16
|
||||
the number of neural units of the hidden layer
|
||||
|
||||
kernel_size: integer, default=2
|
||||
the size of the convolutional kernel in TCN
|
||||
|
||||
dropout: float, default=0
|
||||
the dropout rate
|
||||
|
||||
bias: bool, default=True
|
||||
the bias term of the linear layer
|
||||
|
||||
alpha: float, default=0.1
|
||||
the weight of the classification head of NAC
|
||||
|
||||
neg_batch_ratio: float, default=0.2
|
||||
the ratio of generated native anomaly examples
|
||||
es: bool, default=False
|
||||
early stopping
|
||||
seed: integer, default=42
|
||||
|
||||
random_state: integer, default=42
|
||||
random state seed
|
||||
|
||||
device: string, default='cuda'
|
||||
|
||||
"""
|
||||
def __init__(self, seq_len=100, stride=1,
|
||||
epochs=40, batch_size=64, lr=1e-4, ss_type='FULL',
|
||||
|
@ -138,7 +161,7 @@ class COUTA(BaseDeepAD):
|
|||
)
|
||||
self.net.to(self.device)
|
||||
|
||||
self.set_c(train_seqs)
|
||||
self.c = self._set_c(self.net, train_seqs)
|
||||
self.net = self.train(self.net, train_seqs, val_seqs)
|
||||
|
||||
self.decision_scores_ = self.decision_function(X)
|
||||
|
@ -146,7 +169,7 @@ class COUTA(BaseDeepAD):
|
|||
|
||||
return
|
||||
|
||||
def decision_function(self, X):
|
||||
def decision_function(self, X, return_rep=False):
|
||||
"""
|
||||
Predict raw anomaly score of X using the fitted detector.
|
||||
For consistency, outliers are assigned with larger anomaly scores.
|
||||
|
@ -157,6 +180,9 @@ class COUTA(BaseDeepAD):
|
|||
The input samples. Sparse matrices are accepted only
|
||||
if they are supported by the base estimator.
|
||||
|
||||
return_rep: boolean, optional, default=False
|
||||
whether return representations
|
||||
|
||||
Returns
|
||||
-------
|
||||
anomaly_scores : numpy array of shape (n_samples,)
|
||||
|
@ -240,11 +266,9 @@ class COUTA(BaseDeepAD):
|
|||
|
||||
loss_lst.append(loss)
|
||||
loss_oc_lst.append(loss_oc)
|
||||
# loss_ssl_lst.append(loss_ssl)
|
||||
|
||||
epoch_loss = torch.mean(torch.stack(loss_lst)).data.cpu().item()
|
||||
epoch_loss_oc = torch.mean(torch.stack(loss_oc_lst)).data.cpu().item()
|
||||
# epoch_loss_ssl = torch.mean(torch.stack(loss_ssl_lst)).data.cpu().item()
|
||||
|
||||
# validation phase
|
||||
val_loss = np.NAN
|
||||
|
@ -268,23 +292,161 @@ class COUTA(BaseDeepAD):
|
|||
|
||||
return net
|
||||
|
||||
def set_c(self, seqs, eps=0.1):
|
||||
def _training_ray(self, config, X_test, y_test):
|
||||
train_data = self.train_data[:int(0.8 * len(self.train_data))]
|
||||
val_data = self.train_data[int(0.8 * len(self.train_data)):]
|
||||
|
||||
train_loader = DataLoader(dataset=SubseqData(train_data), batch_size=self.batch_size,
|
||||
drop_last=True, pin_memory=True, shuffle=True)
|
||||
val_loader = DataLoader(dataset=SubseqData(val_data), batch_size=self.batch_size,
|
||||
drop_last=True, pin_memory=True, shuffle=True)
|
||||
|
||||
self.net = self.set_tuned_net(config)
|
||||
self.c = self._set_c(self.net, train_data)
|
||||
criterion_oc_umc = DSVDDUncLoss(c=self.c, reduction='mean')
|
||||
criterion_mse = torch.nn.MSELoss(reduction='mean')
|
||||
optimizer = torch.optim.Adam(self.net.parameters(), lr=config['lr'], eps=1e-6)
|
||||
|
||||
self.net.train()
|
||||
for i in range(config['epochs']):
|
||||
t1 = time.time()
|
||||
rng = RandomState(seed=self.random_state+i)
|
||||
epoch_seed = rng.randint(0, 1e+6, len(train_loader))
|
||||
loss_lst, loss_oc_lst, loss_ssl_lst, = [], [], []
|
||||
for ii, x0 in enumerate(train_loader):
|
||||
x0 = x0.float().to(self.device)
|
||||
y0 = -1 * torch.ones(self.batch_size).float().to(self.device)
|
||||
|
||||
x0_output = self.net(x0)
|
||||
|
||||
rep_x0 = x0_output[0]
|
||||
rep_x0_dup = x0_output[1]
|
||||
loss_oc = criterion_oc_umc(rep_x0, rep_x0_dup)
|
||||
|
||||
tmp_rng = RandomState(epoch_seed[ii])
|
||||
neg_batch_size = int(config['neg_batch_ratio'] * self.batch_size)
|
||||
neg_candidate_idx = tmp_rng.randint(0, self.batch_size, neg_batch_size)
|
||||
|
||||
x1, y1 = create_batch_neg(
|
||||
batch_seqs=x0[neg_candidate_idx],
|
||||
max_cut_ratio=self.max_cut_ratio,
|
||||
seed=epoch_seed[ii],
|
||||
return_mul_label=False,
|
||||
ss_type=self.ss_type
|
||||
)
|
||||
x1, y1 = x1.to(self.device), y1.to(self.device)
|
||||
y = torch.hstack([y0, y1])
|
||||
|
||||
x1_output = self.net(x1)
|
||||
pred_x1 = x1_output[-1]
|
||||
pred_x0 = x0_output[-1]
|
||||
|
||||
out = torch.cat([pred_x0, pred_x1]).view(-1)
|
||||
|
||||
loss_ssl = criterion_mse(out, y)
|
||||
loss = loss_oc + config['alpha'] * loss_ssl
|
||||
|
||||
self.net.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
loss_lst.append(loss)
|
||||
loss_oc_lst.append(loss_oc)
|
||||
|
||||
epoch_loss = torch.mean(torch.stack(loss_lst)).data.cpu().item()
|
||||
epoch_loss_oc = torch.mean(torch.stack(loss_oc_lst)).data.cpu().item()
|
||||
|
||||
# validation phase
|
||||
val_loss = []
|
||||
with torch.no_grad():
|
||||
for x in val_loader:
|
||||
x = x.float().to(self.device)
|
||||
x_out = self.net(x)
|
||||
loss = criterion_oc_umc(x_out[0], x_out[1])
|
||||
loss = torch.mean(loss)
|
||||
val_loss.append(loss)
|
||||
val_loss = torch.mean(torch.stack(val_loss)).data.cpu().item()
|
||||
|
||||
test_metric = -1
|
||||
if X_test is not None and y_test is not None:
|
||||
scores = self.decision_function(X_test)
|
||||
adj_eval_metrics = ts_metrics(y_test, point_adjustment(y_test, scores))
|
||||
test_metric = adj_eval_metrics[2] # use adjusted Best-F1
|
||||
|
||||
t = time.time() - t1
|
||||
if self.verbose >= 1 and (i == 0 or (i+1) % self.prt_steps == 0):
|
||||
print(
|
||||
f'epoch: {i+1:3d}, '
|
||||
f'training loss: {epoch_loss:.6f}, '
|
||||
f'training loss_oc: {epoch_loss_oc:.6f}, '
|
||||
f'validation loss: {val_loss:.6f}, '
|
||||
f'test F1: {test_metric:.3f}, '
|
||||
f'time: {t:.1f}s'
|
||||
)
|
||||
|
||||
checkpoint_data = {
|
||||
"epoch": i,
|
||||
"net_state_dict": self.net.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
'c': self.c
|
||||
}
|
||||
checkpoint = Checkpoint.from_dict(checkpoint_data)
|
||||
session.report(
|
||||
{"loss": val_loss, "metric": test_metric},
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def set_tuned_params():
|
||||
config = {
|
||||
'lr': tune.grid_search([1e-5, 1e-4, 1e-3, 1e-2]),
|
||||
'epochs': tune.grid_search([20, 50, 100]),
|
||||
'rep_dim': tune.choice([16, 64, 128, 512]),
|
||||
'hidden_dims': tune.choice(['16', '32,16']),
|
||||
'alpha': tune.choice([0.1, 0.2, 0.5, 0.8, 1.0]),
|
||||
'neg_batch_ratio': tune.choice([0.2, 0.5]),
|
||||
}
|
||||
return config
|
||||
|
||||
def set_tuned_net(self, config):
|
||||
net = COUTANet(
|
||||
input_dim=self.n_features,
|
||||
hidden_dims=config['hidden_dims'],
|
||||
n_output=config['rep_dim'],
|
||||
pretext_hidden=self.pretext_hidden,
|
||||
rep_hidden=self.rep_hidden,
|
||||
out_dim=1,
|
||||
kernel_size=self.kernel_size,
|
||||
dropout=self.dropout,
|
||||
bias=self.bias,
|
||||
pretext=True,
|
||||
dup=True
|
||||
).to(self.device)
|
||||
return net
|
||||
|
||||
def load_ray_checkpoint(self, best_config, best_checkpoint):
|
||||
self.c = best_checkpoint['c']
|
||||
self.net = self.set_tuned_net(best_config)
|
||||
self.net.load_state_dict(best_checkpoint['net_state_dict'])
|
||||
return
|
||||
|
||||
def _set_c(self, net, seqs, eps=0.1):
|
||||
"""Initializing the center for the hypersphere"""
|
||||
dataloader = DataLoader(dataset=SubseqData(seqs), batch_size=self.batch_size,
|
||||
drop_last=True, pin_memory=True, shuffle=True)
|
||||
z_ = []
|
||||
self.net.eval()
|
||||
net.eval()
|
||||
with torch.no_grad():
|
||||
for x in dataloader:
|
||||
x = x.float().to(self.device)
|
||||
x_output = self.net(x)
|
||||
x_output = net(x)
|
||||
rep = x_output[0]
|
||||
z_.append(rep.detach())
|
||||
z_ = torch.cat(z_)
|
||||
c = torch.mean(z_, dim=0)
|
||||
c[(abs(c) < eps) & (c < 0)] = -eps
|
||||
c[(abs(c) < eps) & (c > 0)] = eps
|
||||
self.c = c
|
||||
return c
|
||||
|
||||
def training_forward(self, batch_x, net, criterion):
|
||||
"""define forward step in training"""
|
||||
|
@ -378,7 +540,12 @@ class COUTANet(torch.nn.Module):
|
|||
|
||||
self.layers = []
|
||||
|
||||
if type(hidden_dims) == int: hidden_dims = [hidden_dims]
|
||||
if type(hidden_dims) == int:
|
||||
hidden_dims = [hidden_dims]
|
||||
elif type(hidden_dims) == str:
|
||||
hidden_dims = hidden_dims.split(',')
|
||||
hidden_dims = [int(a) for a in hidden_dims]
|
||||
|
||||
num_layers = len(hidden_dims)
|
||||
for i in range(num_layers):
|
||||
dilation_size = 2 ** i
|
||||
|
|
|
@ -2,10 +2,16 @@
|
|||
TCN is adapted from https://github.com/locuslab/TCN
|
||||
"""
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from deepod.core.base_model import BaseDeepAD
|
||||
from deepod.core.networks.ts_network_tcn import TcnAE
|
||||
from deepod.metrics import ts_metrics, point_adjustment
|
||||
|
||||
import time
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from ray import tune
|
||||
from ray.air import session, Checkpoint
|
||||
|
||||
|
||||
class TcnED(BaseDeepAD):
|
||||
|
@ -29,7 +35,7 @@ class TcnED(BaseDeepAD):
|
|||
|
||||
return
|
||||
|
||||
def training_prepare(self, X, y):
|
||||
def training_prepare(self, X, y=None):
|
||||
train_loader = DataLoader(X, batch_size=self.batch_size, shuffle=True)
|
||||
|
||||
net = TcnAE(
|
||||
|
@ -52,7 +58,7 @@ class TcnED(BaseDeepAD):
|
|||
def inference_prepare(self, X):
|
||||
test_loader = DataLoader(X, batch_size=self.batch_size,
|
||||
drop_last=False, shuffle=False)
|
||||
self.criterion.reduction = 'none'
|
||||
self.criterion = torch.nn.MSELoss(reduction="none")
|
||||
return test_loader
|
||||
|
||||
def training_forward(self, batch_x, net, criterion):
|
||||
|
@ -67,3 +73,94 @@ class TcnED(BaseDeepAD):
|
|||
error = torch.nn.L1Loss(reduction='none')(output[:, -1], batch_x[:, -1])
|
||||
error = torch.sum(error, dim=1)
|
||||
return output, error
|
||||
|
||||
def _training_ray(self, config, X_test, y_test):
|
||||
train_data = self.train_data[:int(0.8 * len(self.train_data))]
|
||||
val_data = self.train_data[int(0.8 * len(self.train_data)):]
|
||||
|
||||
train_loader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_data, batch_size=self.batch_size, shuffle=True)
|
||||
|
||||
criterion = torch.nn.MSELoss(reduction="mean")
|
||||
self.net = self.set_tuned_net(config)
|
||||
|
||||
optimizer = torch.optim.Adam(self.net.parameters(), lr=config['lr'], eps=1e-6)
|
||||
|
||||
self.net.train()
|
||||
for i in range(config['epochs']):
|
||||
t1 = time.time()
|
||||
total_loss = 0
|
||||
cnt = 0
|
||||
for batch_x in train_loader:
|
||||
loss = self.training_forward(batch_x, self.net, criterion)
|
||||
self.net.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
cnt += 1
|
||||
|
||||
# terminate this epoch when reaching assigned maximum steps per epoch
|
||||
if cnt > self.epoch_steps != -1:
|
||||
break
|
||||
|
||||
# validation phase
|
||||
val_loss = []
|
||||
with torch.no_grad():
|
||||
for batch_x in val_loader:
|
||||
loss = self.training_forward(batch_x, self.net, criterion)
|
||||
val_loss.append(loss)
|
||||
val_loss = torch.mean(torch.stack(val_loss)).data.cpu().item()
|
||||
|
||||
test_metric = -1
|
||||
if X_test is not None and y_test is not None:
|
||||
scores = self.decision_function(X_test)
|
||||
adj_eval_metrics = ts_metrics(y_test, point_adjustment(y_test, scores))
|
||||
test_metric = adj_eval_metrics[2] # use adjusted Best-F1
|
||||
|
||||
t = time.time() - t1
|
||||
if self.verbose >= 1 and (i == 0 or (i+1) % self.prt_steps == 0):
|
||||
print(f'epoch{i+1:3d}, '
|
||||
f'training loss: {total_loss/cnt:.6f}, '
|
||||
f'validation loss: {val_loss:.6f}, '
|
||||
f'test F1: {test_metric:.3f}, '
|
||||
f'time: {t:.1f}s')
|
||||
|
||||
checkpoint_data = {
|
||||
"epoch": i,
|
||||
"net_state_dict": self.net.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
}
|
||||
checkpoint = Checkpoint.from_dict(checkpoint_data)
|
||||
session.report(
|
||||
{"loss": val_loss, "metric": test_metric},
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
def load_ray_checkpoint(self, best_config, best_checkpoint):
|
||||
self.net = self.set_tuned_net(best_config)
|
||||
self.net.load_state_dict(best_checkpoint['net_state_dict'])
|
||||
return
|
||||
|
||||
def set_tuned_net(self, config):
|
||||
net = TcnAE(
|
||||
n_features=self.n_features,
|
||||
n_hidden=config['hidden_dims'],
|
||||
n_emb=config['rep_dim'],
|
||||
activation=self.act,
|
||||
bias=self.bias,
|
||||
kernel_size=config['kernel_size'],
|
||||
dropout=self.dropout
|
||||
).to(self.device)
|
||||
return net
|
||||
|
||||
@staticmethod
|
||||
def set_tuned_params():
|
||||
config = {
|
||||
'lr': tune.grid_search([1e-5, 1e-4, 1e-3, 1e-2]),
|
||||
'epochs': tune.grid_search([20, 50, 100]),
|
||||
'rep_dim': tune.choice([16, 64, 128, 512]),
|
||||
'hidden_dims': tune.choice(['100,100', '100']),
|
||||
'kernel_size': tune.choice([2, 3, 5])
|
||||
}
|
||||
return config
|
||||
|
|
|
@ -6,4 +6,6 @@ dependencies:
|
|||
- scipy
|
||||
- pytorch
|
||||
- tqdm
|
||||
- ray
|
||||
- pyarrow
|
||||
|
||||
|
|
|
@ -3,4 +3,6 @@ scipy>=1.5.1
|
|||
scikit_learn>=0.20.0
|
||||
pandas>=1.0.0
|
||||
torch>=1.10.0,<1.13.1
|
||||
tqdm>=4.62.3
|
||||
tqdm>=4.62.3
|
||||
ray=2.6.1
|
||||
pyarrow>=11.0.0
|
|
@ -5,7 +5,7 @@ testbed of unsupervised tabular anomaly detection
|
|||
"""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import warnings
|
||||
import argparse
|
||||
import getpass
|
||||
import time
|
||||
|
@ -25,10 +25,12 @@ parser.add_argument("--input_dir", type=str,
|
|||
help="the path of the data sets")
|
||||
parser.add_argument("--output_dir", type=str, default='@record/',
|
||||
help="the output file path")
|
||||
parser.add_argument("--dataset", type=str, default='*thyroid*',
|
||||
parser.add_argument("--dataset", type=str, default='38_thyroid*',
|
||||
help="FULL represents all the csv file in the folder, "
|
||||
"or a list of data set names split by comma")
|
||||
parser.add_argument("--model", type=str, default='SLAD', help="",)
|
||||
parser.add_argument("--model", type=str, default='DeepSVDD', help="",)
|
||||
parser.add_argument("--auto_hyper", default=True, action='store_true', help="")
|
||||
|
||||
parser.add_argument("--normalization", type=str, default='min-max', help="",)
|
||||
parser.add_argument('--silent_header', action='store_true')
|
||||
parser.add_argument("--flag", type=str, default='')
|
||||
|
@ -70,25 +72,50 @@ for file in data_lst:
|
|||
continue
|
||||
|
||||
auc_lst, ap_lst, f1_lst = np.zeros(args.runs), np.zeros(args.runs), np.zeros(args.runs)
|
||||
t1_lst, t2_lst = np.zeros(args.runs), np.zeros(args.runs)
|
||||
clf = None
|
||||
for i in range(args.runs):
|
||||
t1_lst, t2_lst = [], []
|
||||
runs = args.runs
|
||||
|
||||
model_configs = {}
|
||||
if args.auto_hyper:
|
||||
clf = model_class(random_state=42)
|
||||
|
||||
# check whether the anomaly detection model supports ray tuning
|
||||
if not hasattr(clf, 'fit_auto_hyper'):
|
||||
warnings.warn(f'anomaly detection model {args.model} '
|
||||
f'does not support auto tuning hyper-parameters currently.')
|
||||
break
|
||||
|
||||
print(f'\nRunning [1/{args.runs}] of [{args.model}] on Dataset [{dataset_name}] (rat tune)')
|
||||
tuned_model_configs = clf.fit_auto_hyper(X=x_train,
|
||||
X_test=x_test, y_test=y_test,
|
||||
n_ray_samples=3, time_budget_s=None)
|
||||
model_configs = tuned_model_configs
|
||||
print(f'model parameter configure update to: {model_configs}')
|
||||
scores = clf.decision_function(x_test)
|
||||
|
||||
auc, ap, f1 = tabular_metrics(y_test, scores)
|
||||
|
||||
print(f'{dataset_name}, {auc:.4f}, {ap:.4f}, {f1:.4f}, '
|
||||
f'{args.model}')
|
||||
|
||||
for i in range(runs):
|
||||
start_time = time.time()
|
||||
print(f'\nRunning [{i+1}/{args.runs}] of [{args.model}] on Dataset [{dataset_name}]')
|
||||
|
||||
clf = model_class(epochs=50, random_state=42+i)
|
||||
clf = model_class(**model_configs, random_state=42+i)
|
||||
clf.fit(x_train)
|
||||
|
||||
train_time = time.time()
|
||||
scores = clf.decision_function(x_test)
|
||||
done_time = time.time()
|
||||
|
||||
auc, ap, f1 = tabular_metrics(y_test, scores)
|
||||
auc_lst[i], ap_lst[i], f1_lst[i] = auc, ap, f1
|
||||
t1_lst[i] = train_time - start_time
|
||||
t2_lst[i] = done_time - start_time
|
||||
t1_lst.append(train_time - start_time)
|
||||
t2_lst.append(done_time - start_time)
|
||||
|
||||
print(f'{dataset_name}, {auc_lst[i]:.4f}, {ap_lst[i]:.4f}, {f1_lst[i]:.4f}, '
|
||||
f'{t1_lst[i]:.1f}/{t2_lst[i]:.1f}, {args.model}')
|
||||
f'{t1_lst[i]:.1f}/{t2_lst[i]:.1f}, {args.model}, {str(model_configs)}')
|
||||
|
||||
avg_auc, avg_ap, avg_f1 = np.average(auc_lst), np.average(ap_lst), np.average(f1_lst)
|
||||
std_auc, std_ap, std_f1 = np.std(auc_lst), np.std(ap_lst), np.std(f1_lst)
|
||||
|
@ -100,7 +127,7 @@ for file in data_lst:
|
|||
f'{avg_auc:.4f}, {std_auc:.4f}, ' \
|
||||
f'{avg_ap:.4f}, {std_ap:.4f}, ' \
|
||||
f'{avg_f1:.4f}, {std_f1:.4f}, ' \
|
||||
f'{avg_time1:.1f}/{avg_time2:.1f}'
|
||||
f'{avg_time1:.1f}/{avg_time2:.1f}, {args.model}, {str(model_configs)}'
|
||||
print(txt, file=f)
|
||||
print(txt)
|
||||
f.close()
|
||||
|
|
|
@ -7,6 +7,7 @@ testbed of unsupervised time series anomaly detection
|
|||
import os
|
||||
import argparse
|
||||
import getpass
|
||||
import warnings
|
||||
import yaml
|
||||
import time
|
||||
import importlib as imp
|
||||
|
@ -19,27 +20,27 @@ dataset_root = f'/home/{getpass.getuser()}/dataset/5-TSdata/_processed_data/'
|
|||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--runs", type=int, default=5,
|
||||
help="how many times we repeat the experiments to obtain the average performance")
|
||||
parser.add_argument("--runs", type=int, default=1,
|
||||
help="how many times we repeat the experiments to "
|
||||
"obtain the average performance")
|
||||
parser.add_argument("--output_dir", type=str, default='@records/',
|
||||
help="the output file path")
|
||||
parser.add_argument("--dataset", type=str,
|
||||
default='SWaT_cut',
|
||||
)
|
||||
parser.add_argument("--dataset", type=str, default='ASD',
|
||||
help='dataset name or a list of names split by comma')
|
||||
parser.add_argument("--entities", type=str,
|
||||
default='FULL',
|
||||
help='FULL represents all the csv file in the folder, '
|
||||
'or a list of entity names split by comma'
|
||||
)
|
||||
'or a list of entity names split by comma')
|
||||
parser.add_argument("--entity_combined", type=int, default=1)
|
||||
parser.add_argument("--model", type=str, default='TcnED', help="")
|
||||
parser.add_argument("--model", type=str, default='COUTA', help="")
|
||||
parser.add_argument("--auto_hyper", default=True, action='store_true', help="")
|
||||
|
||||
parser.add_argument('--silent_header', action='store_true')
|
||||
parser.add_argument("--flag", type=str, default='')
|
||||
parser.add_argument("--note", type=str, default='')
|
||||
|
||||
parser.add_argument('--seq_len', type=int, default=30)
|
||||
parser.add_argument('--stride', type=int, default=1)
|
||||
parser.add_argument('--stride', type=int, default=10)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -83,23 +84,75 @@ if not args.silent_header:
|
|||
dataset_name_lst = args.dataset.split(',')
|
||||
|
||||
for dataset in dataset_name_lst:
|
||||
# # import data
|
||||
# # read data
|
||||
data_pkg = import_ts_data_unsupervised(dataset_root,
|
||||
dataset, entities=args.entities,
|
||||
combine=args.entity_combined)
|
||||
dataset, entities=args.entities,
|
||||
combine=args.entity_combined)
|
||||
train_lst, test_lst, label_lst, name_lst = data_pkg
|
||||
|
||||
entity_metric_lst = []
|
||||
entity_metric_std_lst = []
|
||||
for train_data, test_data, labels, dataset_name in zip(train_lst, test_lst, label_lst, name_lst):
|
||||
|
||||
entries = []
|
||||
t_lst = []
|
||||
for i in range(args.runs):
|
||||
runs = args.runs
|
||||
|
||||
# using ray to tune hyper-parameters
|
||||
if args.auto_hyper:
|
||||
clf = model_class(**model_configs, random_state=42)
|
||||
|
||||
# check whether the anomaly detection model supports ray tuning
|
||||
if not hasattr(clf, 'fit_auto_hyper'):
|
||||
warnings.warn(f'anomaly detection model {args.model} '
|
||||
f'does not support auto tuning hyper-parameters currently.')
|
||||
break
|
||||
|
||||
print(f'\nRunning [1/{args.runs}] of [{args.model}] on Dataset [{dataset_name}] (ray tune)')
|
||||
|
||||
# config = {
|
||||
# 'lr': 1e-4,
|
||||
# 'epochs': 20,
|
||||
# 'rep_dim': 16,
|
||||
# 'hidden_dims': '100',
|
||||
# 'kernel_size': 3
|
||||
# }
|
||||
# from deepod.utils.utility import get_sub_seqs
|
||||
# train_data = get_sub_seqs(train_data, seq_len=30, stride=10)
|
||||
# clf.train_data = train_data
|
||||
# clf.n_features = train_data.shape[2]
|
||||
# clf._training_ray(config=config, X_test=test_data, y_test=labels)
|
||||
|
||||
# # fit
|
||||
tuned_model_configs = clf.fit_auto_hyper(X=train_data,
|
||||
X_test=test_data, y_test=labels,
|
||||
n_ray_samples=3, time_budget_s=None)
|
||||
|
||||
# default_keys = list(set(model_configs.keys()) - set(tuned_model_configs.keys()))
|
||||
# default_configs = {}
|
||||
# for k in default_keys:
|
||||
# default_configs[k] = model_configs[k]
|
||||
model_configs = dict(model_configs, **tuned_model_configs)
|
||||
print(f'model parameter configure update to: {model_configs}')
|
||||
|
||||
scores = clf.decision_function(test_data)
|
||||
|
||||
eval_metrics = ts_metrics(labels, scores)
|
||||
adj_eval_metrics = ts_metrics(labels, point_adjustment(labels, scores))
|
||||
|
||||
# print single results
|
||||
txt = f'{dataset_name},'
|
||||
txt += ', '.join(['%.4f' % a for a in eval_metrics]) + \
|
||||
', pa, ' + \
|
||||
', '.join(['%.4f' % a for a in adj_eval_metrics])
|
||||
txt += f', model, {args.model}, runs, 1/{args.runs}'
|
||||
print(txt)
|
||||
|
||||
for i in range(runs):
|
||||
start_time = time.time()
|
||||
print(f'\nRunning [{i+1}/{args.runs}] of [{args.model}] on Dataset [{dataset_name}]')
|
||||
|
||||
t1 = time.time()
|
||||
|
||||
clf = model_class(**model_configs, random_state=42+i)
|
||||
clf.fit(train_data)
|
||||
scores = clf.decision_function(test_data)
|
||||
|
@ -127,12 +180,12 @@ for dataset in dataset_name_lst:
|
|||
|
||||
f = open(result_file, 'a')
|
||||
txt = '%s, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, ' \
|
||||
'%.4f, %.4f, %.4f, %.4f, %.1f, %s ' % \
|
||||
'%.4f, %.4f, %.4f, %.4f, %.1f, %s, %s ' % \
|
||||
(dataset_name,
|
||||
avg_entry[0], std_entry[0], avg_entry[1], std_entry[1],
|
||||
avg_entry[2], std_entry[2], avg_entry[3], std_entry[3],
|
||||
avg_entry[4], std_entry[4],
|
||||
np.average(t_lst), args.model)
|
||||
np.average(t_lst), args.model, str(model_configs))
|
||||
print(txt)
|
||||
print(txt, file=f)
|
||||
f.close()
|
||||
|
|
Loading…
Reference in New Issue