Skip to content

Commit

Permalink
Hints multiarray (#1)
Browse files Browse the repository at this point in the history
* Adding support to expressions hints multiarray

* Minor fixes

* Fixing hints strings

* Adding logger improvements

* Importing logger.hpp in utils

* Moving some logs to debug mode

* Fix typo

* Fixing calculate_xdivxsub method

* Modify Makefile for bctree

* align console output to pil2-proofman output

* align console output to pil2-proofman output

* Fixing joinrecursive2

---------

Co-authored-by: Xavier Pinsach <[email protected]>
  • Loading branch information
RogerTaule and xavi-pinsach authored Oct 2, 2024
1 parent 863eca5 commit 6d03bca
Show file tree
Hide file tree
Showing 25 changed files with 405 additions and 277 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ build
runtime/*
!runtime/README.md
MyLogFile.log
lib/libstarks.a
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ INC_DIRS := $(shell find $(SRC_DIRS) -type d) $(sort $(dir))
INC_FLAGS := $(addprefix -I,$(INC_DIRS))


SRCS_STARKS_LIB := $(shell find ./src/api/starks_api.* ./src/goldilocks/src ./src/config ./src/starkpil ./src/rapidsnark/binfile_utils.* ./src/ffiasm ./src/utils -name *.cpp -or -name *.c -or -name *.asm -or -name *.cc)
SRCS_STARKS_LIB := $(shell find ./src/api/starks_api.* ./src/goldilocks/src ./src/config ./src/starkpil ./src/poseidon_opt ./src/rapidsnark/binfile_utils.* ./src/rapidsnark/logger.* ./src/ffiasm ./src/utils -name *.cpp -or -name *.c -or -name *.asm -or -name *.cc)
OBJS_STARKS_LIB := $(SRCS_STARKS_LIB:%=$(BUILD_DIR)/%.o)
DEPS_STARKS_LIB := $(OBJS_STARKS_LIB:.o=.d)

SRCS_BCT := $(shell find ./src/bctree/build_const_tree.cpp ./src/bctree/main.cpp ./src/goldilocks/src ./src/starkpil/merkleTree/merkleTreeBN128.cpp ./src/starkpil/merkleTree/merkleTreeGL.cpp ./src/poseidon_opt/poseidon_opt.cpp ./src/ffiasm ./src/utils/* -name *.cpp -or -name *.c -or -name *.asm -or -name *.cc)
SRCS_BCT := $(shell find ./src/bctree/build_const_tree.cpp ./src/bctree/main.cpp ./src/goldilocks/src ./src/starkpil/merkleTree/merkleTreeBN128.cpp ./src/starkpil/merkleTree/merkleTreeGL.cpp ./src/rapidsnark/logger.* ./src/poseidon_opt/poseidon_opt.cpp ./src/ffiasm ./src/utils/* -name *.cpp -or -name *.c -or -name *.asm -or -name *.cc)
OBJS_BCT := $(SRCS_BCT:%=$(BUILD_DIR)/%.o)
DEPS_BCT := $(OBJS_BCT:.o=.d)

Expand Down
5 changes: 3 additions & 2 deletions lib/include/starks_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@
void *starks_new(void *pSetupCtx);
void starks_free(void *pStarks);

void extend_and_merkelize(void *pStarks, uint64_t step, void *buffer, void *proof, void *pBuffHelper);
void treesGL_get_root(void *pStarks, uint64_t index, void *root);

void *calculate_xdivxsub(void *pStarks, void* xiChallenge, void *xDivXSub);
void calculate_xdivxsub(void *pStarks, void* xiChallenge, void *xDivXSub);
void *get_fri_pol(void *pSetupCtx, void *buffer);

void calculate_fri_polynomial(void *pStarks, void* buffer, void* public_inputs, void* challenges, void* subproofValues, void* evals, void *xDivXSub);
Expand Down Expand Up @@ -102,4 +101,6 @@
void *join_zkin_recursive2(char* globalInfoFile, void* pPublics, void* pChallenges, void *zkin1, void *zkin2, void *starkInfoRecursive2);
void *join_zkin_final(void* pPublics, void* pChallenges, char* globalInfoFile, void **zkinRecursive2, void **starkInfoRecursive2);

// Util calls
void setLogLevel(uint64_t level);
#endif
Binary file modified lib/libstarks.a
Binary file not shown.
39 changes: 31 additions & 8 deletions src/api/starks_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
#include "hints.hpp"
#include "global_constraints.hpp"
#include "gen_recursive_proof.hpp"
#include "logger.hpp"
#include <filesystem>

#include <nlohmann/json.hpp>
using json = nlohmann::json;
using ordered_json = nlohmann::ordered_json;

using namespace CPlusPlusLogging;

void save_challenges(void *pChallenges, char* globalInfoFile, char *fileDir) {

json globalInfo;
Expand Down Expand Up @@ -210,8 +213,8 @@ void expressions_bin_free(void *pExpressionsBin)
// ========================================================================================
void *get_hint_field(void *pSetupCtx, void* buffer, void* public_inputs, void* challenges, void* subproofValues, void* evals, uint64_t hintId, char *hintFieldName, bool dest, bool inverse, bool printExpression)
{
HintFieldInfo hintFieldInfo = getHintField(*(SetupCtx *)pSetupCtx, (Goldilocks::Element *)buffer, (Goldilocks::Element *)public_inputs, (Goldilocks::Element *)challenges, (Goldilocks::Element *)subproofValues, (Goldilocks::Element *)evals, hintId, string(hintFieldName), dest, inverse, printExpression);
return new HintFieldInfo(hintFieldInfo);
HintFieldValues hintFieldValues = getHintField(*(SetupCtx *)pSetupCtx, (Goldilocks::Element *)buffer, (Goldilocks::Element *)public_inputs, (Goldilocks::Element *)challenges, (Goldilocks::Element *)subproofValues, (Goldilocks::Element *)evals, hintId, string(hintFieldName), dest, inverse, printExpression);
return new HintFieldValues(hintFieldValues);
}

uint64_t set_hint_field(void *pSetupCtx, void* buffer, void* subproofValues, void *values, uint64_t hintId, char * hintFieldName)
Expand All @@ -233,12 +236,6 @@ void starks_free(void *pStarks)
delete starks;
}

void extend_and_merkelize(void *pStarks, uint64_t step, void *buffer, void *pProof, void *pBuffHelper)
{
auto starks = (Starks<Goldilocks::Element> *)pStarks;
starks->ffi_extend_and_merkelize(step, (Goldilocks::Element *)buffer, (FRIProof<Goldilocks::Element> *)pProof, (Goldilocks::Element *)pBuffHelper);
}

void treesGL_get_root(void *pStarks, uint64_t index, void *dst)
{
Starks<Goldilocks::Element> *starks = (Starks<Goldilocks::Element> *)pStarks;
Expand Down Expand Up @@ -468,4 +465,30 @@ void *join_zkin_final(void* pPublics, void* pChallenges, char* globalInfoFile, v
ordered_json zkinFinal = joinzkinfinal(globalInfo, publics, challenges, zkinRecursive2, starkInfoRecursive2);

return (void *) new nlohmann::ordered_json(zkinFinal);
}


void setLogLevel(uint64_t level) {
LogLevel new_level;
switch(level) {
case 0:
new_level = DISABLE_LOG;
break;
case 1:
case 2:
case 3:
new_level = LOG_LEVEL_INFO;
break;
case 4:
new_level = LOG_LEVEL_DEBUG;
break;
case 5:
new_level = LOG_LEVEL_TRACE;
break;
default:
cerr << "Invalid log level: " << level << endl;
return;
}

Logger::getInstance(LOG_TYPE::CONSOLE)->updateLogLevel((LOG_LEVEL)new_level);
}
5 changes: 3 additions & 2 deletions src/api/starks_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@
void *starks_new(void *pSetupCtx);
void starks_free(void *pStarks);

void extend_and_merkelize(void *pStarks, uint64_t step, void *buffer, void *proof, void *pBuffHelper);
void treesGL_get_root(void *pStarks, uint64_t index, void *root);

void *calculate_xdivxsub(void *pStarks, void* xiChallenge, void *xDivXSub);
void calculate_xdivxsub(void *pStarks, void* xiChallenge, void *xDivXSub);
void *get_fri_pol(void *pSetupCtx, void *buffer);

void calculate_fri_polynomial(void *pStarks, void* buffer, void* public_inputs, void* challenges, void* subproofValues, void* evals, void *xDivXSub);
Expand Down Expand Up @@ -102,4 +101,6 @@
void *join_zkin_recursive2(char* globalInfoFile, void* pPublics, void* pChallenges, void *zkin1, void *zkin2, void *starkInfoRecursive2);
void *join_zkin_final(void* pPublics, void* pChallenges, char* globalInfoFile, void **zkinRecursive2, void **starkInfoRecursive2);

// Util calls
void setLogLevel(uint64_t level);
#endif
27 changes: 15 additions & 12 deletions src/rapidsnark/logger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ Logger* Logger::m_Instance = 0;
// Log file name. File name should be change from here only
const string logFileName = "MyLogFile.log";

Logger::Logger()
Logger::Logger(int log_type)
{
m_File.open(logFileName.c_str(), ios::out|ios::app);
m_LogLevel = LOG_LEVEL_TRACE;
m_LogType = FILE_LOG;
if (log_type == FILE_LOG)
{
m_File.open(logFileName.c_str(), ios::out | ios::app);
}

m_LogLevel = LOG_LEVEL_TRACE;
m_LogType = (LOG_TYPE)log_type;

// Initialize mutex
#ifdef WIN32
Expand Down Expand Up @@ -84,11 +88,11 @@ Logger::~Logger()
#endif
}

Logger* Logger::getInstance() throw ()
Logger *Logger::getInstance(int log_type) throw()
{
if (m_Instance == 0)
if (m_Instance == 0)
{
m_Instance = new Logger();
m_Instance = new Logger(log_type);
}
return m_Instance;
}
Expand Down Expand Up @@ -120,7 +124,7 @@ void Logger::logIntoFile(std::string& data)

void Logger::logOnConsole(std::string& data)
{
cout << getCurrentTime() << " " << data << endl;
cout << data << endl;
}

string Logger::getCurrentTime()
Expand Down Expand Up @@ -255,7 +259,7 @@ void Logger::buffer(std::ostringstream& stream) throw()
void Logger::info(const char* text) throw()
{
string data;
data.append("[INFO]: ");
data.append("[INFO] PilStark: ");
data.append(text);

if((m_LogType == FILE_LOG) && (m_LogLevel >= LOG_LEVEL_INFO))
Expand Down Expand Up @@ -283,7 +287,7 @@ void Logger::info(std::ostringstream& stream) throw()
void Logger::trace(const char* text) throw()
{
string data;
data.append("[TRACE]: ");
data.append("[TRACE] PilStark: ");
data.append(text);

if((m_LogType == FILE_LOG) && (m_LogLevel >= LOG_LEVEL_TRACE))
Expand Down Expand Up @@ -311,7 +315,7 @@ void Logger::trace(std::ostringstream& stream) throw()
void Logger::debug(const char* text) throw()
{
string data;
data.append("[DEBUG]: ");
data.append("[DEBUG] PilStark: ");
data.append(text);

if((m_LogType == FILE_LOG) && (m_LogLevel >= LOG_LEVEL_DEBUG))
Expand Down Expand Up @@ -368,4 +372,3 @@ void Logger::enableFileLogging()
{
m_LogType = FILE_LOG ;
}

8 changes: 4 additions & 4 deletions src/rapidsnark/logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ namespace CPlusPlusLogging
DISABLE_LOG = 1,
LOG_LEVEL_INFO = 2,
LOG_LEVEL_BUFFER = 3,
LOG_LEVEL_TRACE = 4,
LOG_LEVEL_DEBUG = 5,
LOG_LEVEL_DEBUG = 4,
LOG_LEVEL_TRACE = 5,
ENABLE_LOG = 6,
}LogLevel;

Expand All @@ -77,7 +77,7 @@ namespace CPlusPlusLogging
class Logger
{
public:
static Logger* getInstance() throw ();
static Logger* getInstance(int log_type = LOG_TYPE::FILE_LOG) throw ();

// Interface for Error Log
void error(const char* text) throw();
Expand Down Expand Up @@ -128,7 +128,7 @@ namespace CPlusPlusLogging
void enableFileLogging();

protected:
Logger();
Logger(int log_type);
~Logger();

// Wrapper function for lock/unlock
Expand Down
4 changes: 2 additions & 2 deletions src/starkpil/const_pols.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class ConstPols
uint64_t constTreeSizeBytes = getConstTreeSize();

pConstTreeAddress = (Goldilocks::Element *)loadFileParallel(constTreeFile, constTreeSizeBytes);
zklog.info("Starks::Starks() successfully copied " + to_string(constTreeSizeBytes) + " bytes from constant file " + constTreeFile);
zklog.debug(" Starks::Starks() successfully copied " + to_string(constTreeSizeBytes) + " bytes from constant file " + constTreeFile);

pConstPolsAddressExtended = &pConstTreeAddress[2];
TimerStopAndLog(LOAD_CONST_TREE_TO_MEMORY);
Expand All @@ -114,7 +114,7 @@ class ConstPols
uint64_t constPolsSize = starkInfo.nConstants * sizeof(Goldilocks::Element) * N;

pConstPolsAddress = (Goldilocks::Element *)loadFileParallel(constPolsFile, constPolsSize);
zklog.info("Starks::Starks() successfully copied " + to_string(constPolsSize) + " bytes from constant file " + constPolsFile);
zklog.debug(" Starks::Starks() successfully copied " + to_string(constPolsSize) + " bytes from constant file " + constPolsFile);

TimerStopAndLog(LOAD_CONST_POLS_TO_MEMORY);
}
Expand Down
34 changes: 23 additions & 11 deletions src/starkpil/expressions_bin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,19 +303,31 @@ void ExpressionsBin::loadExpressionsBin(BinFileUtils::BinFile *expressionsBin) {
for(uint64_t f = 0; f < nFields; f++) {
HintField hintField;
std::string name = expressionsBin->readString();
std::string operand = expressionsBin->readString();
hintField.name = name;
hintField.operand = string2opType(operand);
if(hintField.operand == opType::number) {
hintField.value = expressionsBin->readU64LE();
} else if(hintField.operand == opType::string_) {
hintField.stringValue = expressionsBin->readString();
} else {
hintField.id = expressionsBin->readU32LE();
}
if(hintField.operand == opType::tmp) {
hintField.dim = expressionsBin->readU32LE();

uint64_t nValues = expressionsBin->readU32LE();
for(uint64_t v = 0; v < nValues; v++) {
HintFieldValue hintFieldValue;
std::string operand = expressionsBin->readString();
hintFieldValue.operand = string2opType(operand);
if(hintFieldValue.operand == opType::number) {
hintFieldValue.value = expressionsBin->readU64LE();
} else if(hintFieldValue.operand == opType::string_) {
hintFieldValue.stringValue = expressionsBin->readString();
} else {
hintFieldValue.id = expressionsBin->readU32LE();
}
if(hintFieldValue.operand == opType::tmp) {
hintFieldValue.dim = expressionsBin->readU32LE();
}
uint64_t nPos = expressionsBin->readU32LE();
for(uint64_t p = 0; p < nPos; ++p) {
uint32_t pos = expressionsBin->readU32LE();
hintFieldValue.pos.push_back(pos);
}
hintField.values.push_back(hintFieldValue);
}

hint.fields.push_back(hintField);
}

Expand Down
9 changes: 7 additions & 2 deletions src/starkpil/expressions_bin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@ const int BINARY_EXPRESSIONS_SECTION = 3;
const int BINARY_CONSTRAINTS_SECTION = 4;
const int BINARY_HINTS_SECTION = 5;

struct HintField {
string name;
struct HintFieldValue {
opType operand;
uint64_t id;
uint64_t dim;
uint64_t value;
string stringValue;
std::vector<uint64_t> pos;
};

struct HintField {
string name;
std::vector<HintFieldValue> values;
};


Expand Down
22 changes: 12 additions & 10 deletions src/starkpil/gen_recursive_proof.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,14 @@ void *genRecursiveProof(SetupCtx& setupCtx, Goldilocks::Element *pAddress, Goldi
treesGL[setupCtx.starkInfo.nStages + 1]->getRoot(verkey);
starks.addTranscript(transcript, &verkey[0], nFieldElements);

if(!setupCtx.starkInfo.starkStruct.hashCommits) {
starks.addTranscriptGL(transcript, &publicInputs[0], setupCtx.starkInfo.nPublics);
} else {
ElementType hash[nFieldElements];
starks.calculateHash(hash, &publicInputs[0], setupCtx.starkInfo.nPublics);
starks.addTranscript(transcript, hash, nFieldElements);
if(setupCtx.starkInfo.nPublics > 0) {
if(!setupCtx.starkInfo.starkStruct.hashCommits) {
starks.addTranscriptGL(transcript, &publicInputs[0], setupCtx.starkInfo.nPublics);
} else {
ElementType hash[nFieldElements];
starks.calculateHash(hash, &publicInputs[0], setupCtx.starkInfo.nPublics);
starks.addTranscript(transcript, hash, nFieldElements);
}
}

TimerStopAndLog(STARK_STEP_0);
Expand Down Expand Up @@ -114,8 +116,8 @@ void *genRecursiveProof(SetupCtx& setupCtx, Goldilocks::Element *pAddress, Goldi
return hintField.name == "reference";
});

expressionsCtx.calculateExpression(params, den, denField->id, true);
expressionsCtx.calculateExpression(params, num, numField->id);
expressionsCtx.calculateExpression(params, den, denField->values[0].id, true);
expressionsCtx.calculateExpression(params, num, numField->values[0].id);


Goldilocks3::copy((Goldilocks3::Element *)&gprod[0], &Goldilocks3::one());
Expand All @@ -126,7 +128,7 @@ void *genRecursiveProof(SetupCtx& setupCtx, Goldilocks::Element *pAddress, Goldi
}

Polinomial gprodTransposedPol;
setupCtx.starkInfo.getPolynomial(gprodTransposedPol, pAddress, true, gprodField->id, false);
setupCtx.starkInfo.getPolynomial(gprodTransposedPol, pAddress, true, gprodField->values[0].id, false);
#pragma omp parallel for
for(uint64_t j = 0; j < N; ++j) {
std::memcpy(gprodTransposedPol[j], &gprod[j*FIELD_EXTENSION], FIELD_EXTENSION * sizeof(Goldilocks::Element));
Expand All @@ -136,7 +138,7 @@ void *genRecursiveProof(SetupCtx& setupCtx, Goldilocks::Element *pAddress, Goldi
delete den;
delete gprod;

commitsCalculated[gprodField->id] = true;
commitsCalculated[gprodField->values[0].id] = true;

for(uint64_t i = 0; i < setupCtx.starkInfo.cmPolsMap.size(); i++) {
if(setupCtx.starkInfo.cmPolsMap[i].stage == 2 && !setupCtx.starkInfo.cmPolsMap[i].imPol && !commitsCalculated[i]) {
Expand Down
Loading

0 comments on commit 6d03bca

Please sign in to comment.