Skip to content

Commit

Permalink
add pruned_struc_graph
Browse files Browse the repository at this point in the history
  • Loading branch information
WDdeBWT committed Jul 2, 2020
1 parent 3d3a6a9 commit d1d2276
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 35 deletions.
12 changes: 8 additions & 4 deletions code/cf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,20 @@ def _build_interaction_graph(self):
g.ndata['sqrt_degree'] = 1 / torch.sqrt(g.out_degrees().float().unsqueeze(-1))
return g

def build_struc_graphs(self, sumed=True):
def build_struc_graphs(self, mode=0):
nx_rec_g = nx.Graph()
nx_rec_g.add_nodes_from(range(self.n_users + self.n_items))
edges = np.concatenate((self.train_data[0].reshape(-1, 1), self.train_data[1].reshape(-1, 1) + self.n_users), 1)
nx_rec_g.add_edges_from(edges)
s2v = Struc2Vec(nx_rec_g, workers=4, verbose=40, opt3_num_layers=3, reuse=True)
if sumed:
g_list = [s2v.get_sumed_struc_graph()]
else:
if mode == 0: # general
g_list = s2v.get_struc_graphs()
elif mode == 1: # sumed
g_list = [s2v.get_sumed_struc_graph()]
elif mode == 2: # last
g_list = s2v.get_struc_graphs()[-1:]
elif mode == 3: # prune
g_list = [s2v.get_pruned_struc_graph()]
return g_list

# def __len__(self): # first version
Expand Down
3 changes: 2 additions & 1 deletion code/gcn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,11 @@ def AggregateUnweighted(g, entity_embed):
def AggregateWeighted(g, entity_embed):
# try to use a static func instead of a object
g = g.local_var()
g.ndata['node'] = entity_embed
g.ndata['node'] = entity_embed * g.ndata['out_sqrt_degree']
g.update_all(dgl.function.u_mul_e('node', 'weight', 'side'), dgl.function.sum(msg='side', out='N_h'))
# g.update_all(lambda edges: {'side' : edges.src['node'] * edges.data['weight']},
# lambda nodes: {'N_h': torch.sum(nodes.mailbox['side'], 1)})
g.ndata['N_h'] = g.ndata['N_h'] * g.ndata['in_sqrt_degree']
return g.ndata['N_h']


Expand Down
62 changes: 40 additions & 22 deletions code/s2vec/struc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def get_struc_graphs(self):
n_nodes = len(self.idx)
struc_graphs = []
for layer in self.layers_adj:
# times = 0
g = dgl.DGLGraph()
g.add_nodes(n_nodes)
edge_list = []
Expand Down Expand Up @@ -81,7 +80,6 @@ def get_struc_graphs(self):
g.ndata['id'] = torch.arange(n_nodes, dtype=torch.long)
g.edata['weight'] = torch.tensor(edge_weight_list).float().unsqueeze(-1)
struc_graphs.append(g)
# print('times', times)
return struc_graphs

def get_sumed_struc_graph(self):
Expand Down Expand Up @@ -122,6 +120,46 @@ def get_sumed_struc_graph(self):
g.edata['weight'] = torch.tensor(edge_weight_list).float().unsqueeze(-1)
return g

def get_pruned_struc_graph(self):
# build dgl graph of last layer and prune the low weight
n_nodes = len(self.idx)
layer = max(self.layers_adj.keys())
g = dgl.DGLGraph()
g.add_nodes(n_nodes)
edge_list = []
edge_weight_list = []
neighbors_dict = self.layers_adj[layer]
layer_sim_scores = self.layers_sim_scores[layer]
for v, neighbors in neighbors_dict.items():
max_score = 0.0
for n in neighbors:
if (v, n) in layer_sim_scores:
sim_score = layer_sim_scores[v, n]
else:
sim_score = layer_sim_scores[n, v]
max_score = sim_score if sim_score > max_score else max_score
mid_score = max_score / 100 # cut the low sim edge
if mid_score == 0:
continue
for n in neighbors:
if (v, n) in layer_sim_scores:
sim_score = layer_sim_scores[v, n]
else:
sim_score = layer_sim_scores[n, v]
if sim_score > mid_score:
edge_list.append((n, v)) # form n to v
edge_weight_list.append(sim_score)
edge_list = np.array(edge_list, dtype=int)
g.add_edges(edge_list[:, :1].squeeze(), edge_list[:, 1:].squeeze())
g.readonly()
g.ndata['id'] = torch.arange(n_nodes, dtype=torch.long)
g.edata['weight'] = torch.tensor(edge_weight_list).float().unsqueeze(-1)
g.ndata['out_sqrt_degree'] = 1 / torch.sqrt(g.out_degrees().float().unsqueeze(-1))
g.ndata['in_sqrt_degree'] = 1 / torch.sqrt(g.in_degrees().float().unsqueeze(-1))
g.ndata['out_sqrt_degree'][torch.isinf(g.ndata['out_sqrt_degree'])] = 0
g.ndata['in_sqrt_degree'][torch.isinf(g.ndata['in_sqrt_degree'])] = 0
return g

