diff --git a/src/options.tsx b/src/options.tsx index 4b7b3f5..9068e8f 100644 --- a/src/options.tsx +++ b/src/options.tsx @@ -4,7 +4,7 @@ import "./options.css"; import { getStorage, setStorage } from "./utils"; import Switch from "./components/Switch"; import FilterRules from "./components/FilterRules"; -import { FilterRuleItem } from "./types"; +import { FilterRuleItem, ServiceProvider } from "./types"; import { DEFAULT_PROMPT } from "./const"; const TABS = [ @@ -16,6 +16,9 @@ const TABS = [ function BasicSettings() { const [model, setModel] = useState("gpt-3.5-turbo"); + const [serviceProvider, setServiceProvider] = useState( + "GPT" + ); const [apiURL, setApiURL] = useState( "https://api.openai.com/v1/chat/completions" ); @@ -24,6 +27,11 @@ function BasicSettings() { ]); useEffect(() => { getStorage("model").then(setModel); + getStorage("serviceProvider").then((value) => { + if (value) { + setServiceProvider(value); + } + }); getStorage("apiURL").then(setApiURL); getStorage("filterRules").then(setFilterRules); }, []); @@ -33,6 +41,14 @@ function BasicSettings() { setStorage("model", e.target.value); }, []); + const updateServiceProvider = useCallback( + (e: ChangeEvent) => { + setServiceProvider(e.target.value as ServiceProvider); + setStorage("serviceProvider", e.target.value); + }, + [] + ); + const updateApiURL = useCallback((e: ChangeEvent) => { setApiURL(e.target.value); setStorage("apiURL", e.target.value); @@ -47,38 +63,58 @@ function BasicSettings() {
-
- - - +
+ + + +
+
+ + + +
+ + )}
@@ -212,7 +252,7 @@ const Popup = () => {
You can get your key from{" "} => { + const prompt: string = (await getStorage("prompt")) || DEFAULT_PROMPT; + return [ + { + role: "user", + parts: [ + { + text: "", + }, + ], + }, + { + role: "model", + parts: [ + { + text: "You are a brwoser tab group classificator", + }, + ], + }, + { + role: "user", + parts: [ + { + text: Mustache.render(prompt, { + tabURL: tab.url, + tabTitle: tab.title, + types: types.join(", "), + }), + }, + ], + }, + ]; +}; + +export const fetchGemini = async ( + apiKey: string, + tabInfo: TabInfo, + types: string[] +) => { + const response = await fetch( + "https://generativelanguage.googleapis.com/v1beta/models/gemini-pro:generateContent?key=" + + apiKey, + { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + contents: await renderPromptForGemini(tabInfo, types), + }), + } + ); + + const data = await response.json(); + + const type = data.candidates[0].content.parts[0].text; + return type; +}; diff --git a/src/service-provider/gpt.ts b/src/service-provider/gpt.ts new file mode 100644 index 0000000..472637c --- /dev/null +++ b/src/service-provider/gpt.ts @@ -0,0 +1,56 @@ +import { TabInfo } from "../types"; +import Mustache from "mustache"; +import { getStorage } from "../utils"; +import { DEFAULT_PROMPT } from "../const"; + +const renderPromptForOpenAI = async ( + tab: TabInfo, + types: string[] +): Promise< + [{ role: string; content: string }, { role: string; content: string }] +> => { + const prompt: string = (await getStorage("prompt")) || DEFAULT_PROMPT; + return [ + { + role: "system", + content: "You are a brwoser tab group classificator", + }, + { + role: "user", + content: Mustache.render(prompt, { + tabURL: tab.url, + tabTitle: tab.title, + types: types.join(", "), + }), + }, + ]; +}; + +export const fetchGpt = async ( + apiKey: string, + tabInfo: TabInfo, + types: string[] +) => { + const apiURL = + (await getStorage("apiURL")) || + "https://api.openai.com/v1/chat/completions"; + + const model = (await getStorage("model")) || "gpt-3.5-turbo"; + + const response = await fetch(apiURL, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${apiKey}`, + "Api-Key": apiKey, + }, + body: JSON.stringify({ + messages: await renderPromptForOpenAI(tabInfo, types), + model, + }), + }); + + const data = await response.json(); + const type = data.choices[0].message.content; + return type; +}; diff --git a/src/service-provider/index.ts b/src/service-provider/index.ts new file mode 100644 index 0000000..941edf5 --- /dev/null +++ b/src/service-provider/index.ts @@ -0,0 +1,21 @@ +import { TabInfo } from "../types"; +import { getServiceProvider } from "../utils"; +import { fetchGemini } from "./gemini"; +import { fetchGpt } from "./gpt"; + +const fetchMap = { + GPT: fetchGpt, + Gemini: fetchGemini, +}; + +export const fetchType = async ( + apiKey: string, + tabInfo: TabInfo, + types: string[] +) => { + const serviceProvider = await getServiceProvider(); + if (!fetchMap[serviceProvider]) { + throw new Error("unexpected serviceProvider: " + serviceProvider); + } + return fetchMap[serviceProvider](apiKey, tabInfo, types); +}; diff --git a/src/services.ts b/src/services.ts index e38bd30..1dafef8 100644 --- a/src/services.ts +++ b/src/services.ts @@ -1,42 +1,12 @@ -import Mustache from "mustache"; import { getStorage, matchesRule } from "./utils"; -import { FilterRuleItem } from "./types"; -import { DEFAULT_PROMPT } from "./const"; +import { FilterRuleItem, TabInfo } from "./types"; +import { fetchType } from "./service-provider"; interface TabGroup { type: string; tabIds: (number | undefined)[]; } -interface TabInfo { - id: number | undefined; - title: string | undefined; - url: string | undefined; -} - -const renderPrompt = async ( - tab: TabInfo, - types: string[] -): Promise< - [{ role: string; content: string }, { role: string; content: string }] -> => { - const prompt: string = (await getStorage("prompt")) || DEFAULT_PROMPT; - return [ - { - role: "system", - content: "You are a brwoser tab group classificator", - }, - { - role: "user", - content: Mustache.render(prompt, { - tabURL: tab.url, - tabTitle: tab.title, - types: types.join(", "), - }), - }, - ]; -}; - const filterTabInfo = (tabInfo: TabInfo, filterRules: FilterRuleItem[]) => { if (!filterRules || !filterRules?.length) return true; const url = new URL(tabInfo.url ?? ""); @@ -48,7 +18,7 @@ const filterTabInfo = (tabInfo: TabInfo, filterRules: FilterRuleItem[]) => { export async function batchGroupTabs( tabs: chrome.tabs.Tab[], types: string[], - openAIKey: string + apiKey: string ) { const filterRules = (await getStorage("filterRules")) || []; const tabInfoList: TabInfo[] = tabs @@ -68,31 +38,11 @@ export async function batchGroupTabs( }; }); - const model = (await getStorage("model")) || "gpt-3.5-turbo"; - const apiURL = - (await getStorage("apiURL")) || - "https://api.openai.com/v1/chat/completions"; - try { await Promise.all( tabInfoList.map(async (tabInfo) => { if (!tabInfo.url) return; - const response = await fetch(apiURL, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${openAIKey}`, - "Api-Key": openAIKey, - }, - body: JSON.stringify({ - messages: await renderPrompt(tabInfo, types), - model, - }), - }); - - const data = await response.json(); - const type = data.choices[0].message.content; - + const type = await fetchType(apiKey, tabInfo, types); const index = types.indexOf(type); if (index === -1) return; result[index].tabIds.push(tabInfo.id); @@ -108,35 +58,16 @@ export async function batchGroupTabs( export async function handleOneTab( tab: chrome.tabs.Tab, types: string[], - openAIKey: string + apiKey: string ) { try { const tabInfo: TabInfo = { id: tab.id, title: tab.title, url: tab.url }; - const model = (await getStorage("model")) || "gpt-3.5-turbo"; - const apiURL = - (await getStorage("apiURL")) || - "https://api.openai.com/v1/chat/completions"; - const filterRules = (await getStorage("filterRules")) || []; const shouldFilter = !filterTabInfo(tabInfo, filterRules); if (shouldFilter) return; - const response = await fetch(apiURL, { - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: `Bearer ${openAIKey}`, - "Api-Key": openAIKey, - }, - body: JSON.stringify({ - messages: await renderPrompt(tabInfo, types), - model, - }), - }); - - const data = await response.json(); - const type = data.choices[0].message.content; + const type = await fetchType(apiKey, tabInfo, types); return type; } catch (error) { diff --git a/src/types.ts b/src/types.ts index b9de660..9825fec 100644 --- a/src/types.ts +++ b/src/types.ts @@ -5,3 +5,11 @@ export type FilterRuleItem = { type: RuleType; rule: string; }; + +export type ServiceProvider = "GPT" | "Gemini"; + +export interface TabInfo { + id: number | undefined; + title: string | undefined; + url: string | undefined; +} diff --git a/src/utils.ts b/src/utils.ts index f33e78d..b868b2a 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,4 +1,4 @@ -import { FilterRuleItem } from "./types"; +import { FilterRuleItem, ServiceProvider } from "./types"; export function setStorage(key: string, value: V) { return new Promise((resolve, reject) => { @@ -127,3 +127,9 @@ export const curryFilterManualGroups = async () => { return !manualGroupsTabs.map((tab) => tab.id).includes(tabId); }; }; + +export const getServiceProvider = async () => { + const serviceProvider = + (await getStorage("serviceProvider")) || "GPT"; + return serviceProvider; +};