Skip to content

Commit

Permalink
add Label Relaxation
Browse files Browse the repository at this point in the history
  • Loading branch information
adelmemariani committed Dec 12, 2024
1 parent 00aa49d commit f18bde2
Show file tree
Hide file tree
Showing 46 changed files with 883 additions and 0 deletions.
707 changes: 707 additions & 0 deletions dicee.egg-info/PKG-INFO

Large diffs are not rendered by default.

95 changes: 95 additions & 0 deletions dicee.egg-info/SOURCES.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
LICENSE
MANIFEST.in
README.md
requirements.txt
setup.py
dicee/__init__.py
dicee/__main__.py
dicee/abstracts.py
dicee/analyse_experiments.py
dicee/callbacks.py
dicee/config.py
dicee/dataset_classes.py
dicee/eval_static_funcs.py
dicee/evaluator.py
dicee/executer.py
dicee/knowledge_graph.py
dicee/knowledge_graph_embeddings.py
dicee/lp.png
dicee/query_generator.py
dicee/sanity_checkers.py
dicee/static_funcs.py
dicee/static_funcs_training.py
dicee/static_preprocess_funcs.py
dicee.egg-info/PKG-INFO
dicee.egg-info/SOURCES.txt
dicee.egg-info/dependency_links.txt
dicee.egg-info/entry_points.txt
dicee.egg-info/requires.txt
dicee.egg-info/top_level.txt
dicee/figures/deploy_qmult_family.png
dicee/models/__init__.py
dicee/models/adopt.py
dicee/models/base_model.py
dicee/models/clifford.py
dicee/models/complex.py
dicee/models/dualE.py
dicee/models/ensemble.py
dicee/models/function_space.py
dicee/models/octonion.py
dicee/models/pykeen_models.py
dicee/models/quaternion.py
dicee/models/real.py
dicee/models/static_funcs.py
dicee/models/transformers.py
dicee/read_preprocess_save_load_kg/__init__.py
dicee/read_preprocess_save_load_kg/preprocess.py
dicee/read_preprocess_save_load_kg/read_from_disk.py
dicee/read_preprocess_save_load_kg/save_load_disk.py
dicee/read_preprocess_save_load_kg/util.py
dicee/scripts/__init__.py
dicee/scripts/index_serve.py
dicee/scripts/run.py
dicee/trainer/__init__.py
dicee/trainer/dice_trainer.py
dicee/trainer/model_parallelism.py
dicee/trainer/torch_trainer.py
dicee/trainer/torch_trainer_ddp.py
tests/__init__.py
tests/test_adaptive_swa.py
tests/test_answer_multi_hop_query.py
tests/test_auto_batch_finder.py
tests/test_continual_training.py
tests/test_custom_trainer.py
tests/test_deployment.py
tests/test_different_backends.py
tests/test_download_and_eval.py
tests/test_ensemble_construction.py
tests/test_execute_start.py
tests/test_gradient_accumulation.py
tests/test_inductive_regression.py
tests/test_k_fold_cv_1_vs_all.py
tests/test_k_fold_cv_k_vs_all.py
tests/test_k_fold_cv_neg_sample.py
tests/test_large_kg.py
tests/test_link_prediction_evaluation.py
tests/test_onevssample.py
tests/test_online_learning.py
tests/test_pickle.py
tests/test_pykeen.py
tests/test_read_few_only.py
tests/test_regression_DualE.py
tests/test_regression_aconex.py
tests/test_regression_all_vs_all.py
tests/test_regression_cl.py
tests/test_regression_clifford.py
tests/test_regression_complex.py
tests/test_regression_conex.py
tests/test_regression_convo.py
tests/test_regression_distmult.py
tests/test_regression_model_paralelisim.py
tests/test_regression_omult.py
tests/test_regression_pyke.py
tests/test_regression_qmult.py
tests/test_saving_embeddings.py
tests/test_trainers.py
1 change: 1 addition & 0 deletions dicee.egg-info/dependency_links.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

3 changes: 3 additions & 0 deletions dicee.egg-info/entry_points.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[console_scripts]
dicee = dicee.scripts.run:main
dicee_vector_db = dicee.scripts.index_serve:main
40 changes: 40 additions & 0 deletions dicee.egg-info/requires.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
pandas>=2.1.0
polars==1.9.0
pyarrow>=11.0.0
rdflib>=7.0.0
torch>=2.5.1
lightning>=2.1.3
tiktoken>=0.5.1
psutil>=5.9.4
matplotlib>=3.8.2
pykeen>=1.10.2
numpy==1.26.4

