Skip to content

Commit

Permalink
STYLE: enforce PEP8 and linting
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzobalzani committed May 10, 2024
1 parent 5d0eabd commit c0d8e21
Show file tree
Hide file tree
Showing 20 changed files with 298 additions and 129 deletions.
2 changes: 1 addition & 1 deletion .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ data/medqa/train.statement.jsonl filter=lfs diff=lfs merge=lfs -text
data/medqa/train.grounded.json filter=lfs diff=lfs merge=lfs -text
data/medqa/concept_names.tsv filter=lfs diff=lfs merge=lfs -text
data/umls/UMLS_KG_triplets_sample.txt filter=lfs diff=lfs merge=lfs -text
data/umls/*.pt filter=lfs diff=lfs merge=lfs -text
data/umls/*.pt filter=lfs diff=lfs merge=lfs -text
34 changes: 34 additions & 0 deletions .github/workflows/lint_and_style.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: lint_and_style

on:
pull_request:
push:
branches:
- main
- master

jobs:
pre-commit:
runs-on: ubuntu-latest
if: github.event_name == 'pull_request' || (github.ref == 'refs/heads/main' && github.event_name == 'push') || (github.ref == 'refs/heads/master' && github.event_name == 'push')
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: '3.9'
- run: pip install pylint & pip install -r docker/requirements.txt
- uses: pre-commit/action@v3.0.1
pylint:
runs-on: ubuntu-latest
needs: pre-commit
continue-on-error: true
if: false #github.event_name == 'pull_request' || (github.ref == 'refs/heads/main' && github.event_name == 'push') || (github.ref == 'refs/heads/master' && github.event_name == 'push')
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: '3.9'
- run: pip install pylint & pip install -r docker/requirements.txt
- uses: pre-commit/action@v3.0.1
with:
extra_args: --hook-stage manual pylint-all --all-files
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
**/.bashrc
**/.DS_Store
**/__pycache__/
.idea/
.idea/
65 changes: 65 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-docstring-first
- id: check-toml
- id: check-yaml
- id: mixed-line-ending
args: [--fix=lf]
- id: end-of-file-fixer
- repo: https://gitlab.com/bmares/check-json5
rev: v1.0.0
hooks:
- id: check-json5
exclude: data/.*
- repo: https://github.com/hhatto/autopep8
rev: v2.1.0
hooks:
- id: autopep8
args: [--in-place, --aggressive, --exit-code]
types: [python]
- repo: local
hooks:
- id: pylint
name: pylint
entry: pylint
language: system
types: [python]
args:
[
"--max-line-length=120",
"--errors-only",
]
- id: pylint
alias: pylint-all
name: pylint-all
entry: pylint
language: system
types: [python]
args:
[
"--max-line-length=120",
]
stages: [manual]
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
hooks:
- id: flake8
args: [--config=setup.cfg]
additional_dependencies:
- flake8-bugbear==22.10.27
- flake8-comprehensions==3.10.1
- torchfix==0.0.2
- repo: https://github.com/facebook/usort
rev: v1.0.7
hooks:
- id: usort
name: Sort imports with µsort
description: Safe, minimal import sorting
language: python
types_or:
- python
- pyi
entry: usort format
require_serial: true
Empty file added .pylintrc
Empty file.
4 changes: 2 additions & 2 deletions data/medqa/train.grounded.json
Git LFS file not shown
21 changes: 21 additions & 0 deletions docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
torch==2.2.0+cu118
torchvision==0.17.0+cu118
torchaudio==2.2.0+cu118
git+https://github.com/huggingface/transformers
pykeen==1.10.1
scikit-learn==1.3.1
networkx
huggingface_hub
datasets
evaluate
nvidia-ml-py3
accelerate
sentence-transformers
nltk
tqdm
wget
gdown
wrapt==1.12.1
-i https://pypi.nvidia.com
cudf-cu11
cugraph-cu11
77 changes: 39 additions & 38 deletions link_prediction/link_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
import sys
from typing import Dict, Optional

