-
Notifications
You must be signed in to change notification settings - Fork 24
/
main.py
72 lines (61 loc) · 2.79 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import random
from data.E2E.reader import E2EDataReader
from data.WebNLG.reader import WebNLGDataReader
from data.reader import DataReader, DataSetType
from planner.naive_planner import NaivePlanner
from planner.neural_planner import NeuralPlanner
from planner.planner import Planner
from process.evaluation import EvaluationPipeline
from process.pre_process import TrainingPreProcessPipeline, TestingPreProcessPipeline
from process.reg import REGPipeline
from process.train_model import TrainModelPipeline
from process.train_planner import TrainPlannerPipeline
from process.translate import TranslatePipeline
from reg.bert import BertREG
from reg.naive import NaiveREG
from reg.base import REG
from scorer.global_direction import GlobalDirectionExpert
from scorer.product_of_experts import WeightedProductOfExperts
from scorer.relation_direction import RelationDirectionExpert
from scorer.relation_transitions import RelationTransitionsExpert
from scorer.splitting_tendencies import SplittingTendenciesExpert
from utils.pipeline import Pipeline
class Config:
def __init__(self, reader: DataReader = None, planner: Planner = None, reg: REG = None, test_reader: DataReader = None):
self.reader = {
DataSetType.TRAIN: reader,
DataSetType.DEV: reader,
DataSetType.TEST: test_reader if test_reader else reader,
}
self.planner = planner
self.reg = reg
MainPipeline = Pipeline()
MainPipeline.enqueue("pre-process", "Pre-process training data", TrainingPreProcessPipeline)
MainPipeline.enqueue("train-planner", "Train Planner", TrainPlannerPipeline)
MainPipeline.enqueue("train-model", "Train Model", TrainModelPipeline)
MainPipeline.enqueue("test-corpus", "Pre-process test data", TestingPreProcessPipeline)
MainPipeline.enqueue("train-reg", "Train Referring Expressions Generator", REGPipeline)
MainPipeline.enqueue("translate", "Translate Test", TranslatePipeline)
MainPipeline.enqueue("evaluate", "Evaluate Translations", EvaluationPipeline)
if __name__ == "__main__":
# naive_planner = NaivePlanner(WeightedProductOfExperts([
# RelationDirectionExpert,
# GlobalDirectionExpert,
# SplittingTendenciesExpert,
# RelationTransitionsExpert
# ]))
neural_planner = NeuralPlanner()
# combined_planner = CombinedPlanner((neural_planner, naive_planner))
config = Config(reader=WebNLGDataReader,
planner=neural_planner,
reg=BertREG)
res = MainPipeline.mutate({"config": config}).execute("WebNLG", cache_name="WebNLG")
print()
d = random.choice(res["translate"].data)
print("Random Sample:")
print("Graph:", d.graph.as_rdf())
print("Plan:", d.plan)
print("Translation:", d.hyp)
print("Reference: ", d.text)
print()
print("BLEU", res["evaluate"]["bleu"])