Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
peteryangms committed Jun 5, 2024
1 parent 9d617af commit 2545277
Show file tree
Hide file tree
Showing 20 changed files with 428 additions and 245 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ select = ["ALL"]
"test/*" = ["S101"]

[tool.setuptools]
py-modules = ["rdagent"]
packages = ["rdagent"]

[tool.setuptools.dynamic]
dependencies = {file = ["requirements.txt"]}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,143 +2,29 @@
import json
from pathlib import Path

from document_process.document_analysis import (
check_factor_dict_viability,
from rdagent.document_process.document_analysis import (
filter_factor_by_viability,
deduplicate_factors_several_times,
extract_factors_from_report_dict_and_classify_result,
)
from document_process.document_reader import (
classify_report_from_dict,
load_and_process_pdfs_by_langchain,
extract_factors_from_report_dict,
merge_file_to_factor_dict_to_factor_dict,
)
from rdagent.document_process.document_reader import load_and_process_pdfs_by_langchain
from rdagent.document_process.document_analysis import classify_report_from_dict
from dotenv import load_dotenv
from oai.llm_utils import APIBackend


def extract_factors_and_implement(report_file_path: str):
assert load_dotenv()
api = APIBackend()
docs_dict_select = load_and_process_pdfs_by_langchain(Path(report_file_path))

selected_report_dict = classify_report_from_dict(report_dict=docs_dict_select, api=api, vote_time=1)
file_to_factor_result = extract_factors_from_report_dict_and_classify_result(docs_dict_select, selected_report_dict)

factor_dict = {}
for file_name in file_to_factor_result:
for factor_name in file_to_factor_result[file_name]:
factor_dict.setdefault(factor_name, [])
factor_dict[factor_name].append(file_to_factor_result[file_name][factor_name])

factor_dict_simple_deduplication = {}
for factor_name in factor_dict:
if len(factor_dict[factor_name]) > 1:
factor_dict_simple_deduplication[factor_name] = max(
factor_dict[factor_name],
key=lambda x: len(x["formulation"]),
)
else:
factor_dict_simple_deduplication[factor_name] = factor_dict[factor_name][0]
# %%

factor_viability = check_factor_dict_viability(factor_dict_simple_deduplication)
# json.dump(
# factor_viability,
# open(
# "factor_viability_all_reports.json",
# "w",
# ),
# indent=4,
# )

# factor_viability = json.load(
# open(
# "factor_viability_all_reports.json"
# )
# )

# %%

duplication_names_list = deduplicate_factors_several_times(factor_dict_simple_deduplication)
duplication_names_list = sorted(duplication_names_list, key=lambda x: len(x), reverse=True)
json.dump(duplication_names_list, open("duplication_names_list.json", "w"), indent=4)

# %%
factor_dict_viable = {
factor_name: factor_dict_simple_deduplication[factor_name]
for factor_name in factor_dict_simple_deduplication
if factor_viability[factor_name]["viability"]
}

to_replace_dict = {}
for duplication_names in duplication_names_list:
for duplication_factor_name in duplication_names[1:]:
to_replace_dict[duplication_factor_name] = duplication_names[0]

added_lower_name_set = set()
factor_dict_deduplication_with_llm = dict()
for factor_name in factor_dict_simple_deduplication:
if factor_name not in to_replace_dict and factor_name.lower() not in added_lower_name_set:
added_lower_name_set.add(factor_name.lower())
factor_dict_deduplication_with_llm[factor_name] = factor_dict_simple_deduplication[factor_name]

to_replace_viable_dict = {}
for duplication_names in duplication_names_list:
viability_list = [factor_viability[name]["viability"] for name in duplication_names]
if True not in viability_list:
continue
target_factor_name = duplication_names[viability_list.index(True)]
for duplication_factor_name in duplication_names:
if duplication_factor_name == target_factor_name:
continue
to_replace_viable_dict[duplication_factor_name] = target_factor_name

added_lower_name_set = set()
factor_dict_deduplication_with_llm_and_viable = dict()
for factor_name in factor_dict_viable:
if factor_name not in to_replace_viable_dict and factor_name.lower() not in added_lower_name_set:
added_lower_name_set.add(factor_name.lower())
factor_dict_deduplication_with_llm_and_viable[factor_name] = factor_dict_simple_deduplication[factor_name]

# %%
docs_dict = load_and_process_pdfs_by_langchain(Path(report_file_path))

dump_md_list = [
[factor_dict_simple_deduplication, "final_factor_book"],
[factor_dict_viable, "final_viable_factor_book"],
[factor_dict_deduplication_with_llm, "final_deduplicated_factor_book"],
[factor_dict_deduplication_with_llm_and_viable, "final_deduplicated_viable_factor_book"],
]
selected_report_dict = classify_report_from_dict(report_dict=docs_dict, vote_time=1)
file_to_factor_result = extract_factors_from_report_dict(docs_dict, selected_report_dict)
factor_dict = merge_file_to_factor_dict_to_factor_dict(file_to_factor_result)

for dump_md in dump_md_list:
factor_name_set = set()
current_index = 1
target_dict = dump_md[0]
json.dump(target_dict, open(f"{dump_md[1]}.json", "w"), indent=4)
with open(
rf"{dump_md[1]}.md",
"w",
) as fw:
for factor_name in target_dict:
formulation = target_dict[factor_name]["formulation"]
if factor_name in formulation:
target_factor_name = factor_name.replace("_", r"\_")
formulation = formulation.replace(factor_name, target_factor_name)
for variable in target_dict[factor_name]["variables"]:
if variable in formulation:
target_variable = variable.replace("_", r"\_")
formulation = formulation.replace(variable, target_variable)
factor_dict_viable, factor_viability = filter_factor_by_viability(factor_dict)

fw.write(f"## {current_index}. 因子名称:{factor_name}\n")
fw.write(f"### Viability: {target_dict[factor_name]['viability']}\n")
fw.write(f"### Viability Reason: {target_dict[factor_name]['viability_reason']}\n")
fw.write(f"### description: {target_dict[factor_name]['description']}\n")
fw.write(f"### formulation: $$ {formulation} $$\n")
fw.write(f"### formulation string: {formulation}\n")
# write a table of variable and its description
factor_dict, duplication_names_list = deduplicate_factors_several_times(factor_dict, factor_viability)

fw.write("### variable tables: \n")
fw.write("| variable | description |\n")
fw.write("| -------- | ----------- |\n")
for variable in target_dict[factor_name]["variables"]:
fw.write(f"| {variable} | {target_dict[factor_name]['variables'][variable]} |\n")

current_index += 1
if __name__ == "__main__":
extract_factors_and_implement("/home/xuyang1/workspace/report.pdf")
1 change: 1 addition & 0 deletions rdagent/core/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

class FincoSettings(BaseSettings):
use_azure: bool = True
use_azure_token_provider: bool = False
max_retry: int = 10
retry_wait_seconds: int = 1
continuous_mode: bool = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
from typing import Dict

import yaml
from finco.utils import SingletonBaseClass
from rdagent.core.utils import SingletonBaseClass


class FactorImplementationPrompts(Dict, SingletonBaseClass):
def __init__(self):
super().__init__()
prompt_yaml_path = Path(__file__).parent / "prompts.yaml"

class Prompts(Dict, SingletonBaseClass):
def __init__(self, file_path: Path):
prompt_yaml_dict = yaml.load(
open(
prompt_yaml_path,
file_path,
encoding="utf8",
),
Loader=yaml.FullLoader,
)

if prompt_yaml_dict is None:
raise ValueError(f"Failed to load prompts from {file_path}")

for key, value in prompt_yaml_dict.items():
self[key] = value
14 changes: 9 additions & 5 deletions rdagent/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,21 @@
from fuzzywuzzy import fuzz


class FincoException(Exception):
class RDAgentException(Exception):
pass


class SingletonMeta(type):
_instance = None
_instance_dict = {}

def __call__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(SingletonMeta, cls).__call__(*args, **kwargs)
return cls._instance
# Since it's hard to align the difference call using args and kwargs, we strictly ask to use kwargs in Singleton
if len(args) > 0:
raise RDAgentException("Please only use kwargs in Singleton to avoid misunderstanding.")
kwargs_hash = hash(tuple(sorted(kwargs.items())))
if kwargs_hash not in cls._instance_dict:
cls._instance_dict[kwargs_hash] = super(SingletonMeta, cls).__call__(*args, **kwargs)
return cls._instance_dict[kwargs_hash]


class SingletonBaseClass(metaclass=SingletonMeta):
Expand Down
Empty file.
Loading

0 comments on commit 2545277

Please sign in to comment.