-
Notifications
You must be signed in to change notification settings - Fork 661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WebGPU-based successor for GPUjs #822
Comments
Hey all, so GPT4 and me got to talkin and, I got a start on this. Here is a sample converting a DFT function. It needs a lot of work, like inferring input/outputs and so on as it's a little jank, but here's to a start :-), give me a couple more days to think and I'm parking this. Just save this as an html file, open, and check the console/see the readout, scroll to the bottom of the code to see the javascript function being converted. Turns out webgpu is really not that challenging to transpose. There are some challenges in how to infer types and so on but I will have something in place shortly just out of my own curiosity. <html>
<head></head>
<body>
<script>
class WebGPUjs {
constructor() {
this.bindings = [];
}
static async createPipeline(computeFunction, device = null) {
if (!device) {
const gpu = navigator.gpu;
const adapter = await gpu.requestAdapter();
device = await adapter.requestDevice();
}
const processor = new WebGPUjs();
let shader = computeFunction; let ast;
if(typeof computeFunction === 'function') {
let result = processor.convertToWebGPU(computeFunction);
shader = result.shader; ast = result.ast;
}
await processor.init(shader, ast, undefined, device);
return processor;
}
async init(computeShader, ast, bindGroupLayoutSettings, device=this.device) {
this.device = device;
// Extract all returned variables from the function string
const returnMatches = this.fstr.match(/^(?![ \t]*\/\/).*\breturn .*;/gm);
let returnedVars = returnMatches ? returnMatches.map(match => match.replace(/^[ \t]*return /, '').replace(';', '')) : undefined;
returnedVars = this.flattenStrings(returnedVars);
if(ast) {
let bufferIncr = 0;
let uniformBufferIdx;
let filtered = ast.filter((v) => v.isInput || returnedVars?.includes(v.name));
const entries = filtered.map((node, i) => {
let isReturned = (returnedVars === undefined || returnedVars?.includes(node.name));
if (node.isUniform) {
if(typeof uniformBufferIdx === 'undefined') {
uniformBufferIdx = i;
bufferIncr++;
return {
binding: uniformBufferIdx,
visibility: GPUShaderStage.COMPUTE,
buffer: {
type: 'uniform'
}
};
}
return undefined;
}
else {
const buffer = {
binding: bufferIncr,
visibility: GPUShaderStage.COMPUTE,
buffer: {
type: (isReturned || node.isModified) ? 'storage' : 'read-only-storage' // Corrected here
}
};
bufferIncr++;
return buffer;
}
}).filter(v => v);
this.bindGroupLayout = this.device.createBindGroupLayout({
entries
});
}
else if (bindGroupLayout) this.bindGroupLayout = this.device.createBindGroupLayout(bindGroupLayoutSettings);
this.pipelineLayout = this.device.createPipelineLayout({
bindGroupLayouts: [this.bindGroupLayout]
});
this.shader = computeShader;
this.shaderModule = this.device.createShaderModule({
code:computeShader
});
this.computePipeline = this.device.createComputePipeline({
layout: this.pipelineLayout,
compute: {
module: this.shaderModule,
entryPoint: 'main'
}
});
return this.computePipeline;
}
// Helper function to determine the type and size of the input
getInputTypeInfo(input, idx) {
const typeName = input.constructor.name;
// Check the variable registry for the type
const param = this.params[idx];
if (param) {
if (param.type.startsWith('mat')) {
const matMatch = param.type.match(/mat(\d+)x(\d+)<(f32|i32)>/);
const rows = parseInt(matMatch[1]);
const cols = parseInt(matMatch[2]);
const type = matMatch[3];
return { type: param.type, byteSize: wgslTypeSizes[param.type].size };
} else if (param.type.startsWith('vec')) {
const vecMatch = param.type.match(/vec(\d+)<(f32|i32)>/);
const dimensions = parseInt(vecMatch[1]);
const type = vecMatch[2];
//console.log(param, wgslTypeSizes[param.type])
return { type: `vec${dimensions}<${type}>`, byteSize: wgslTypeSizes[param.type].size };
}
}
switch (typeName) {
case 'Float32Array':
return { type: 'f32', byteSize: 4 };
case 'Int32Array':
return { type: 'i32', byteSize: 4 };
case 'Uint32Array':
return { type: 'u32', byteSize: 4 };
//none of these are supported in webgpu
case 'Float64Array':
return { type: 'f64', byteSize: 8 };
case 'Float16Array': //does not exist in javascript
return { type: 'f16', byteSize: 2 };
case 'Int16Array':
return { type: 'i16', byteSize: 2 };
case 'Uint16Array':
return { type: 'u16', byteSize: 2 };
case 'Int8Array':
return { type: 'i8', byteSize: 1 };
case 'Uint8Array':
return { type: 'u8', byteSize: 1 };
}
if (typeof input === 'number') {
if (Number.isInteger(input)) {
return { type: 'i32', byteSize: 4 }; //u32??
} else {
return { type: 'f32', byteSize: 4 };
}
}
// Add more conditions for matrices and other types if needed
return { type: 'unknown', byteSize: 0 };
}
flattenArray(arr) {
let result = [];
for (let i = 0; i < arr.length; i++) {
if (Array.isArray(arr[i])) {
result = result.concat(this.flattenArray(arr[i]));
} else {
result.push(arr[i]);
}
}
return result;
}
process(...inputs) {
const inputTypes = [];
inputs.forEach((input, idx) => {
inputTypes.push(this.getInputTypeInfo(input, idx))
})
const allSameSize = this.inputBuffers && inputs.every((inputArray, index) =>
this.inputBuffers[index].byteLength === inputArray.length * inputTypes[index].byteSize
);
if (!allSameSize) {
// Create or recreate input buffers // Extract all returned variables from the function string
// Separate input and output AST nodes
this.inputBuffers = [];
this.uniformBuffer = undefined;
this.outputBuffers = [];
}
let uBufferPushed = false;
let inputBufferIndex = 0;
let hasUniformBuffer = 0;
this.params.forEach((node, i) => {
if(node.isUniform) {
// Assuming you've determined the total size of the uniform buffer beforehand
if (!this.uniformBuffer) {
let totalUniformBufferSize = 0;
this.ast.forEach((node,j) => {
if(node.isInput && node.isUniform){
totalUniformBufferSize += inputTypes[j].byteSize;
if(totalUniformBufferSize % 8 !== 0)
totalUniformBufferSize += wgslTypeSizes[inputTypes[j].type].alignment;
}
});
totalUniformBufferSize -= totalUniformBufferSize % 16; //correct final buffer size (IDK)
this.uniformBuffer = this.device.createBuffer({
size: totalUniformBufferSize, // This should be the sum of byte sizes of all uniforms
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_SRC,
mappedAtCreation: true
});
this.inputBuffers.push(this.uniformBuffer);
}
if(!hasUniformBuffer) {
hasUniformBuffer = 1;
inputBufferIndex++;
}
}
// Create or recreate input buffers
else {
if (!allSameSize) {
if(!inputs[i]) {
if(i > inputs.length)
{ //temp, we don't hae a way to estimate dynamically generated data structures
if(node.type.includes('vec') || node.type.includes('mat')) //these are immutable anyway so this is kind of useless but we are just padding the uniform buffer
inputs[i] = new Float32Array(new Array(16).fill(0)); //e.g. a mat4
else if (node.type.includeS('array'))
inputs[i] = new Float32Array(new Array(65536).fill(0)); //e.g. a dynamic float32 arraybuffer
else inputs[i] = 0.0; //a numbe
}
throw new Error("Missing Input at argument "+i+". Type: "+this.params[i].type);
}
if(!inputs[i].byteLength && Array.isArray(inputs[i][0])) inputs[i] = this.flattenArray(inputs[i]);
this.inputBuffers.push(
this.device.createBuffer({
size: inputs[i].byteLength ? inputs[i].byteLength : inputs[i].length*4,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC,
mappedAtCreation: true
})
);
}
new Float32Array(this.inputBuffers[inputBufferIndex].getMappedRange()).set(inputs[i]);
this.inputBuffers[inputBufferIndex].unmap();
inputBufferIndex++;
}
if(node.isReturned && (!node.isUniform || (node.isUniform && !uBufferPushed))) {
// Create or recreate the output buffers for all returned variables
if(!node.isUniform) {
this.outputBuffers.push(this.inputBuffers[this.inputBuffers.length - 1]);
} else if(!uBufferPushed) {
uBufferPushed = true;
this.outputBuffers.push(this.uniformBuffer);
}
}
});
if(this.uniformBuffer) {
// Use a DataView to set values at specific byte offsets
const dataView = new DataView(this.uniformBuffer.getMappedRange());
let offset = 0; // Initialize the offset
this.ast.forEach((node, i) => {
if(node.isUniform && node.isReturned) {
const typeInfo = wgslTypeSizes[inputTypes[i].type];
// Ensure the offset is aligned correctly
offset = Math.ceil(offset / typeInfo.alignment) * typeInfo.alignment;
if (inputTypes[i].type.startsWith('vec')) {
const vecSize = typeInfo.size / 4;
for (let j = 0; j < vecSize; j++) {
//console.log(dataView,offset + j * 4)
dataView.setFloat32(offset + j * 4, inputs[i][j], true);
}
} else if (inputTypes[i].type.startsWith('mat')) {
const flatMatrix = this.flattenArray(inputs[i]);
for (let j = 0; j < flatMatrix.length; j++) {
dataView.setFloat32(offset + j * 4, flatMatrix[j], true); //we don't have Float16 in javascript :-\
}
} else{
switch (inputTypes[i].type) {
case 'f32':
dataView.setFloat32(offset, inputs[i], true); // true for little-endian
break;
case 'i32':
dataView.setInt32(offset, inputs[i], true); // true for little-endian
break;
case 'u32':
dataView.setUInt32(offset, inputs[i], true); // true for little-endian
break;
}
}
offset += typeInfo.size; // Increment the offset by the size of the type
}
});
this.uniformBuffer.unmap();
}
if(!allSameSize) {
// Update bind group creation to include both input and output buffers
const bindGroupEntries = [...this.inputBuffers].map((buffer, index) => ({
binding: index,
resource: { buffer }
})); //we are inferring outputBuffers from inputBuffers
this.bindGroup = this.device.createBindGroup({
layout: this.bindGroupLayout,
entries: bindGroupEntries
});
}
const commandEncoder = this.device.createCommandEncoder();
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(this.computePipeline);
passEncoder.setBindGroup(0, this.bindGroup);
passEncoder.dispatchWorkgroups(Math.ceil(inputs[0].length / 64)); // Assuming all inputs are of the same size
passEncoder.end();
// Create staging buffers for all output buffers
const stagingBuffers = this.outputBuffers.map(outputBuffer => {
return this.device.createBuffer({
size: outputBuffer.size,
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST
});
});
// Copy data from each output buffer to its corresponding staging buffer
this.outputBuffers.forEach((outputBuffer, index) => {
commandEncoder.copyBufferToBuffer(
outputBuffer, 0,
stagingBuffers[index], 0,
outputBuffer.size
);
});
this.device.queue.submit([commandEncoder.finish()]);
const promises = stagingBuffers.map(buffer => {
return new Promise((resolve) => {
buffer.mapAsync(GPUMapMode.READ).then(() => {
const mappedRange = buffer.getMappedRange();
const rawResults = new Float32Array(mappedRange);
const copiedResults = new Float32Array(rawResults.length);
copiedResults.set(rawResults); // Fast copy
buffer.unmap();
resolve(copiedResults);
});
});
});
return promises.length === 1 ? promises[0] : Promise.all(promises);
}
getFunctionHead = (methodString) => {
let startindex = methodString.indexOf('=>')+1;
if(startindex <= 0) {
startindex = methodString.indexOf('){');
}
if(startindex <= 0) {
startindex = methodString.indexOf(') {');
}
return methodString.slice(0, methodString.indexOf('{',startindex) + 1);
}
splitIgnoringBrackets = (str) => {
const result = [];
let depth = 0; // depth of nested structures
let currentToken = '';
for (let i = 0; i < str.length; i++) {
const char = str[i];
if (char === ',' && depth === 0) {
result.push(currentToken);
currentToken = '';
} else {
currentToken += char;
if (char === '(' || char === '[' || char === '{') {
depth++;
} else if (char === ')' || char === ']' || char === '}') {
depth--;
}
}
}
// This is the change: Ensure any remaining content in currentToken is added to result
if (currentToken) {
result.push(currentToken);
}
return result;
}
tokenize(funcStr) {
// Capture function parameters
let head = this.getFunctionHead(funcStr);
let paramString = head.substring(head.indexOf('(') + 1, head.lastIndexOf(')'));
let params = this.splitIgnoringBrackets(paramString).map(param => ({
token: param,
isInput: true
}));
// Capture variables, arrays, and their assignments
const assignmentTokens = (funcStr.match(/(const|let|var)\s+(\w+)\s*=\s*([^;]+)/g) || []).map(token => ({
token,
isInput: false
}));
// Combine both sets of tokens
return params.concat(assignmentTokens);
}
parse = (tokens) => {
const ast = [];
// Extract all returned variables from the tokens
const returnMatches = this.fstr.match(/^(?![ \t]*\/\/).*\breturn .*;/gm);
let returnedVars = returnMatches ? returnMatches.map(match => match.replace(/^[ \t]*return /, '').replace(';', '')) : undefined;
returnedVars = this.flattenStrings(returnedVars);
const functionBody = this.fstr.substring(this.fstr.indexOf('{'));
//basic function splitting, we dont support object inputs right now, anyway. e.g. we could add {x,y,z} objects to define vectors
tokens.forEach(({ token, isInput }) => {
let isReturned = returnedVars?.find((v) => token.includes(v));
let isModified = new RegExp(`\\b${token.split('=')[0]}\\b(\\[\\w+\\])?\\s*=`).test(functionBody);
if (token.includes('=')) {
const variableMatch = token.match(/(const|let|var)?\s*(\w+)\s*=\s*(.+)/);
if (variableMatch[3].startsWith('new Array') || variableMatch[3].startsWith('[')) {
ast.push({
type: 'array',
name: variableMatch[2],
value: variableMatch[3],
isInput,
isReturned: returnedVars ? returnedVars?.includes(variableMatch[2]) : isInput ? true : false,
isModified
});
} else if (token.startsWith('vec') || token.startsWith('mat')) {
const typeMatch = token.match(/(vec\d|mat\d+x\d+)\(([^)]+)\)/);
if (typeMatch) {
ast.push({
type: typeMatch[1],
name: token.split('=')[0],
value: typeMatch[2],
isInput,
isReturned: returnedVars ? returnedVars?.includes(token.split('=')[0]) : isInput ? true : false,
isModified
});
}
} else {
ast.push({
type: 'variable',
name: variableMatch[2],
value: variableMatch[3],
isUniform:true,
isInput,
isReturned: returnedVars ? returnedVars?.includes(variableMatch[2]) : isInput ? true : false,
isModified
});
}
} else if (token.includes('new Array') || token.includes('[')) {
// This is a function parameter that is an array
const paramName = token.split('=')[0];
ast.push({
type: 'array',
name: paramName,
value: token,
isInput,
isReturned,
isModified
});
} else if (token.startsWith('vec') || token.startsWith('mat')) {
const typeMatch = token.match(/(vec\d|mat\d+x\d+)\(([^)]+)\)/);
if (typeMatch) {
ast.push({
type: typeMatch[1],
name: token.split('=')[0],
value: typeMatch[2],
isInput,
isReturned: returnedVars ? returnedVars?.includes(token.split('=')[0]) : isInput ? true : false,
isModified
});
}
} else {
// This is a function parameter without a default value
ast.push({
type: 'variable',
name: token,
value: 'unknown',
isUniform:true,
isInput,
isReturned,
isModified
});
}
});
this.ast = ast;
return ast;
}
inferTypeFromValue(value, funcStr, ast) {
if (value.startsWith('vec')) {
const type = value.includes('.') ? '<f32>' : '<i32>';
return value.match(/vec(\d)/)[0] + type;
} else if (value.startsWith('mat')) {
const type = value.includes('.') ? '<f32>' : '<i32>';
return value.match(/mat(\d)x(\d)/)[0] + type;
} else if (value.startsWith('[')) {
// Infer the type from the first element if the array is initialized with values
const firstElement = value.split(',')[0].substring(1);
if(firstElement === ']') return 'array<f32>';
if (firstElement.startsWith('[') && !firstElement.endsWith(']')) {
// Only recurse if the first element is another array and not a complete array by itself
return this.inferTypeFromValue(firstElement, funcStr, ast);
} else {
// Check if the first element starts with vec or mat
if (firstElement.startsWith('vec') || firstElement.startsWith('mat')) {
return `array<${this.inferTypeFromValue(firstElement, funcStr, ast)}>`;
} else if (firstElement.includes('.')) {
return 'array<f32>';
} else if (!isNaN(firstElement)) {
return 'array<i32>';
}
}
} else if (value.startsWith('new Array')) {
// If the array is initialized using the `new Array()` syntax, look for assignments in the function body
const arrayNameMatch = value.match(/let\s+(\w+)\s*=/);
if (arrayNameMatch) {
const arrayName = arrayNameMatch[1];
const assignmentMatch = funcStr.match(new RegExp(`${arrayName}\\[\\d+\\]\\s*=\\s*(.+?);`));
if (assignmentMatch) {
return this.inferTypeFromValue(assignmentMatch[1], funcStr, ast);
}
} else return 'f32'
} else if (value.includes('.')) {
return 'f32'; // Float type for values with decimals
} else if (!isNaN(value)) {
return 'i32'; // Int type for whole numbers
} else {
// Check if the value is a variable name and infer its type from AST
const astNode = ast.find(node => node.name === value);
if (astNode) {
if (astNode.type === 'array') {
return 'f32'; // Assuming all arrays are of type f32 for simplicity
} else if (astNode.type === 'variable') {
return this.inferTypeFromValue(astNode.value, funcStr, ast);
}
}
}
return 'f32'; // For other types
}
flattenStrings(arr) {
if(!arr) throw new Error(arr);
const callback = (item) => {
if (item.startsWith('[') && item.endsWith(']')) {
return item.slice(1, -1).split(',').map(s => s.trim());
}
return item;
}
return arr.reduce((acc, value, index, array) => {
return acc.concat(callback(value, index, array));
}, []);
}
generateDataStructures(funcStr, ast) {
let code = '//Bindings (data passed to/from CPU) \n';
// Extract all returned variables from the function string
// const returnMatches = this.fstr.match(/^(?![ \t]*\/\/).*\breturn .*;/gm);
// let returnedVars = returnMatches ? returnMatches.map(match => match.replace(/^[ \t]*return /, '').replace(';', '')) : undefined;
// returnedVars = this.flattenStrings(returnedVars);
// Capture all nested functions
const functionRegex = /function (\w+)\(([^()]*|\((?:[^()]*|\([^()]*\))*\))*\) \{([\s\S]*?)\}/g;
let modifiedStr = this.fstr;
let match;
while ((match = functionRegex.exec(this.fstr)) !== null) {
// Replace the content of the nested function with a placeholder
modifiedStr = modifiedStr.replace(match[3], 'PLACEHOLDER');
}
// Now, search for return statements in the modified string
const returnMatches = modifiedStr.match(/^(?![ \t]*\/\/).*\breturn .*;/gm);
let returnedVars = returnMatches ? returnMatches.map(match => match.replace(/^[ \t]*return /, '').replace(';', '')) : undefined;
returnedVars = this.flattenStrings(returnedVars);
let uniformsStruct = 'struct UniformsStruct {\n'; // Start the UniformsStruct
let hasUniforms = false; // Flag to check if there are any uniforms
this.params = [];
let bindingIncr = 0;
ast.forEach((node, i) => {
if(returnedVars.includes(node.name)) node.isInput = true; //catch extra returned variables not in the explicit input buffers (data structures generated by webgpu)
if(node.isInput) {
if (node.type === 'array') {
this.bindings.push(node.name);
const elementType = this.inferTypeFromValue(node.value.split(',')[0], funcStr, ast);
node.type = elementType; // Use the inferred type directly
this.params.push(node);
code += `struct ${capitalizeFirstLetter(node.name)}Struct {\n values: ${elementType}\n};\n\n`;
code += `@group(0) @binding(${bindingIncr})\n`;
if (!returnedVars || returnedVars?.includes(node.name)) {
code += `var<storage, read_write> ${node.name}: ${capitalizeFirstLetter(node.name)}Struct;\n\n`;
} else {
code += `var<storage, read> ${node.name}: ${capitalizeFirstLetter(node.name)}Struct;\n\n`;
}
bindingIncr++;
}
else if (node.isUniform) {
if(!hasUniforms) {
hasUniforms = bindingIncr; // Set the flag to the index
bindingIncr++;
}
this.bindings.push(node.name);
const uniformType = this.inferTypeFromValue(node.value, funcStr, ast);
node.type = uniformType;
this.params.push(node);
uniformsStruct += ` ${node.name}: ${uniformType},\n`; // Add the uniform to the UniformsStruct
}
}
});
uniformsStruct += '};\n\n'; // Close the UniformsStruct
if (hasUniforms) { // If there are any uniforms, add the UniformsStruct and its binding to the code
code += uniformsStruct;
code += `@group(0) @binding(${hasUniforms}) var<uniform> uniforms: UniformsStruct;\n\n`;
}
return code;
}
extractAndTransposeInnerFunctions = (body, extract=true) => {
const functionRegex = /function (\w+)\(([^()]*|\((?:[^()]*|\([^()]*\))*\))*\) \{([\s\S]*?)\}/g;
let match;
let extractedFunctions = '';
while ((match = functionRegex.exec(body)) !== null) {
const functionHead = match[0];
let paramString = functionHead.substring(functionHead.indexOf('(') + 1, functionHead.lastIndexOf(')'));
let outputParam;
let params = this.splitIgnoringBrackets(paramString).map((p) => {
let split = p.split('=');
let vname = split[0];
let inferredType = this.inferTypeFromValue(split[1], body, this.ast);
if(!outputParam) outputParam = inferredType;
return vname+': '+inferredType;
});
const funcName = match[1];
const funcBody = match[3];
// Transpose the function body
const transposedBody = this.transposeBody(funcBody, funcBody, null, true); // Assuming AST is not used in your current implementation
extractedFunctions += `fn ${funcName}(${params}) -> ${outputParam} {\n${transposedBody}\n}\n\n`;
}
// Remove the inner functions from the main body
if(extract) body = body.replace(functionRegex, '');
return { body, extractedFunctions };
}
generateMainFunctionWorkGroup(funcStr, ast, size=256) {
let code = '//Main function call\n//globalId tells us what x,y,z thread we are on\n';
if(this.functions) {
this.functions.forEach((f) => {
let result = this.extractAndTransposeInnerFunctions(f.toString(), false);
if(result.extractedFunctions) code += result.extractedFunctions;
})
}
// Extract inner functions and transpose them
const { body: mainBody, extractedFunctions } = this.extractAndTransposeInnerFunctions(funcStr.match(/{([\s\S]+)}/)[1], true);
// Prepend the transposed inner functions to the main function
code += extractedFunctions;
// Generate function signature
code += '@compute @workgroup_size('+size+')\n';
code += `fn main(\n @builtin(global_invocation_id) globalId: vec3<u32>`;
code += '\n) {\n';
// Transpose the main body
code += this.transposeBody(mainBody, funcStr, ast);
code += '}\n';
return code;
}
transposeBody = (body, funcStr, ast, returns = false) => {
let code = '';
// Capture commented lines and replace with a placeholder
const commentPlaceholders = {};
let placeholderIndex = 0;
body = body.replace(/\/\/.*$/gm, (match) => {
const placeholder = `__COMMENT_PLACEHOLDER_${placeholderIndex}__`;
commentPlaceholders[placeholder] = match;
placeholderIndex++;
return placeholder;
});
// Replace common patterns
code = body.replace(/for \(let (\w+) = (\w+); \1 < (\w+); \1\+\+\)/g, 'for (var $1 = $2u; $1 < $3; $1 = $1 + 1u)');
code = code.replace(/const (\w+) = (\w+).length;/g, 'let $1 = arrayLength(&$2.values);');
code = code.replace(/const (\w+) = globalId.(\w+);/g, 'let $1 = globalId.$2;');
code = code.replace(/const/g, 'let');
const vecMatDeclarationRegex = /let (\w+) = (vec\d+|mat\d+)/g;
code = code.replace(vecMatDeclarationRegex, 'var $1 = $2');
// Handle array access
code = code.replace(/(\w+)\[([\w\s+\-*\/]+)\]/g, '$1.values[$2]');
// Handle array length
code = code.replace(/(\w+).length/g, 'arrayLength(&$1.values)');
// Handle mathematical operations
code = replaceJSFunctions(code, replacements);
// Handle vector and matrix creation
const vecMatCreationRegex = /(vec(\d+)|mat(\d+))\(([^)]+)\)/g;
code = code.replace(vecMatCreationRegex, (match, type, vecSize, matSize, args) => {
// Split the arguments and check if any of them contain a decimal point
const argArray = args.split(',').map(arg => arg);
const hasDecimal = argArray.some(arg => arg.includes('.'));
// If any argument has a decimal, it's a float, otherwise it's an integer
const inferredType = hasDecimal ? 'f32' : 'i32';
return `${type}<${inferredType}>(${argArray.join(', ')})`;
});
this.params.forEach((param) => {
if(param.isUniform) {
const regex = new RegExp(`(?<![a-zA-Z0-9])${param.name}(?![a-zA-Z0-9])`, 'g');
code = code.replace(regex, `uniforms.${param.name}`);
}
});
// Replace placeholders with their corresponding comments
for (const [placeholder, comment] of Object.entries(commentPlaceholders)) {
code = code.replace(placeholder, comment);
}
// Ensure lines not ending with a semicolon or open bracket have a semicolon appended. Not sure if this is
code = code.replace(/^(.*[^;\s\{\[\(\,\>\}])(\s*\/\/.*)$/gm, '$1;$2');
code = code.replace(/^(.*[^;\s\{\[\(\,\>\}])(?!\s*\/\/)(?=\s*$)/gm, '$1;');
//trim off some cases for inserting semicolons wrong
code = code.replace(/(\/\/[^\n]*);/gm, '$1'); //trim off semicolons after comments
code = code.replace(/\);\s*(\n\s*)\)/gm, ')$1)'); //trim off semicolons between end parentheses
if(!returns) code = code.replace(/(return [^;]+;)/g, '//$1');
this.mainBody = code;
return code;
}
addFunction = (func) => {
if(!this.functions) this.functions = [];
this.functions.push(func);
let result = this.convertToWebGPU();
return this.init(result.shader, result.ast);
}
convertToWebGPU(func=this.fstr) {
const funcStr = typeof func === 'string' ? func : func.toString();
this.fstr = funcStr;
const tokens = this.tokenize(funcStr);
const ast = this.parse(tokens);
let webGPUCode = this.generateDataStructures(funcStr, ast);
webGPUCode += '\n' + this.generateMainFunctionWorkGroup(funcStr, ast); // Pass funcStr as the first argument
return {shader:webGPUCode, ast};
}
}
function capitalizeFirstLetter(string) {
return string.charAt(0).toUpperCase() + string.slice(1);
}
function replaceJSFunctions(code, replacements) {
for (let [jsFunc, shaderFunc] of Object.entries(replacements)) {
const regex = new RegExp(jsFunc.replace('.', '\\.'), 'g'); // Escape dots for regex
code = code.replace(regex, shaderFunc);
}
return code;
}
// Usage:
const replacements = {
'Math.PI': `${Math.PI}`,
'Math.E': `${Math.E}`,
'Math.abs': 'abs',
'Math.acos': 'acos',
'Math.asin': 'asin',
'Math.atan': 'atan',
'Math.atan2': 'atan2', // Note: Shader might handle atan2 differently, ensure compatibility
'Math.ceil': 'ceil',
'Math.cos': 'cos',
'Math.exp': 'exp',
'Math.floor': 'floor',
'Math.log': 'log',
'Math.max': 'max',
'Math.min': 'min',
'Math.pow': 'pow',
'Math.round': 'round',
'Math.sin': 'sin',
'Math.sqrt': 'sqrt',
'Math.tan': 'tan',
// ... add more replacements as needed
};
const wgslTypeSizes32 = {
'i32': { alignment: 4, size: 4 },
'u32': { alignment: 4, size: 4 },
'f32': { alignment: 4, size: 4 },
'atomic': { alignment: 4, size: 4 },
'vec2<f32>': { alignment: 8, size: 8 },
'vec2<i32>': { alignment: 8, size: 8 },
'vec2<u32>': { alignment: 8, size: 8 },
'vec3<f32>': { alignment: 16, size: 12 },
'vec3<i32>': { alignment: 16, size: 12 },
'vec3<u32>': { alignment: 16, size: 12 },
'vec4<f32>': { alignment: 16, size: 16 },
'vec4<i32>': { alignment: 16, size: 16 },
'vec4<u32>': { alignment: 16, size: 16 },
'mat2x2<f32>': { alignment: 8, size: 16 },
'mat2x2<i32>': { alignment: 8, size: 16 },
'mat2x2<u32>': { alignment: 8, size: 16 },
'mat3x2<f32>': { alignment: 8, size: 24 },
'mat3x2<i32>': { alignment: 8, size: 24 },
'mat3x2<u32>': { alignment: 8, size: 24 },
'mat4x2<f32>': { alignment: 8, size: 32 },
'mat4x2<i32>': { alignment: 8, size: 32 },
'mat4x2<u32>': { alignment: 8, size: 32 },
'mat2x3<f32>': { alignment: 16, size: 32 },
'mat2x3<i32>': { alignment: 16, size: 32 },
'mat2x3<u32>': { alignment: 16, size: 32 },
'mat3x3<f32>': { alignment: 16, size: 48 },
'mat3x3<i32>': { alignment: 16, size: 48 },
'mat3x3<u32>': { alignment: 16, size: 48 },
'mat4x3<f32>': { alignment: 16, size: 64 },
'mat4x3<i32>': { alignment: 16, size: 64 },
'mat4x3<u32>': { alignment: 16, size: 64 },
'mat2x4<f32>': { alignment: 16, size: 32 },
'mat2x4<i32>': { alignment: 16, size: 32 },
'mat2x4<u32>': { alignment: 16, size: 32 },
'mat3x4<f32>': { alignment: 16, size: 48 },
'mat3x4<i32>': { alignment: 16, size: 48 },
'mat3x4<u32>': { alignment: 16, size: 48 },
'mat4x4<f32>': { alignment: 16, size: 64 },
'mat4x4<i32>': { alignment: 16, size: 64 },
'mat4x4<u32>': { alignment: 16, size: 64 }
};
const wgslTypeSizes16 = {
'i16': { alignment: 2, size: 2 },
'u16': { alignment: 2, size: 2 },
'f16': { alignment: 2, size: 2 },
'vec2<f16>': { alignment: 4, size: 4 },
'vec2<i16>': { alignment: 4, size: 4 },
'vec2<u16>': { alignment: 4, size: 4 },
'vec3<f16>': { alignment: 8, size: 6 },
'vec3<i16>': { alignment: 8, size: 6 },
'vec3<u16>': { alignment: 8, size: 6 },
'vec4<f16>': { alignment: 8, size: 8 },
'vec4<i16>': { alignment: 8, size: 8 },
'vec4<u16>': { alignment: 8, size: 8 },
'mat2x2<f16>': { alignment: 4, size: 8 },
'mat2x2<i16>': { alignment: 4, size: 8 },
'mat2x2<u16>': { alignment: 4, size: 8 },
'mat3x2<f16>': { alignment: 4, size: 12 },
'mat3x2<i16>': { alignment: 4, size: 12 },
'mat3x2<u16>': { alignment: 4, size: 12 },
'mat4x2<f16>': { alignment: 4, size: 16 },
'mat4x2<i16>': { alignment: 4, size: 16 },
'mat4x2<u16>': { alignment: 4, size: 16 },
'mat2x3<f16>': { alignment: 8, size: 16 },
'mat2x3<i16>': { alignment: 8, size: 16 },
'mat2x3<u16>': { alignment: 8, size: 16 },
'mat3x3<f16>': { alignment: 8, size: 24 },
'mat3x3<i16>': { alignment: 8, size: 24 },
'mat3x3<u16>': { alignment: 8, size: 24 },
'mat4x3<f16>': { alignment: 8, size: 32 },
'mat4x3<i16>': { alignment: 8, size: 32 },
'mat4x3<u16>': { alignment: 8, size: 32 },
'mat2x4<f16>': { alignment: 8, size: 16 },
'mat2x4<i16>': { alignment: 8, size: 16 },
'mat2x4<u16>': { alignment: 8, size: 16 },
'mat3x4<f16>': { alignment: 8, size: 24 },
'mat3x4<i16>': { alignment: 8, size: 24 },
'mat3x4<u16>': { alignment: 8, size: 24 },
'mat4x4<f16>': { alignment: 8, size: 32 },
'mat4x4<i16>': { alignment: 8, size: 32 },
'mat4x4<u16>': { alignment: 8, size: 32 }
};
const wgslTypeSizes = Object.assign({}, wgslTypeSizes16, wgslTypeSizes32);
function dft(
inputData = [],
outputData = [],
//dummy inputs
outp3 = mat2x2(vec2(1.0,1.0),vec2(1.0,1.0)),
outp4 = 4,
outp5 = vec3(1,2,3),
outp6 = [vec2(1.0,1.0)]
) {
function add(a=vec2(0.0,0.0),b=vec2(0.0,0.0)) { //transpiled out of main body
return a + b;
}
const N = inputData.length;
const k = globalId.x;
var sum = vec2(0.0, 0.0);
var sum2 = add(sum,sum);
const b = 3 + outp4;
var M = mat4x4(
vec4(1.0,0.0,0.0,0.0),
vec4(0.0,1.0,0.0,0.0),
vec4(0.0,0.0,1.0,0.0),
vec4(0.0,0.0,0.0,1.0)
); //identity matrix
let D = M + M;
var Z = outp3 * mat2x2(vec2(4.0,-1.0),vec2(3.0,2.0));
var Zz = outp5 + vec3(4,5,6);
for (let n = 0; n < N; n++) {
const phase = 2.0 * Math.PI * f32(k) * f32(n) / f32(N);
sum = sum + vec2(
inputData[n] * Math.cos(phase),
-inputData[n] * Math.sin(phase)
);
}
let v = 2
const outputIndex = k * 2 //use strict
if (outputIndex + 1 < outputData.length) {
outputData[outputIndex] = sum.x;
outputData[outputIndex + 1] = sum.y;
}
return [inputData, outputData]; //returning an array of inputs lets us return several buffer promises
//return outputData;
//return outp4; //we can also return the uniform buffer though it is immutable so it's pointless
}
//explicit return statements will define only that variable as an output (i.e. a mutable read_write buffer)
const parser = new WebGPUjs();
const webGPUCode = parser.convertToWebGPU(dft);
//console.log(webGPUCode);
document.body.style.backgroundColor = 'black';
document.body.style.color = 'white';
document.body.insertAdjacentHTML('afterbegin', `
<span style="position:absolute; left:0px;">
Before (edit me!):<br>
<textarea id="t2" style="width:50vw; background-color:#303000; color:lightblue; height:100vh;">${dft.toString()}</textarea>
</span>
<span style="position:absolute; left:50vw;">
After:<br>
<textarea id="t1" style="width:50vw; background-color:#000020; color:lightblue; height:100vh;">${webGPUCode.shader}</textarea>
</span>
`);
function parseFunction() {
const fstr = document.getElementById('t2').value;
const webGPUCode = parser.convertToWebGPU(fstr);
document.getElementById('t1').value = webGPUCode.shader;
}
document.getElementById('t2').oninput = () => {
parseFunction();
}
WebGPUjs.createPipeline(dft).then(pipeline => {
// Create some sample input data
const len = 256;
const inputData = new Float32Array(len).fill(1.0); // Example data
const outputData = new Float32Array(len*2).fill(0);
// Run the process method to execute the shader
pipeline.process(inputData,outputData,[1.0,2.0,3.0,4.0], 1, [1,2,3], [1,2]).then(result => {
console.log(result); // Log the output
if(result[2]?.buffer) { //we returned the uniform buffer for some reason, double check alignments
console.log(
new DataView(result[2].buffer).getInt32(16,true) // the int32 is still correctly encoded
)
}
pipeline.addFunction(function mul(a=vec2(2.0),b=vec2(2.0)) { return a * b; }).then((p) => {
document.getElementById('t1').value = pipeline.shader;
});
});
});
const dftReference = `
struct InputData {
values : array<f32>
}
struct OutputData {
values: array<f32>
}
@group(0) @binding(0)
var<storage, read> inputData: InputData;
@group(0) @binding(1)
var<storage, read_write> outputData: OutputData;
@compute @workgroup_size(256)
fn main(
@builtin(global_invocation_id) globalId: vec3<u32>
) {
let N = arrayLength(&inputData.values);
let k = globalId.x;
var sum = vec2<f32>(0.0, 0.0);
for (var n = 0u; n < N; n = n + 1u) {
let phase = 2.0 * 3.14159265359 * f32(k) * f32(n) / f32(N);
sum = sum + vec2<f32>(
inputData.values[n] * cos(phase),
-inputData.values[n] * sin(phase)
);
}
let outputIndex = k * 2;
if (outputIndex + 1 < arrayLength(&outputData.values)) {
outputData.values[outputIndex] = sum.x;
outputData.values[outputIndex + 1] = sum.y;
}
}
`
</script>
<script>
</script>
</body>
</html> |
Here's a repo. Compute shaders work great just getting the rendering and pipeline chaining working now (could never get that to work in GPUjs). It's pretty much a stock renderer just requires the special javascript-ish shader function input format |
@joshbrew yes, you have far far more control over webgpu then webgl, webgpu is fully async, so you can keep feeding it data and execution instructions keeping it fully occupied at all times |
@jacobbogers I've made a ton of progress on my webgpujs thing, there is another project much closer to gpujs here: https://github.com/AmesingFlank/taichi.js I'm otherwise stuck getting textures to render correctly before I can polish up my library finally. (I'd pay someone at this point) |
just normal buffers for compute shaders, why r u using textures? |
That tatchi stuff looks crazy good |
I have an entire process to transpile compute, vertex, and fragment and then a WIP method to combine bindings as you chain shaders so I can create as many shaders as I want that can share data structures. So that includes textures or especially storage textures once I solve the basic texture issue I am having. This is not really documented yet till I can fix the bugs and finalize the workflow. |
Thanks for brewing this thought .. any update on this ? |
I've made a lot of headway on my library recently, have not really kept tabs on others since I do this in fits to get better with webgpu in general. What's left for me now is getting storage textures and a few extended things to make sense, plus the multi-compute/render chains to combine bind groups automatically (the transpiler portion works just not the actual buffering part because it's written too stupidly lol). Then of course just polish on all the utilities and function/pipeline handles as none of it looks professional. Then to fix some of the annoying gotchas from stringifying functions and stuff due to oversights in my regex black magic, but that will take me some more time. I'm more happy than I expected with what is emerging but I am working on some examples like noise textures and stuff that will make it more apparent for its utility, I think this is a truly viable option for a lot of complex tasks but I am pushing myself to actually do that stuff myself first to round out the framework and make it fully functional/digestible. what's working rn is here: Oh yeah the minified library is <100kb too, which is pretty much only thanks to template strings. |
Okay fellas, we have compute support now in native web, and outputting kernels are possible e.g. this tutorial: https://jott.live/markdown/webgpu_safari
Who has started or where do we want to start the discussion of a WebGPU-based spiritual successor for GPUjs? The whole javascript-based kernel generation is genius for learning GPU coding and helped me immensely since the performance is fine too, and it's not really that complicated under the hood since it's just transposing text for you. The WebGPU compute pipeline adds some additional boilerplate for setting up command/storage buffers and data structures though it could definitely be macro'd based on what I've seen.
Another thing to think about is chaining compute and fragment shaders, though compute can handle rasterization and possibly faster, while the fragment shader can simply dump the resulting image matrix to screen. E.g. https://github.com/OmarShehata/webgpu-compute-rasterizer/blob/main/how-to-build-a-compute-rasterizer.md
The text was updated successfully, but these errors were encountered: