-
Notifications
You must be signed in to change notification settings - Fork 0
/
arguments.py
122 lines (105 loc) · 4.06 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
from typing import List, Optional
from paddlenlp.trainer import TrainingArguments
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={
"help": "Path to pretrained model or model identifier from huggingface.co/models"
}
)
tokenizer_name: Optional[str] = field(
default=None,
metadata={
"help": "Pretrained tokenizer name or path if not the same as model_name"
},
)
normalized: bool = field(default=True)
use_flash_attention: bool = field(
default=False, metadata={"help": "Whether to use flash attention"}
)
is_batch_negative: bool = field(
default=False, metadata={"help": "Whethe to use in batch negative training"}
)
@dataclass
class DataArguments:
train_data: str = field(default=None, metadata={"help": "Path to train data"})
train_group_size: int = field(default=8)
query_max_len: int = field(
default=32,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
passage_max_len: int = field(
default=128,
metadata={
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
max_example_num_per_dataset: int = field(
default=100000000,
metadata={"help": "the max number of examples for each dataset"},
)
query_instruction_for_retrieval: str = field(
default=None, metadata={"help": "instruction for query"}
)
passage_instruction_for_retrieval: str = field(
default=None, metadata={"help": "instruction for passage"}
)
def __post_init__(self):
if not os.path.exists(self.train_data):
raise FileNotFoundError(
f"cannot find file: {self.train_data}, please set a true path"
)
@dataclass
class RetrieverTrainingArguments(TrainingArguments):
negatives_cross_device: bool = field(
default=False, metadata={"help": "share negatives across devices"}
)
temperature: Optional[float] = field(default=0.02)
margin: Optional[float] = field(default=0.2)
fix_position_embedding: bool = field(
default=False, metadata={"help": "Freeze the parameters of position embeddings"}
)
sentence_pooling_method: str = field(
default="weighted_mean",
metadata={"help": "the pooling method, should be weighted_mean"},
)
fine_tune_type: str = field(
default="sft",
metadata={"help": "fine-tune type for retrieval,eg: sft, bitfit, lora"},
)
use_inbatch_neg: bool = field(
default=False, metadata={"help": "use passages in the same batch as negatives"}
)
use_matryoshka: bool = field(
default=False, metadata={"help": "use matryoshka for flexible embedding size"}
)
matryoshka_dims: List[int] = field(
default_factory=lambda: [64, 128, 256, 512, 768],
metadata={"help": "matryoshka dims"},
)
matryoshka_loss_weights: List[float] = field(
default_factory=lambda: [1, 1, 1, 1, 1],
metadata={"help": "matryoshka loss weights"},
)