import pandas as pd
import torch
from arg_parser import parse_arguments
from huggingface_hub import PyTorchModelHubMixin
from pykeen.models import DistMult

from pykeen.nn import TextRepresentation
from pykeen.pipeline import PipelineResult, pipeline_from_path
from pykeen.triples import TriplesFactory
from pykeen.pipeline import pipeline_from_path, PipelineResult
from pykeen.predict import predict_target, Predictions
from pykeen.models import DistMult
from huggingface_hub import PyTorchModelHubMixin
import torch
import pandas as pd
from pykeen.triples import TriplesFactory
from utils.sentence_transformer_encoder import SentenceTransformerEncoder

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))

from utils.sentence_transformer_encoder import SentenceTransformerEncoder
from arg_parser import parse_arguments


class LinkPredictor(torch.nn.Module, PyTorchModelHubMixin):
def __init__(
Expand All @@ -43,11 +43,11 @@ def __init__(
self._entity_representations = TextRepresentation.from_triples_factory(
triples_factory=self._triples_factory,
encoder=SentenceTransformerEncoder,
encoder_kwargs=dict(
encoder_model_name_or_path=encoder_model_name_or_path,
embedding_dim=encoder_embedding_dim,
device=self._device.type,
),
encoder_kwargs={
"encoder_model_name_or_path": encoder_model_name_or_path,
"embedding_dim": encoder_embedding_dim,
"device": self._device.type,
},
)

self._pykeen_model = DistMult(
Expand Down Expand Up @@ -92,7 +92,8 @@ def _load_umls_kg(
def print_metrics(cls, result: PipelineResult) -> Dict[str, str]:
# MR (Mean Rank): [1, num_entities]; the lower it is, the better the model results are. It's out of 9958
# MRR (Mean Reciprocal Rank): [0, 1]; the higher it is, the better the model results
# IGMR (Inverse Geometric Mean Rank): [0, 1]; the higher it is, the better the model results
# IGMR (Inverse Geometric Mean Rank): [0, 1]; the higher it is, the
# better the model results
metrics = {
metric: str(result.get_metric(metric)) for metric in ["mr", "mrr", "igmr"]
}
Expand Down Expand Up @@ -121,29 +122,29 @@ def train_pykeen_model(

print(f"Loading config from {pykeen_config}")

pipeline_params = dict(
model=self._pykeen_model,
device=self._device,
path=pykeen_config,
training=train_set,
validation=val_set,
testing=test_set,
stopper="early",
stopper_kwargs=dict(
frequency=es_frequency,
patience=es_patience,
relative_delta=es_relative_delta,
),
random_seed=random_state,
use_tqdm=True,
training_kwargs=dict(
num_epochs=n_epochs,
batch_size=batch_size,
checkpoint_directory=self._pykeen_checkpoint_folder,
checkpoint_name=self._pykeen_checkpoint_file,
checkpoint_frequency=checkpoint_frequency,
),
)
pipeline_params = {
"model": self._pykeen_model,
"device": self._device,
"path": pykeen_config,
"training": train_set,
"validation": val_set,
"testing": test_set,
"stopper": "early",
"stopper_kwargs": {
"frequency": es_frequency,
"patience": es_patience,
"relative_delta": es_relative_delta,
},
"random_seed": random_state,
"use_tqdm": True,
"training_kwargs": {
"num_epochs": n_epochs,
"batch_size": batch_size,
"checkpoint_directory": self._pykeen_checkpoint_folder,
"checkpoint_name=": self._pykeen_checkpoint_file,
"checkpoint_frequency": checkpoint_frequency,
},
}

pipeline_result = pipeline_from_path(**pipeline_params)
pipeline_result.save_to_directory("training_" + self._pykeen_checkpoint_file)
Expand Down
2 changes: 1 addition & 1 deletion link_prediction/run_train_distmult.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ python3.9 link_predictor.py \
--huggingface_hub_url neural-subgraph-retrieval/umls-link-predictor \
--pykeen_train_config pykeen_config.json \
--n_epochs 50 \
--pykeen_checkpoint_file distmult.pt
--pykeen_checkpoint_file distmult.pt
2 changes: 1 addition & 1 deletion link_prediction/run_train_pubmedbert.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ python3.9 link_predictor.py \
--encoder_model_name_or_path NeuML/pubmedbert-base-embeddings-matryoshka \
--pykeen_train_config pykeen_config.json \
--n_epochs 50 \
--pykeen_checkpoint_file pubmedbert.pt
--pykeen_checkpoint_file pubmedbert.pt
35 changes: 17 additions & 18 deletions next_relation_prediction/dataset/main.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
import sys
import datetime
import os
import sys
from typing import Dict, List, Optional, Set, Union

sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))

import torch
import networkx as nx
from typing import Optional, List, Dict, Union, Set
import datetime
from utils.coder import Coder
from arg_parser import parse_arguments
from utils.loaders import load_medqa_dataset, load_umls
from utils.general import (
split_and_push_to_hf_hub,
check_device,
)
import pandas as pd
import torch
from arg_parser import parse_arguments
from tqdm import tqdm
from utils.coder import Coder
from utils.general import check_device, split_and_push_to_hf_hub
from utils.loaders import load_medqa_dataset, load_umls

sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))


def find_shortest_path_per_question(
Expand All @@ -38,7 +35,7 @@ def find_shortest_path_per_question(
medqa_answers (pd.DataFrame): A Pandas DataFrame containing medical question answers and choices.
concept_names (pd.DataFrame): A Pandas DataFrame mapping concept IDs to concept names.
question_idx (int): Index of the question to process.
normalize_nodes (bool, optional): Whether to use CODER model to perform term-normalization on nodes. Default to True.
normalize_nodes (bool, optional): Whether to use CODER model to perform term-normalization on nodes. Default to True
debug_print (bool, optional): Whether to print debug information. Default is False.
Returns:
Expand All @@ -53,12 +50,13 @@ def find_shortest_path_per_question(
not_found_nodes: int = 0
not_found_paths: int = 0
all_topic_entities: Set[str] = set()
dataset_rows: List[Dict[str, str]] = list()
dataset_rows: List[Dict[str, str]] = []

# Extract the dataset row and the question
row = medqa_answers.iloc[question_idx]
question = row["question"]["stem"]
# Get the 4 possible answers, with index and text, e.g. 3 and "Nitrofurantoin" (idx ranges from 0 to 3).
# Get the 4 possible answers, with index and text, e.g. 3 and
# "Nitrofurantoin" (idx ranges from 0 to 3).
possible_choices = [
{"idx": idx, "text": choice["text"]}
for idx, choice in enumerate(row["question"]["choices"])
Expand Down Expand Up @@ -105,7 +103,8 @@ def find_shortest_path_per_question(
try:
all_topic_entities.add(topic_entity)

# Extract all edges that form the path between the source and the target nodes
# Extract all edges that form the path between the source and
# the target nodes
for path in [paths_to_answer[topic_entity]]:
entity_2_answer_relations = [
kg[path[i]][path[i + 1]]["relation"]
Expand Down Expand Up @@ -183,7 +182,7 @@ def create_dataset(
not_found_paths: int = 0
n_errors: int = 0

dataset: List[Dict[str, str]] = list()
dataset: List[Dict[str, str]] = []

for question_idx in tqdm(
range(len(medqa_answers[:until_idx])), leave=True, desc="Creating dataset"
Expand Down
2 changes: 1 addition & 1 deletion next_relation_prediction/dataset/run_create_dataset.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ python3.9 main.py \
--coder_path balzanilo/UMLSBert_ENG \
--output_path ../../data/paths_dataset.tsv \
--huggingface_hub_url neural-subgraph-retrieval/umls-nrp-dataset \
--stats_folder ../../data/umls-next-relation-prediction
--stats_folder ../../data/umls-next-relation-prediction
2 changes: 1 addition & 1 deletion next_relation_prediction/model/accelerate_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
use_cpu: false
Loading

0 comments on commit c0d8e21

Please sign in to comment.