def create_context_graph(self, max_num_layers, workers=1, verbose=0,):
print(str(time.asctime(time.localtime(time.time()))) + ' create_context_graph')
pair_distances = self._compute_structural_distance(max_num_layers, workers, verbose)
Expand Down Expand Up @@ -289,28 +327,8 @@ def _get_layer_rep(self, pair_distances):
layers_adj[layer][vx].append(vy)
layers_adj[layer][vy].append(vx)

# self.norm_sim_score(layers_adj, layers_sim_scores)
return layers_adj, layers_sim_scores

def norm_sim_score(self, layers_adj, layers_sim_scores):
print(str(time.asctime(time.localtime(time.time()))) + ' norm_sim_score')
for layer in layers_adj:
neighbors_dict = layers_adj[layer]
layer_sim_scores = layers_sim_scores[layer]
for v, neighbors in neighbors_dict.items():
sum_score = 0.0
for n in neighbors:
if (v, n) in layer_sim_scores:
sim_score = layer_sim_scores[v, n]
else:
sim_score = layer_sim_scores[n, v]
sum_score += sim_score
for n in neighbors:
if (v, n) in layer_sim_scores:
layer_sim_scores[v, n] = layer_sim_scores[v, n] / sum_score
else:
layer_sim_scores[n, v] = layer_sim_scores[n, v] / sum_score


def cost(a, b):
ep = 0.5
Expand Down
22 changes: 14 additions & 8 deletions code/script_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from gcn_model import CFGCN
from metrics import precision_and_recall, ndcg, auc

EPOCH = 1000
LR = 0.001
EPOCH = 200
LR = 0.005
EDIM = 64
LAYERS = 3
LAM = 1e-4
Expand Down Expand Up @@ -101,11 +101,15 @@ def test(data_set, model, data_loader, show_auc = False):
itra_G = data_set.get_interaction_graph()
itra_G.ndata['id'] = itra_G.ndata['id'].to(device) # move graph data to target device
itra_G.ndata['sqrt_degree'] = itra_G.ndata['sqrt_degree'].to(device) # move graph data to target device
# struc_Gs = data_set.build_struc_graphs(sumed=True)
# for g in struc_Gs:
# g.ndata['id'] = g.ndata['id'].to(device)
# g.edata['weight'] = g.edata['weight'].to(device)
struc_Gs = None
struc_Gs = data_set.build_struc_graphs(mode=3)
for g in struc_Gs:
g.ndata['id'] = g.ndata['id'].to(device)
g.edata['weight'] = g.edata['weight'].to(device)
if 'out_sqrt_degree' in g.ndata and 'in_sqrt_degree' in g.ndata:
g.ndata['out_sqrt_degree'] = g.ndata['out_sqrt_degree'].to(device)
g.ndata['in_sqrt_degree'] = g.ndata['in_sqrt_degree'].to(device)
else:
assert False
n_users = data_set.get_user_num()
n_items = data_set.get_item_num()
model = CFGCN(n_users, n_items, itra_G, struc_Gs=struc_Gs, embed_dim=EDIM, n_layers=LAYERS, lam=LAM).to(device)
Expand Down Expand Up @@ -134,11 +138,13 @@ def test(data_set, model, data_loader, show_auc = False):
# train loss 0.015; evaluate loss 0.134
# test result: precision 0.047575934218717066; recall 0.16351703048292573; ndcg 0.13673274095554458
# max
# recall 0.172; ndcg 0.143
# precision 0.051 recall 0.174; ndcg 0.147

# Paper code at epoch 50 gowalla
# {'precision': array([0.04382075]), 'recall': array([0.14503336]), 'ndcg': array([0.12077126]), 'auc': 0.9587075653077938}
# Paper code at epoch 80 gowalla
# {'precision': array([0.0468166]), 'recall': array([0.15585551]), 'ndcg': array([0.13010746]), 'auc': 0.9582199598920466}
# Paper code at epoch 400 gowalla
# {'precision': array([0.05400898]), 'recall': array([0.17730673]), 'ndcg': array([0.15099276]), 'auc': 0.9508640011701547}
# max in paper
# precision 0.055 recall 0.182; ndcg 0.154

0 comments on commit d1d2276

Please sign in to comment.