Skip to content

Commit

Permalink
Feature/Add Neo4j GraphRag support (#3686)
Browse files Browse the repository at this point in the history
* added: Neo4j database connectivity, Neo4j credentials, supports the usage of the GraphCypherQaChain node and modifies the FewShotPromptTemplate node to handle variables from the prefix field.

* Merge branch 'main' of github.com:FlowiseAI/Flowise into feature/graphragsupport

* revert pnpm-lock.yaml

* add: neo4j package

* Refactor GraphCypherQAChain: Update version to 1.0, remove memory input, and enhance prompt handling

- Changed version from 2.0 to 1.0.
- Removed the 'Memory' input parameter from the GraphCypherQAChain.
- Made 'cypherPrompt' optional and improved error handling for prompt validation.
- Updated the 'init' and 'run' methods to streamline input processing and response handling.
- Enhanced streaming response logic based on the 'returnDirect' flag.

* Refactor GraphCypherQAChain: Simplify imports and update init method signature

- Consolidated import statements for better readability.
- Removed the 'input' and 'options' parameters from the 'init' method, streamlining its signature to only accept 'nodeData'.

* add output, format final response, fix optional inputs

---------

Co-authored-by: Henry <[email protected]>
  • Loading branch information
ghondar and HenryHengZJ authored Dec 23, 2024
1 parent 93f3a5d commit a7c1ab8
Show file tree
Hide file tree
Showing 8 changed files with 34,325 additions and 33,897 deletions.
39 changes: 39 additions & 0 deletions packages/components/credentials/Neo4jApi.credential.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { INodeParams, INodeCredential } from '../src/Interface'

class Neo4jApi implements INodeCredential {
label: string
name: string
version: number
description: string
inputs: INodeParams[]

constructor() {
this.label = 'Neo4j API'
this.name = 'neo4jApi'
this.version = 1.0
this.description =
'Refer to <a target="_blank" href="https://neo4j.com/docs/operations-manual/current/authentication-authorization/">official guide</a> on Neo4j authentication'
this.inputs = [
{
label: 'Neo4j URL',
name: 'url',
type: 'string',
description: 'Your Neo4j instance URL (e.g., neo4j://localhost:7687)'
},
{
label: 'Username',
name: 'username',
type: 'string',
description: 'Neo4j database username'
},
{
label: 'Password',
name: 'password',
type: 'password',
description: 'Neo4j database password'
}
]
}
}

module.exports = { credClass: Neo4jApi }
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import { ICommonObject, INode, INodeData, INodeParams, INodeOutputsValue, IServerSideEventStreamer } from '../../../src/Interface'
import { FromLLMInput, GraphCypherQAChain } from '@langchain/community/chains/graph_qa/cypher'
import { getBaseClasses } from '../../../src/utils'
import { BasePromptTemplate, PromptTemplate, FewShotPromptTemplate } from '@langchain/core/prompts'
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
import { ConsoleCallbackHandler as LCConsoleCallbackHandler } from '@langchain/core/tracers/console'
import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation'
import { formatResponse } from '../../outputparsers/OutputParserHelpers'

class GraphCypherQA_Chain implements INode {
label: string
name: string
version: number
type: string
icon: string
category: string
description: string
baseClasses: string[]
inputs: INodeParams[]
sessionId?: string
outputs: INodeOutputsValue[]

constructor(fields?: { sessionId?: string }) {
this.label = 'Graph Cypher QA Chain'
this.name = 'graphCypherQAChain'
this.version = 1.0
this.type = 'GraphCypherQAChain'
this.icon = 'graphqa.svg'
this.category = 'Chains'
this.description = 'Advanced chain for question-answering against a Neo4j graph by generating Cypher statements'
this.baseClasses = [this.type, ...getBaseClasses(GraphCypherQAChain)]
this.sessionId = fields?.sessionId
this.inputs = [
{
label: 'Language Model',
name: 'model',
type: 'BaseLanguageModel',
description: 'Model for generating Cypher queries and answers.'
},
{
label: 'Neo4j Graph',
name: 'graph',
type: 'Neo4j'
},
{
label: 'Cypher Generation Prompt',
name: 'cypherPrompt',
optional: true,
type: 'BasePromptTemplate',
description: 'Prompt template for generating Cypher queries. Must include {schema} and {question} variables'
},
{
label: 'Cypher Generation Model',
name: 'cypherModel',
optional: true,
type: 'BaseLanguageModel',
description: 'Model for generating Cypher queries. If not provided, the main model will be used.'
},
{
label: 'QA Prompt',
name: 'qaPrompt',
optional: true,
type: 'BasePromptTemplate',
description: 'Prompt template for generating answers. Must include {context} and {question} variables'
},
{
label: 'QA Model',
name: 'qaModel',
optional: true,
type: 'BaseLanguageModel',
description: 'Model for generating answers. If not provided, the main model will be used.'
},
{
label: 'Input Moderation',
description: 'Detect text that could generate harmful output and prevent it from being sent to the language model',
name: 'inputModeration',
type: 'Moderation',
optional: true,
list: true
},
{
label: 'Return Direct',
name: 'returnDirect',
type: 'boolean',
default: false,
optional: true,
description: 'If true, return the raw query results instead of using the QA chain'
}
]
this.outputs = [
{
label: 'Graph Cypher QA Chain',
name: 'graphCypherQAChain',
baseClasses: [this.type, ...getBaseClasses(GraphCypherQAChain)]
},
{
label: 'Output Prediction',
name: 'outputPrediction',
baseClasses: ['string', 'json']
}
]
}

async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> {
const model = nodeData.inputs?.model
const cypherModel = nodeData.inputs?.cypherModel
const qaModel = nodeData.inputs?.qaModel
const graph = nodeData.inputs?.graph
const cypherPrompt = nodeData.inputs?.cypherPrompt as BasePromptTemplate | FewShotPromptTemplate | undefined
const qaPrompt = nodeData.inputs?.qaPrompt as BasePromptTemplate | undefined
const returnDirect = nodeData.inputs?.returnDirect as boolean
const output = nodeData.outputs?.output as string

// Handle prompt values if they exist
let cypherPromptTemplate: PromptTemplate | FewShotPromptTemplate | undefined
let qaPromptTemplate: PromptTemplate | undefined

if (cypherPrompt) {
if (cypherPrompt instanceof PromptTemplate) {
cypherPromptTemplate = new PromptTemplate({
template: cypherPrompt.template as string,
inputVariables: cypherPrompt.inputVariables
})
if (!qaPrompt) {
throw new Error('QA Prompt is required when Cypher Prompt is a Prompt Template')
}
} else if (cypherPrompt instanceof FewShotPromptTemplate) {
const examplePrompt = cypherPrompt.examplePrompt as PromptTemplate
cypherPromptTemplate = new FewShotPromptTemplate({
examples: cypherPrompt.examples,
examplePrompt: examplePrompt,
inputVariables: cypherPrompt.inputVariables,
prefix: cypherPrompt.prefix,
suffix: cypherPrompt.suffix,
exampleSeparator: cypherPrompt.exampleSeparator,
templateFormat: cypherPrompt.templateFormat
})
} else {
cypherPromptTemplate = cypherPrompt as PromptTemplate
}
}

if (qaPrompt instanceof PromptTemplate) {
qaPromptTemplate = new PromptTemplate({
template: qaPrompt.template as string,
inputVariables: qaPrompt.inputVariables
})
}

if ((!cypherModel || !qaModel) && !model) {
throw new Error('Language Model is required when Cypher Model or QA Model are not provided')
}

// Validate required variables in prompts
if (
cypherPromptTemplate &&
(!cypherPromptTemplate?.inputVariables.includes('schema') || !cypherPromptTemplate?.inputVariables.includes('question'))
) {
throw new Error('Cypher Generation Prompt must include {schema} and {question} variables')
}

const fromLLMInput: FromLLMInput = {
llm: model,
graph,
returnDirect
}

if (cypherModel && cypherPromptTemplate) {
fromLLMInput['cypherLLM'] = cypherModel
fromLLMInput['cypherPrompt'] = cypherPromptTemplate
}

if (qaModel && qaPromptTemplate) {
fromLLMInput['qaLLM'] = qaModel
fromLLMInput['qaPrompt'] = qaPromptTemplate
}

const chain = GraphCypherQAChain.fromLLM(fromLLMInput)

if (output === this.name) {
return chain
} else if (output === 'outputPrediction') {
nodeData.instance = chain
return await this.run(nodeData, input, options)
}

return chain
}

async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | object> {
const chain = nodeData.instance as GraphCypherQAChain
const moderations = nodeData.inputs?.inputModeration as Moderation[]
const returnDirect = nodeData.inputs?.returnDirect as boolean

const shouldStreamResponse = options.shouldStreamResponse
const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer
const chatId = options.chatId

// Handle input moderation if configured
if (moderations && moderations.length > 0) {
try {
input = await checkInputs(moderations, input)
} catch (e) {
await new Promise((resolve) => setTimeout(resolve, 500))
if (shouldStreamResponse) {
streamResponse(sseStreamer, chatId, e.message)
}
return formatResponse(e.message)
}
}

const obj = {
query: input
}

const loggerHandler = new ConsoleCallbackHandler(options.logger)
const callbackHandlers = await additionalCallbacks(nodeData, options)
let callbacks = [loggerHandler, ...callbackHandlers]

if (process.env.DEBUG === 'true') {
callbacks.push(new LCConsoleCallbackHandler())
}

try {
let response
if (shouldStreamResponse) {
if (returnDirect) {
response = await chain.invoke(obj, { callbacks })
let result = response?.result
if (typeof result === 'object') {
result = '```json\n' + JSON.stringify(result, null, 2)
}
if (result && typeof result === 'string') {
streamResponse(sseStreamer, chatId, result)
}
} else {
const handler = new CustomChainHandler(sseStreamer, chatId, 2)
callbacks.push(handler)
response = await chain.invoke(obj, { callbacks })
}
} else {
response = await chain.invoke(obj, { callbacks })
}

return formatResponse(response?.result)
} catch (error) {
console.error('Error in GraphCypherQAChain:', error)
if (shouldStreamResponse) {
streamResponse(sseStreamer, chatId, error.message)
}
return formatResponse(`Error: ${error.message}`)
}
}
}

module.exports = { nodeClass: GraphCypherQA_Chain }
22 changes: 22 additions & 0 deletions packages/components/nodes/chains/GraphCypherQAChain/graphqa.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit a7c1ab8

Please sign in to comment.