[dev]
pandas>=2.1.0
polars==1.9.0
pyarrow>=11.0.0
rdflib>=7.0.0
torch>=2.5.1
lightning>=2.1.3
tiktoken>=0.5.1
psutil>=5.9.4
matplotlib>=3.8.2
pykeen>=1.10.2
numpy==1.26.4
ruff>=0.0.284
pytest>=7.2.2
scikit-learn>=1.2.2

[min]
pandas>=2.1.0
polars==1.9.0
pyarrow>=11.0.0
rdflib>=7.0.0
torch>=2.5.1
lightning>=2.1.3
tiktoken>=0.5.1
psutil>=5.9.4
matplotlib>=3.8.2
pykeen>=1.10.2
numpy==1.26.4
2 changes: 2 additions & 0 deletions dicee.egg-info/top_level.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
dicee
tests
Binary file added dicee/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file added dicee/__pycache__/abstracts.cpython-310.pyc
Binary file not shown.
Binary file added dicee/__pycache__/callbacks.cpython-310.pyc
Binary file not shown.
Binary file added dicee/__pycache__/dataset_classes.cpython-310.pyc
Binary file not shown.
Binary file added dicee/__pycache__/evaluator.cpython-310.pyc
Binary file not shown.
Binary file added dicee/__pycache__/executer.cpython-310.pyc
Binary file not shown.
Binary file added dicee/__pycache__/knowledge_graph.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file added dicee/__pycache__/query_generator.cpython-310.pyc
Binary file not shown.
Binary file added dicee/__pycache__/sanity_checkers.cpython-310.pyc
Binary file not shown.
Binary file added dicee/__pycache__/static_funcs.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added dicee/models/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file added dicee/models/__pycache__/adopt.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file added dicee/models/__pycache__/clifford.cpython-310.pyc
Binary file not shown.
Binary file added dicee/models/__pycache__/complex.cpython-310.pyc
Binary file not shown.
Binary file added dicee/models/__pycache__/dualE.cpython-310.pyc
Binary file not shown.
Binary file added dicee/models/__pycache__/ensemble.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file added dicee/models/__pycache__/octonion.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added dicee/models/__pycache__/real.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
35 changes: 35 additions & 0 deletions dicee/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,38 @@
from torch.nn import functional as F
from .adopt import ADOPT

class LabelRelaxationLoss(nn.Module):
def __init__(self, alpha=0.1, dim=-1, logits_provided=True, one_hot_encode_trgts=False, num_classes=-1):
super(LabelRelaxationLoss, self).__init__()
self.alpha = alpha
self.dim = dim
self.gz_threshold = 0.1
self.logits_provided = logits_provided
self.one_hot_encode_trgts = one_hot_encode_trgts
self.num_classes = num_classes

def forward(self, pred, target):
if self.logits_provided:
pred = pred.softmax(dim=self.dim)

# Apply one-hot encoding to targets
if self.one_hot_encode_trgts:
target = F.one_hot(target, num_classes=self.num_classes)

# Construct credal set
with torch.no_grad():
sum_y_hat_prime = torch.sum((torch.ones_like(target) - target) * pred, dim=-1)
pred_hat = self.alpha * pred / torch.unsqueeze(sum_y_hat_prime, dim=-1)
target_credal = torch.where(target > self.gz_threshold, torch.ones_like(target) - self.alpha, pred_hat)

# Calculate divergence
divergence = torch.sum(F.kl_div(pred.log(), target_credal, log_target=False, reduction="none"), dim=-1)

pred = torch.sum(pred * target, dim=-1)

result = torch.where(torch.gt(pred, 1. - self.alpha), torch.zeros_like(divergence), divergence)
return torch.mean(result)

class BaseKGELightning(pl.LightningModule):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -135,7 +167,10 @@ def __init__(self, args: dict):
self.kernel_size = None
self.num_of_output_channels = None
self.weight_decay = None

self.loss = torch.nn.BCEWithLogitsLoss()
#self.loss = LabelRelaxationLoss()

self.selected_optimizer = None
self.normalizer_class = None
self.normalize_head_entity_embeddings = IdentityClass()
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added dicee/scripts/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file added dicee/scripts/__pycache__/run.cpython-310.pyc
Binary file not shown.
Binary file added dicee/trainer/__pycache__/__init__.cpython-310.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit f18bde2

Please sign in to comment.