-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit b7a27c0
Showing
111 changed files
with
11,719 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Custom | ||
*.idea | ||
tmp/ | ||
*.txt | ||
!requirements.txt | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
env/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*,cover | ||
.hypothesis/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
# docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# IPython Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# dotenv | ||
.env | ||
|
||
# virtualenv | ||
venv/ | ||
ENV/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
*.pickle | ||
.vscode | ||
|
||
# checkpoint | ||
*.h5 | ||
*.pth | ||
|
||
# Experimental | ||
experimental/ | ||
|
||
# Mac files | ||
.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from .config import Config | ||
from .utils.seed import set_seed | ||
from .version import __version__ | ||
from .surrogater import Surrogater | ||
from . import attack, data, defense, functional, models, nn, training, utils | ||
|
||
__all__ = ['__version__', 'Config', 'set_seed', 'Surrogater', | ||
'data', 'attack', 'defense', | ||
'models', 'training', 'nn', 'functional', 'utils'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .attacker import Attacker | ||
from .flip_attacker import FlipAttacker | ||
|
||
classes = __all__ = ['Attacker', 'FlipAttacker'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
import abc | ||
from numbers import Number | ||
from typing import Optional, Union | ||
|
||
import dgl | ||
import numpy as np | ||
import scipy.sparse as sp | ||
import torch | ||
|
||
from graphattack import Config, set_seed | ||
|
||
_FEATURE = Config.feat | ||
_LABEL = Config.label | ||
|
||
|
||
class Attacker(torch.nn.Module): | ||
"""Adversarial attacker for graph data. | ||
For example, the attacker model should be defined as follows: | ||
>>> attacker = Attacker(graph, device='cuda') | ||
>>> attacker.reset() # reset states | ||
>>> attacker.attack(attack_arguments) | ||
""" | ||
_max_perturbations: Union[float, int] = 0 | ||
_allow_feature_attack: bool = False | ||
_allow_structure_attack: bool = True | ||
_allow_singleton: bool = True | ||
|
||
def __init__(self, graph: dgl.DGLGraph, device: str = "cpu", | ||
seed: Optional[int] = None, name: Optional[str] = None, **kwargs): | ||
f"""Initialization of an attacker model. | ||
Parameters | ||
---------- | ||
graph : dgl.DGLGraph | ||
the DGL graph. If the attack requires node features, | ||
`graph.ndata[{_FEATURE}]` should be specified. | ||
If the attack requires node labels, | ||
`graph.ndata[{_LABEL}]` should be specified | ||
device : str, optional | ||
the device of the attack running on, by default "cpu" | ||
seed : Optional[int], optional | ||
the random seed of reproduce the attack, by default None | ||
name : Optional[str], optional | ||
name of the attacker, if None, it would be `__class__.__name__`, | ||
by default None | ||
kwargs : optional | ||
additional arguments of :class:`graphattack.attack.Attacker`, | ||
including (`{_FEATURE}`, `{_LABEL}`) to specify the node features | ||
and the node labels, if they are not in `graph.ndata` | ||
Note | ||
---- | ||
* If the attack requires node features, | ||
`graph.ndata[{_FEATURE}]` should be specified. | ||
* If the attack requires node labels, | ||
`graph.ndata[{_LABEL}]` should be specified. | ||
""" | ||
super().__init__() | ||
feat = kwargs.pop(_FEATURE, None) | ||
label = kwargs.pop(_LABEL, None) | ||
|
||
if kwargs: | ||
raise TypeError( | ||
f"Got an unexpected keyword argument '{next(iter(kwargs.keys()))}' " | ||
f"expected ({_FEATURE}, {_LABEL})." | ||
) | ||
|
||
self.device = torch.device(device) | ||
self._graph = graph.to(self.device) | ||
|
||
if feat is not None: | ||
feat = torch.as_tensor(feat, dtype=torch.float32, device=self.device) | ||
assert feat.size(0) == graph.num_nodes() | ||
else: | ||
feat = self.graph.ndata.get(_FEATURE, None) | ||
|
||
if label is not None: | ||
label = torch.as_tensor(label, dtype=torch.long, device=self.device) | ||
else: | ||
label = self.graph.ndata.get(_LABEL, None) | ||
|
||
setattr(self, '_' + _FEATURE, feat) | ||
setattr(self, '_' + _LABEL, label) | ||
|
||
self.adjacency_matrix: sp.csr_matrix = graph.adjacency_matrix(scipy_fmt='csr') | ||
self.name = name or self.__class__.__name__ | ||
self.seed = seed | ||
|
||
self._degree = self._graph.in_degrees() | ||
|
||
self.edges = self._graph.edges() | ||
self.nodes = self._graph.nodes() | ||
self.num_nodes = self._graph.num_nodes() | ||
self.num_edges = self._graph.num_edges() // 2 | ||
self.num_feats = feat.size(-1) if feat is not None else None | ||
|
||
set_seed(seed) | ||
|
||
self.is_reseted = False | ||
|
||
def reset(self): | ||
self.is_reseted = True | ||
return self | ||
|
||
def g(self): | ||
raise NotImplementedError | ||
|
||
@property | ||
def feat(self): | ||
return getattr(self, '_' + _FEATURE, None) | ||
|
||
@property | ||
def label(self): | ||
return getattr(self, '_' + _LABEL, None) | ||
|
||
@property | ||
def graph(self): | ||
return self._graph | ||
|
||
@abc.abstractmethod | ||
def attack(self) -> "Attacker": | ||
"""defined for attacker model.""" | ||
raise NotImplementedError | ||
|
||
def _check_budget(self, num_budgets: Union[float, int], | ||
max_perturbations: Union[float, int]) -> int: | ||
|
||
max_perturbations = max(max_perturbations, self.max_perturbations) | ||
|
||
if not isinstance(num_budgets, Number) or num_budgets <= 0: | ||
raise ValueError( | ||
f"'num_budgets' must be a postive scalar. but got '{num_budgets}'." | ||
) | ||
|
||
if num_budgets > max_perturbations: | ||
raise ValueError( | ||
f"'num_budgets' should be less than or equal the maximum allowed perturbations: {max_perturbations}." | ||
"if you want to use larger budgets, you could set 'attacker.set_max_perturbations(a_larger_budget)'." | ||
) | ||
|
||
if num_budgets < 1.: | ||
assert self._max_perturbations != np.inf | ||
num_budgets = max_perturbations * num_budgets | ||
|
||
return int(num_budgets) | ||
|
||
def set_max_perturbations(self, max_perturbations: Union[float, int] = np.inf, | ||
verbose: bool = True): | ||
assert isinstance(max_perturbations, Number), max_perturbations | ||
self._max_perturbations = max_perturbations | ||
if verbose: | ||
print(f"Set maximum perturbations: {max_perturbations}") | ||
|
||
@property | ||
def max_perturbations(self) -> Union[float, int]: | ||
return self._max_perturbations | ||
|
||
def _check_feature_matrix_exists(self): | ||
if self.feat is None: | ||
raise RuntimeError("Node feature matrix does not exist" | ||
f", please add node feature data externally via `g.ndata['{_FEATURE}'] = {_FEATURE}` " | ||
f"or initialize via `attacker = {self.__class__.__name__}(g, {_FEATURE}={_FEATURE})`.") | ||
|
||
def _check_node_label_exists(self): | ||
if self.label is None: | ||
raise RuntimeError("Node labels does not exist" | ||
f", please add node labels externally via `g.ndata['{_LABEL}'] = {_LABEL}` " | ||
f"or initialize via `attacker = {self.__class__.__name__}(g, {_LABEL}={_LABEL})`.") | ||
|
||
def _check_feature_matrix_binary(self): | ||
self._check_feature_matrix_exists() | ||
feat = self.feat | ||
feat = feat[torch.randint(0, feat.size(0), size=(10,))] | ||
if not torch.unique(feat).tolist() == [0, 1]: | ||
raise RuntimeError("Node feature matrix is required to be a 0-1 binary matrix.") | ||
|
||
def extra_repr(self) -> str: | ||
return f"device={self.device}, seed={self.seed}," |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .backdoor_attacker import BackdoorAttacker | ||
from .fg_backdoor import FGBackdoor | ||
from .lgc_backdoor import LGCBackdoor |
85 changes: 85 additions & 0 deletions
85
GraphAttack/graphattack/attack/backdoor/backdoor_attacker.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from functools import lru_cache | ||
from typing import Optional, Union | ||
|
||
import dgl | ||
import torch | ||
from torch import Tensor | ||
|
||
from graphattack import Config | ||
|
||
from ..attacker import Attacker | ||
|
||
_FEATURE = Config.feat | ||
|
||
|
||
class BackdoorAttacker(Attacker): | ||
|
||
def reset(self) -> "BackdoorAttacker": | ||
"""Reset the state of the Attacker | ||
Returns | ||
------- | ||
BackdoorAttacker | ||
the attacker itself | ||
""" | ||
self.num_budgets = None | ||
self._trigger = None | ||
self.is_reseted = True | ||
|
||
return self | ||
|
||
def attack(self, num_budgets: Union[int, float], targets_class: int) -> "BackdoorAttacker": | ||
"""Base method that describes the adversarial backdoor attack | ||
""" | ||
|
||
_is_setup = getattr(self, "_is_setup", True) | ||
|
||
if not _is_setup: | ||
raise RuntimeError( | ||
f'{self.__class__.__name__} requires a surrogate model to conduct attack. ' | ||
'Use `attacker.setup_surrogate(surrogate_model)`.') | ||
|
||
if not self.is_reseted: | ||
raise RuntimeError( | ||
'Before calling attack, you must reset your attacker. Use `attacker.reset()`.' | ||
) | ||
|
||
num_budgets = self._check_budget( | ||
num_budgets, max_perturbations=self.num_feats) | ||
|
||
self.num_budgets = num_budgets | ||
self.targets_class = torch.LongTensor([targets_class]).view(-1).to(self.device) | ||
self.is_reseted = False | ||
|
||
return self | ||
|
||
def trigger(self,): | ||
return self._trigger | ||
|
||
def g(self, target_node: int, symmetric: bool = True) -> dgl.DGLGraph: | ||
"""return the attacked graph | ||
Parameters | ||
---------- | ||
target_node : int | ||
the target node that the attack performed | ||
symmetric : bool | ||
determine whether the resulting graph is forcibly symmetric, | ||
by default True | ||
Returns | ||
------- | ||
dgl.DGLGraph | ||
the attacked graph with backdoor attack performed on the target node | ||
""" | ||
graph = self.graph.local_var() | ||
num_nodes = self.num_nodes | ||
data = self.trigger().view(1, -1) | ||
|
||
graph.add_nodes(1, data={_FEATURE: data}) | ||
graph.add_edges(num_nodes, target_node) | ||
|
||
if symmetric: | ||
graph.add_edges(target_node, num_nodes) | ||
|
||
return graph |
Oops, something went wrong.