From 421eb9b9c6c9dcc695cb82d13325a62b2fb79b31 Mon Sep 17 00:00:00 2001 From: Ofek Shaked Date: Wed, 25 Dec 2024 17:34:41 +0200 Subject: [PATCH] feat(ksymbols): reimplement ksymbols This implementation stores all symbols, or if a `requiredDataSymbolsOnly` flag is used when creating the symbol table, only non-data symbols are saved (and required data symbols must be registered before updating). This new implementation uses a generic symbol table implementation that is responsible for managing symbol lookups, and can be used by future code for managing exeutable file symbols. --- pkg/ebpf/hooked_syscall_table.go | 4 +- pkg/ebpf/ksymbols.go | 4 +- pkg/ebpf/probes/trace.go | 20 +- pkg/ebpf/processor_funcs.go | 14 +- pkg/ebpf/tracee.go | 18 +- pkg/events/derive/hooked_seq_ops.go | 5 +- pkg/events/derive/hooked_syscall.go | 4 +- pkg/utils/environment/kernel_symbols.go | 393 +++++++------------ pkg/utils/environment/kernel_symbols_test.go | 228 +++++------ pkg/utils/symbol_table.go | 215 ++++++++++ pkg/utils/symbol_table_test.go | 336 ++++++++++++++++ pkg/utils/utils.go | 20 - tests/e2e-inst-signatures/e2e-set_fs_pwd.go | 11 +- 13 files changed, 865 insertions(+), 407 deletions(-) create mode 100644 pkg/utils/symbol_table.go create mode 100644 pkg/utils/symbol_table_test.go diff --git a/pkg/ebpf/hooked_syscall_table.go b/pkg/ebpf/hooked_syscall_table.go index a7382dd677ad..68eabfb419c5 100644 --- a/pkg/ebpf/hooked_syscall_table.go +++ b/pkg/ebpf/hooked_syscall_table.go @@ -170,7 +170,7 @@ func (t *Tracee) populateExpectedSyscallTableArray(tableMap *bpf.BPFMap) error { return e } } - niSyscallAddress := niSyscallSymbol[0].Address + niSyscallAddress := niSyscallSymbol[0].Address() for i, kernelRestrictionArr := range events.SyscallSymbolNames { syscallName := t.getSyscallNameByKerVer(kernelRestrictionArr) @@ -199,7 +199,7 @@ func (t *Tracee) populateExpectedSyscallTableArray(tableMap *bpf.BPFMap) error { continue } - var expectedAddress = kernelSymbol[0].Address + var expectedAddress = kernelSymbol[0].Address() err = tableMap.Update(unsafe.Pointer(&index), unsafe.Pointer(&expectedAddress)) if err != nil { return err diff --git a/pkg/ebpf/ksymbols.go b/pkg/ebpf/ksymbols.go index 64fa239a696a..bc518144e5cd 100644 --- a/pkg/ebpf/ksymbols.go +++ b/pkg/ebpf/ksymbols.go @@ -54,8 +54,8 @@ func (t *Tracee) UpdateKallsyms() error { // ... and update the eBPF map with the symbol address. for _, sym := range symbol { key := make([]byte, maxKsymNameLen) - copy(key, sym.Name) - addr := sym.Address + copy(key, sym.Name()) + addr := sym.Address() // Update the eBPF map with the symbol address. err := bpfKsymsMap.Update( diff --git a/pkg/ebpf/probes/trace.go b/pkg/ebpf/probes/trace.go index ef67ac7351fd..cf36420a3f44 100644 --- a/pkg/ebpf/probes/trace.go +++ b/pkg/ebpf/probes/trace.go @@ -111,7 +111,7 @@ func (p *TraceProbe) attach(module *bpf.Module, args ...interface{}) error { var err error var link *bpf.BPFLink var attachFunc func(uint64) (*bpf.BPFLink, error) - var syms []environment.KernelSymbol + var syms []*environment.KernelSymbol // https://github.com/aquasecurity/tracee/issues/3653#issuecomment-1832642225 // // After commit b022f0c7e404 ('tracing/kprobes: Return EADDRNOTAVAIL @@ -141,10 +141,10 @@ func (p *TraceProbe) attach(module *bpf.Module, args ...interface{}) error { goto rollback } if p.probeType == SyscallEnter { - link, err = prog.AttachKprobe(syms[0].Name) + link, err = prog.AttachKprobe(syms[0].Name()) } if p.probeType == SyscallExit { - link, err = prog.AttachKretprobe(syms[0].Name) + link, err = prog.AttachKretprobe(syms[0].Name()) } if err != nil { goto rollback @@ -155,10 +155,10 @@ func (p *TraceProbe) attach(module *bpf.Module, args ...interface{}) error { symsCompat, _ := ksyms.GetSymbolByName(SyscallPrefixCompat + p.eventName) if len(symsCompat) > 0 { if p.probeType == SyscallEnter { - link, _ = prog.AttachKprobe(symsCompat[0].Name) + link, _ = prog.AttachKprobe(symsCompat[0].Name()) } if p.probeType == SyscallExit { - link, _ = prog.AttachKretprobe(symsCompat[0].Name) + link, _ = prog.AttachKretprobe(symsCompat[0].Name()) } p.bpfLink = append(p.bpfLink, link) } @@ -166,10 +166,10 @@ func (p *TraceProbe) attach(module *bpf.Module, args ...interface{}) error { symsCompat, _ = ksyms.GetSymbolByName(SyscallPrefixCompat2 + p.eventName) if len(symsCompat) > 0 { if p.probeType == SyscallEnter { - link, _ = prog.AttachKprobe(symsCompat[0].Name) + link, _ = prog.AttachKprobe(symsCompat[0].Name()) } if p.probeType == SyscallExit { - link, _ = prog.AttachKretprobe(symsCompat[0].Name) + link, _ = prog.AttachKretprobe(symsCompat[0].Name()) } p.bpfLink = append(p.bpfLink, link) } @@ -188,9 +188,9 @@ func (p *TraceProbe) attach(module *bpf.Module, args ...interface{}) error { case 1: // single address, attach kprobe using symbol name switch p.probeType { case KProbe: - link, err = prog.AttachKprobe(syms[0].Name) + link, err = prog.AttachKprobe(syms[0].Name()) case KretProbe: - link, err = prog.AttachKretprobe(syms[0].Name) + link, err = prog.AttachKretprobe(syms[0].Name()) } if err != nil { goto rollback @@ -204,7 +204,7 @@ func (p *TraceProbe) attach(module *bpf.Module, args ...interface{}) error { attachFunc = prog.AttachKretprobeOnOffset } for _, sym := range syms { - link, err := attachFunc(sym.Address) + link, err := attachFunc(sym.Address()) if err != nil { goto rollback } diff --git a/pkg/ebpf/processor_funcs.go b/pkg/ebpf/processor_funcs.go index 253b16fd5116..d79cfdc52dc5 100644 --- a/pkg/ebpf/processor_funcs.go +++ b/pkg/ebpf/processor_funcs.go @@ -233,7 +233,7 @@ func (t *Tracee) processDoInitModule(event *trace.Event) error { err := capabilities.GetInstance().EBPF( func() error { - err := t.kernelSymbols.Refresh() + err := t.kernelSymbols.Update() if err != nil { return errfmt.WrapError(err) } @@ -281,8 +281,8 @@ func (t *Tracee) processHookedProcFops(event *trace.Event) error { if addr == 0 { // address is in text segment, marked as 0 continue } - hookingFunction := utils.ParseSymbol(addr, t.kernelSymbols) - if hookingFunction.Owner == "system" { + hookingFunction := t.kernelSymbols.GetPotentiallyHiddenSymbolByAddr(addr)[0] + if hookingFunction.Owner() == "system" { continue } functionName := "unknown" @@ -292,7 +292,7 @@ func (t *Tracee) processHookedProcFops(event *trace.Event) error { case Iterate: functionName = "iterate" } - hookedFops = append(hookedFops, trace.HookedSymbolData{SymbolName: functionName, ModuleOwner: hookingFunction.Owner}) + hookedFops = append(hookedFops, trace.HookedSymbolData{SymbolName: functionName, ModuleOwner: hookingFunction.Owner()}) } err = events.SetArgValue(event, hookedFopsPointersArgName, hookedFops) if err != nil { @@ -326,7 +326,7 @@ func (t *Tracee) processPrintMemDump(event *trace.Event) error { } addressUint64 := uint64(address) - symbol := utils.ParseSymbol(addressUint64, t.kernelSymbols) + symbol := t.kernelSymbols.GetPotentiallyHiddenSymbolByAddr(addressUint64)[0] var utsName unix.Utsname arch := "" if err := unix.Uname(&utsName); err != nil { @@ -337,11 +337,11 @@ func (t *Tracee) processPrintMemDump(event *trace.Event) error { if err != nil { return err } - err = events.SetArgValue(event, "symbol_name", symbol.Name) + err = events.SetArgValue(event, "symbol_name", symbol.Name()) if err != nil { return err } - err = events.SetArgValue(event, "symbol_owner", symbol.Owner) + err = events.SetArgValue(event, "symbol_owner", symbol.Owner()) if err != nil { return err } diff --git a/pkg/ebpf/tracee.go b/pkg/ebpf/tracee.go index 7167d6b13ef9..97ceb60b3a97 100644 --- a/pkg/ebpf/tracee.go +++ b/pkg/ebpf/tracee.go @@ -362,12 +362,16 @@ func (t *Tracee) Init(ctx gocontext.Context) error { err = capabilities.GetInstance().Specific( func() error { - t.kernelSymbols, err = environment.NewKernelSymbolTable( - environment.WithRequiredSymbols(t.requiredKsyms), - ) + t.kernelSymbols = environment.NewKernelSymbolTable(true, true) + // t.requiredKsyms may contain non-data symbols, but it doesn't affect the validity of this call + t.kernelSymbols.AddRequiredDataSymbols(t.requiredKsyms) + err := t.kernelSymbols.Update() + if err != nil { + return err + } // Cleanup memory in list t.requiredKsyms = []string{} - return err + return nil }, cap.SYSLOG, ) @@ -920,7 +924,7 @@ func getUnavailbaleKsymbols(ksymbols []events.KSymbol, kernelSymbols *environmen continue } for _, s := range sym { - if s.Address == 0 { + if s.Address() == 0 { // Same if the symbol is found but its address is 0. unavailableSymbols = append(unavailableSymbols, ksymbol) } @@ -1726,7 +1730,7 @@ func (t *Tracee) triggerSeqOpsIntegrityCheck(event trace.Event) { if err != nil { continue } - seqOpsPointers[i] = seqOpsStruct[0].Address + seqOpsPointers[i] = seqOpsStruct[0].Address() } eventHandle := t.triggerContexts.Store(event) _ = t.triggerSeqOpsIntegrityCheckCall( @@ -1852,7 +1856,7 @@ func (t *Tracee) triggerMemDump(event trace.Event) []error { } } eventHandle := t.triggerContexts.Store(event) - _ = t.triggerMemDumpCall(symbol[0].Address, length, uint64(eventHandle)) + _ = t.triggerMemDumpCall(symbol[0].Address(), length, uint64(eventHandle)) } } } diff --git a/pkg/events/derive/hooked_seq_ops.go b/pkg/events/derive/hooked_seq_ops.go index 08af6ba5fe5f..29e6b3aefa51 100644 --- a/pkg/events/derive/hooked_seq_ops.go +++ b/pkg/events/derive/hooked_seq_ops.go @@ -4,7 +4,6 @@ import ( "github.com/aquasecurity/tracee/pkg/errfmt" "github.com/aquasecurity/tracee/pkg/events" "github.com/aquasecurity/tracee/pkg/events/parse" - "github.com/aquasecurity/tracee/pkg/utils" "github.com/aquasecurity/tracee/pkg/utils/environment" "github.com/aquasecurity/tracee/types/trace" ) @@ -43,11 +42,11 @@ func deriveHookedSeqOpsArgs(kernelSymbols *environment.KernelSymbolTable) derive if addr == 0 { continue } - hookingFunction := utils.ParseSymbol(addr, kernelSymbols) + hookingFunction := kernelSymbols.GetPotentiallyHiddenSymbolByAddr(addr)[0] seqOpsStruct := NetSeqOps[i/4] seqOpsFunc := NetSeqOpsFuncs[i%4] hookedSeqOps[seqOpsStruct+"_"+seqOpsFunc] = - trace.HookedSymbolData{SymbolName: hookingFunction.Name, ModuleOwner: hookingFunction.Owner} + trace.HookedSymbolData{SymbolName: hookingFunction.Name(), ModuleOwner: hookingFunction.Owner()} } return []interface{}{hookedSeqOps}, nil } diff --git a/pkg/events/derive/hooked_syscall.go b/pkg/events/derive/hooked_syscall.go index 9160e274f346..c32b1161eaac 100644 --- a/pkg/events/derive/hooked_syscall.go +++ b/pkg/events/derive/hooked_syscall.go @@ -54,8 +54,8 @@ func deriveDetectHookedSyscallArgs(kernelSymbols *environment.KernelSymbolTable) hookedOwner := "" hookedFuncSymbol, err := kernelSymbols.GetSymbolByAddr(address) if err == nil { - hookedFuncName = hookedFuncSymbol[0].Name - hookedOwner = hookedFuncSymbol[0].Owner + hookedFuncName = hookedFuncSymbol[0].Name() + hookedOwner = hookedFuncSymbol[0].Owner() } syscallName := convertToSyscallName(syscallId) diff --git a/pkg/utils/environment/kernel_symbols.go b/pkg/utils/environment/kernel_symbols.go index cb2b119c99f4..4ae470d27355 100644 --- a/pkg/utils/environment/kernel_symbols.go +++ b/pkg/utils/environment/kernel_symbols.go @@ -2,303 +2,223 @@ package environment import ( "bufio" - "fmt" + "io" "os" "strconv" "strings" "sync" -) -const ( - kallsymsPath = "/proc/kallsyms" - chanBuffer = 112800 // TODO: check if we really need this buffer size + "github.com/aquasecurity/tracee/pkg/errfmt" + "github.com/aquasecurity/tracee/pkg/utils" ) -type KernelSymbol struct { - Name string - Type string - Address uint64 - Owner string -} -type nameAndOwner struct { - name string - owner string +// Kernel symbols do not have an associated size, so we define a sensible size +// limit to prevent unrelated symbols from being returned for an address lookup +const maxSymbolSize = 0x100000 + +var ownersMu sync.RWMutex +var symbolOwners = []string{ + "system", + "hidden", } -type addrAndOwner struct { - addr uint64 - owner string +var symbolOwnerToIdx = map[string]uint16{ + "system": 0, + "hidden": 1, } -// KernelSymbolTable manages kernel symbols with multiple maps for fast lookup. -type KernelSymbolTable struct { - symbols map[string][]*KernelSymbol - addrs map[uint64][]*KernelSymbol - symByName map[nameAndOwner][]*KernelSymbol - symByAddr map[addrAndOwner][]*KernelSymbol - requiredSyms map[string]struct{} - requiredAddrs map[uint64]struct{} - onlyRequired bool - updateLock sync.Mutex - updateWg sync.WaitGroup +type KernelSymbol struct { + name string + address uint64 + owner uint16 } -func symNotFoundErr(v interface{}) error { - return fmt.Errorf("symbol not found: %v", v) +func (ks KernelSymbol) Name() string { + return ks.name } -// NewKernelSymbolTable initializes a KernelSymbolTable with optional configuration functions. -func NewKernelSymbolTable(opts ...KSymbTableOption) (*KernelSymbolTable, error) { - k := &KernelSymbolTable{} - for _, opt := range opts { - if err := opt(k); err != nil { - return nil, err - } - } - - // Set onlyRequired to true if there are required symbols or addresses - k.onlyRequired = k.requiredAddrs != nil || k.requiredSyms != nil - - // Initialize maps if they are nil - if k.requiredSyms == nil { - k.requiredSyms = make(map[string]struct{}) - } - if k.requiredAddrs == nil { - k.requiredAddrs = make(map[uint64]struct{}) - } +func (ks KernelSymbol) Address() uint64 { + return ks.address +} - return k, k.Refresh() +func (ks KernelSymbol) Contains(address uint64) bool { + return ks.address <= address && ks.address+maxSymbolSize > address } -// KSymbTableOption defines a function signature for configuration options. -type KSymbTableOption func(k *KernelSymbolTable) error +func (ks KernelSymbol) Owner() string { + ownersMu.RLock() + defer ownersMu.RUnlock() -// WithRequiredSymbols sets the required symbols for the KernelSymbolTable. -func WithRequiredSymbols(reqSyms []string) KSymbTableOption { - return func(k *KernelSymbolTable) error { - k.requiredSyms = sliceToValidationMap(reqSyms) - return nil - } + return symbolOwners[ks.owner] } -// WithRequiredAddresses sets the required addresses for the KernelSymbolTable. -func WithRequiredAddresses(reqAddrs []uint64) KSymbTableOption { - return func(k *KernelSymbolTable) error { - k.requiredAddrs = sliceToValidationMap(reqAddrs) - return nil - } +type KernelSymbolTable struct { + symbols *utils.SymbolTable[KernelSymbol] + requiredDataSymbolsOnly bool + mu sync.Mutex + requiredDataSymbols map[string]struct{} } -// TextSegmentContains returns true if the given address is in the kernel text segment. -func (k *KernelSymbolTable) TextSegmentContains(addr uint64) (bool, error) { - k.updateLock.Lock() - defer k.updateLock.Unlock() - - segStart, segEnd, err := k.getTextSegmentAddresses() - if err != nil { - return false, err +// Creates a new KernelSymbolTable. +// If lazyNameLookup is true, the mapping from name to symbol will be populated +// only when a failed lookup occurs. This reduces memory footprint at the cost +// of the time it takes to lookup a symbol name for the first time. +// If requiredDataSymbolsOnly is true, data symbols will be added only if they +// were explicitly selected using `AddRequiredDataSymbols()`. +func NewKernelSymbolTable(lazyNameLookup bool, requiredDataSymbolsOnly bool) *KernelSymbolTable { + return &KernelSymbolTable{ + symbols: utils.NewSymbolTable[KernelSymbol](lazyNameLookup), + requiredDataSymbolsOnly: requiredDataSymbolsOnly, + requiredDataSymbols: make(map[string]struct{}), } - - return addr >= segStart && addr < segEnd, nil } -// GetSymbolByName returns all the symbols with the given name. -func (k *KernelSymbolTable) GetSymbolByName(name string) ([]KernelSymbol, error) { - k.updateLock.Lock() - defer k.updateLock.Unlock() - - if err := k.validateOrAddRequiredSym(name); err != nil { - return nil, err - } - - symbols, exist := k.symbols[name] - if !exist { - return nil, symNotFoundErr(name) +// Add a list of symbol names to the list of required data symbols. +// The next time `Update()` will be called, they will be added to the symbol +// table (if they exist). +// The symbol names don't have to be of data symbols. If other symbol types +// are present they will have no effect on the update logic. +func (kst *KernelSymbolTable) AddRequiredDataSymbols(symbolNames []string) { + if !kst.requiredDataSymbolsOnly { + return } - return copySliceOfPointersToSliceOfStructs(symbols), nil -} - -// GetSymbolByAddr returns all the symbols with the given address. -func (k *KernelSymbolTable) GetSymbolByAddr(addr uint64) ([]KernelSymbol, error) { - k.updateLock.Lock() - defer k.updateLock.Unlock() - - if err := k.validateOrAddRequiredAddr(addr); err != nil { - return nil, err - } + kst.mu.Lock() + defer kst.mu.Unlock() - symbols, exist := k.addrs[addr] - if !exist { - return nil, symNotFoundErr(addr) + for _, name := range symbolNames { + kst.requiredDataSymbols[name] = struct{}{} } - - return copySliceOfPointersToSliceOfStructs(symbols), nil } -// GetSymbolByOwnerAndName returns all the symbols with the given owner and name. -func (k *KernelSymbolTable) GetSymbolByOwnerAndName(owner, name string) ([]KernelSymbol, error) { - k.updateLock.Lock() - defer k.updateLock.Unlock() - - if err := k.validateOrAddRequiredSym(name); err != nil { - return nil, err - } - - symbols, exist := k.symByName[nameAndOwner{name, owner}] - if !exist { - return nil, symNotFoundErr(nameAndOwner{name, owner}) +// Read the contents of /proc/kallsyms and update the symbol table. +func (kst *KernelSymbolTable) Update() error { + file, err := os.Open("/proc/kallsyms") + if err != nil { + return errfmt.WrapError(err) } + defer func() { + _ = file.Close() + }() - return copySliceOfPointersToSliceOfStructs(symbols), nil + return kst.UpdateFromReader(file) } -// GetSymbolByOwnerAndAddr returns all the symbols with the given owner and address. -func (k *KernelSymbolTable) GetSymbolByOwnerAndAddr(owner string, addr uint64) ([]KernelSymbol, error) { - k.updateLock.Lock() - defer k.updateLock.Unlock() +// Read the contents of the given buffer and update the symbol table +func (kst *KernelSymbolTable) UpdateFromReader(reader io.Reader) error { + kst.symbols.Clear() + symbols := []*KernelSymbol{} - if err := k.validateOrAddRequiredAddr(addr); err != nil { - return nil, err - } + // Make sure we hold the required privileges by checking if we see actual addresses + seenRealAddress := false - symbols, exist := k.symByAddr[addrAndOwner{addr, owner}] - if !exist { - return nil, symNotFoundErr(addrAndOwner{addr, owner}) - } - - return copySliceOfPointersToSliceOfStructs(symbols), nil -} - -// getTextSegmentAddresses gets the start and end addresses of the kernel text segment. -func (k *KernelSymbolTable) getTextSegmentAddresses() (uint64, uint64, error) { - stext, exist1 := k.symByName[nameAndOwner{"_stext", "system"}] - etext, exist2 := k.symByName[nameAndOwner{"_etext", "system"}] + scanner := bufio.NewScanner(reader) + for scanner.Scan() { + fields := strings.Fields(scanner.Text()) + if len(fields) < 3 { + continue + } - if !exist1 || !exist2 { - return 0, 0, fmt.Errorf("kernel text segment symbol(s) not found") - } + symbolAddr, err := strconv.ParseUint(fields[0], 16, 64) + if err != nil { + continue + } + if symbolAddr != 0 { + seenRealAddress = true + } - textSegStart := stext[0].Address - textSegEnd := etext[0].Address + symbolType := fields[1] + symbolName := fields[2] - return textSegStart, textSegEnd, nil -} + symbolOwner := "system" + if len(fields) > 3 { + symbolOwner = fields[3] + symbolOwner = strings.TrimPrefix(symbolOwner, "[") + symbolOwner = strings.TrimSuffix(symbolOwner, "]") + } -// validateOrAddRequiredSym checks if the given symbol is in the required list and adds it if not. -func (k *KernelSymbolTable) validateOrAddRequiredSym(sym string) error { - return k.validateOrAddRequired(func() bool { - _, ok := k.requiredSyms[sym] - return ok - }, func() { - k.requiredSyms[sym] = struct{}{} - }) -} + // This is a data symbol, requiredDataSymbolsOnly is true, and this symbol isn't required + if strings.ContainsAny(symbolType, "DdBbRr") && kst.requiredDataSymbolsOnly { + if _, exists := kst.requiredDataSymbols[symbolName]; !exists { + continue + } + } -// validateOrAddRequiredAddr checks if the given address is in the required list and adds it if not. -func (k *KernelSymbolTable) validateOrAddRequiredAddr(addr uint64) error { - return k.validateOrAddRequired(func() bool { - _, ok := k.requiredAddrs[addr] - return ok - }, func() { - k.requiredAddrs[addr] = struct{}{} - }) -} + // Get index of symbol owner, or add it if it doesn't exist + ownersMu.RLock() + ownerIdx, found := symbolOwnerToIdx[symbolOwner] + ownersMu.RUnlock() + if !found { + ownersMu.Lock() + symbolOwners = append(symbolOwners, symbolOwner) + ownerIdx = uint16(len(symbolOwners) - 1) + symbolOwnerToIdx[symbolOwner] = ownerIdx + ownersMu.Unlock() + } -// validateOrAddRequired is a common function to check and add required symbols or addresses. -func (k *KernelSymbolTable) validateOrAddRequired(checkRequired func() bool, addRequired func()) error { - if !k.onlyRequired { - return nil + symbols = append(symbols, &KernelSymbol{ + name: symbolName, + address: symbolAddr, + owner: ownerIdx, + }) } - if !checkRequired() { - addRequired() - return k.refresh() + // We didn't hold the required privileges + if len(symbols) > 0 && !seenRealAddress { + return errfmt.Errorf("insufficient privileges when reading from /proc/kallsyms") } + // Update the symbol table + kst.symbols.AddSymbols(symbols) + return nil } -// Refresh is the exported method that acquires the lock and calls the internal refresh method. -func (k *KernelSymbolTable) Refresh() error { - k.updateLock.Lock() - defer k.updateLock.Unlock() - return k.refresh() +// GetSymbolByName returns all the symbols with the given name. +func (kst *KernelSymbolTable) GetSymbolByName(name string) ([]*KernelSymbol, error) { + symbols, err := kst.symbols.LookupByName(name) + return symbols, errfmt.WrapError(err) } -// refresh refreshes the KernelSymbolTable, reading the symbols from /proc/kallsyms. -func (k *KernelSymbolTable) refresh() error { - // Re-initialize the maps to include all new symbols. - k.symbols = make(map[string][]*KernelSymbol) - k.addrs = make(map[uint64][]*KernelSymbol) - k.symByName = make(map[nameAndOwner][]*KernelSymbol) - k.symByAddr = make(map[addrAndOwner][]*KernelSymbol) - - // Open the kallsyms file. - file, err := os.Open(kallsymsPath) +// GetSymbolByOwnerAndName returns all the symbols with the given owner and name. +func (kst *KernelSymbolTable) GetSymbolByOwnerAndName(owner, name string) ([]*KernelSymbol, error) { + symbols, err := kst.symbols.LookupByName(name) if err != nil { - return err + return nil, errfmt.WrapError(err) } - defer func() { - _ = file.Close() - }() - // Read the kallsyms file line by line and process each line. - scanner := bufio.NewScanner(file) - for scanner.Scan() { - fields := strings.Fields(scanner.Text()) - if len(fields) < 3 { - continue - } - sym := parseKallsymsLine(fields) - if sym == nil { - continue + // Find symbols that have the requested owner + filteredSymbols := []*KernelSymbol{} + for _, symbol := range symbols { + if symbolOwners[symbol.owner] == owner { + filteredSymbols = append(filteredSymbols, symbol) } - - if k.onlyRequired { - _, symRequired := k.requiredSyms[sym.Name] - _, addrRequired := k.requiredAddrs[sym.Address] - if !symRequired && !addrRequired { - continue - } - } - - k.symbols[sym.Name] = append(k.symbols[sym.Name], sym) - k.addrs[sym.Address] = append(k.addrs[sym.Address], sym) - k.symByName[nameAndOwner{sym.Name, sym.Owner}] = append(k.symByName[nameAndOwner{sym.Name, sym.Owner}], sym) - k.symByAddr[addrAndOwner{sym.Address, sym.Owner}] = append(k.symByAddr[addrAndOwner{sym.Address, sym.Owner}], sym) } - err = scanner.Err() - return err + return filteredSymbols, nil } -// parseKallsymsLine parses a line from /proc/kallsyms and returns a KernelSymbol. -func parseKallsymsLine(line []string) *KernelSymbol { - if len(line) < 3 { - return nil - } +// GetSymbolByAddr returns all the symbols with the given address. +func (kst *KernelSymbolTable) GetSymbolByAddr(addr uint64) ([]*KernelSymbol, error) { + symbols, err := kst.symbols.LookupByAddressExact(addr) + return symbols, errfmt.WrapError(err) +} - symbolAddr, err := strconv.ParseUint(line[0], 16, 64) +// GetPotentiallyHiddenSymbolByAddr returns all the symbols with the given address, +// or if none are found, a fake symbol with the "hidden" owner. +func (kst *KernelSymbolTable) GetPotentiallyHiddenSymbolByAddr(addr uint64) []*KernelSymbol { + symbols, err := kst.symbols.LookupByAddressExact(addr) if err != nil { - return nil + return []*KernelSymbol{{ + address: addr, + owner: symbolOwnerToIdx["hidden"], + }} } - symbolType := line[1] - symbolName := line[2] - - symbolOwner := "system" - if len(line) > 3 { - line[3] = strings.TrimPrefix(line[3], "[") - line[3] = strings.TrimSuffix(line[3], "]") - symbolOwner = line[3] - } + return symbols +} - return &KernelSymbol{ - Name: symbolName, - Type: symbolType, - Address: symbolAddr, - Owner: symbolOwner, - } +func (kst *KernelSymbolTable) ForEachSymbol(callback func(*KernelSymbol)) { + kst.symbols.ForEachSymbol(callback) } // copySliceOfPointersToSliceOfStructs converts a slice of pointers to a slice of structs. @@ -309,12 +229,3 @@ func copySliceOfPointersToSliceOfStructs(s []*KernelSymbol) []KernelSymbol { } return ret } - -// sliceToValidationMap converts a slice to a map for validation purposes. -func sliceToValidationMap[T comparable](items []T) map[T]struct{} { - res := make(map[T]struct{}) - for _, item := range items { - res[item] = struct{}{} - } - return res -} diff --git a/pkg/utils/environment/kernel_symbols_test.go b/pkg/utils/environment/kernel_symbols_test.go index 834b58026a1f..d1c2f217b58d 100644 --- a/pkg/utils/environment/kernel_symbols_test.go +++ b/pkg/utils/environment/kernel_symbols_test.go @@ -2,175 +2,179 @@ package environment import ( "reflect" + "strings" "testing" ) -// TestParseLine tests the parseKallsymsLine function. -func TestParseKallsymsLine(t *testing.T) { - testCases := []struct { - line []string - expected *KernelSymbol - }{ - {[]string{"00000000", "t", "my_symbol", "[my_owner]"}, &KernelSymbol{Name: "my_symbol", Type: "t", Address: 0, Owner: "my_owner"}}, - {[]string{"00000001", "T", "another_symbol"}, &KernelSymbol{Name: "another_symbol", Type: "T", Address: 1, Owner: "system"}}, - {[]string{"invalid_address", "T", "invalid_symbol"}, nil}, - {[]string{"00000002", "T"}, nil}, - } +type symbolInfo struct { + name string + address uint64 + owner string +} - for _, tc := range testCases { - result := parseKallsymsLine(tc.line) - if !reflect.DeepEqual(result, tc.expected) { - t.Errorf("parseKallsymsLine(%v) = %v; want %v", tc.line, result, tc.expected) - } +func symbolToSymbolInfo(symbol *KernelSymbol) *symbolInfo { + if symbol == nil { + return nil + } + return &symbolInfo{ + name: symbol.Name(), + address: symbol.Address(), + owner: symbol.Owner(), } } // TestNewKernelSymbolTable tests the NewKernelSymbolTable function. func TestNewKernelSymbolTable(t *testing.T) { - kst, err := NewKernelSymbolTable() - if err != nil { - t.Fatalf("NewKernelSymbolTable() failed: %v", err) - } + kst := NewKernelSymbolTable(true, true) if kst == nil { t.Fatalf("NewKernelSymbolTable() returned nil") } - // Check if the onlyRequired flag is set correctly - if kst.onlyRequired { - t.Errorf("onlyRequired flag should be false by default") - } - // Check if maps are initialized - if kst.symbols == nil || kst.addrs == nil || kst.symByName == nil || kst.symByAddr == nil { - t.Errorf("KernelSymbolTable maps are not initialized correctly") + if kst.symbols == nil || kst.requiredDataSymbols == nil { + t.Errorf("KernelSymbolTable is not initialized correctly") } } -// TestGetSymbolByName tests the GetSymbolByName function. -func TestGetSymbolByName(t *testing.T) { - kst, err := NewKernelSymbolTable() - if err != nil { - t.Fatalf("NewKernelSymbolTable() failed: %v", err) - } +func getTheOnlySymbol(t *testing.T, kst *KernelSymbolTable) *KernelSymbol { + i := 0 + var foundSymbol *KernelSymbol + kst.ForEachSymbol(func(symbol *KernelSymbol) { + i++ + foundSymbol = symbol + }) + if i > 1 { + t.Errorf("multiple symbols found") + } + return foundSymbol +} - kst.symbols["test_symbol"] = []*KernelSymbol{ - {Name: "test_symbol", Type: "t", Address: 0, Owner: "test_owner"}, +// TestUpdate tests the kallsyms parsing logic. +func TestUpdate(t *testing.T) { + testCases := []struct { + buf string + expected *symbolInfo + }{ + {"00000001 t my_symbol [my_owner]", &symbolInfo{name: "my_symbol", address: 1, owner: "my_owner"}}, + {"00000002 T another_symbol", &symbolInfo{name: "another_symbol", address: 2, owner: "system"}}, + {"invalid_address T invalid_symbol", nil}, + {"00000003 T", nil}, } - symbols, err := kst.GetSymbolByName("test_symbol") - if err != nil { - t.Fatalf("GetSymbolByName() failed: %v", err) + for _, tc := range testCases { + kst := NewKernelSymbolTable(false, false) + err := kst.UpdateFromReader(strings.NewReader(tc.buf)) + if err != nil { + t.Fatalf("Update() failed: %v", err) + } + symbol := getTheOnlySymbol(t, kst) + result := symbolToSymbolInfo(symbol) + if !reflect.DeepEqual(result, tc.expected) { + t.Errorf("Update(%v) = %v; want %v", tc.buf, result, tc.expected) + } } +} - if len(symbols) != 1 { - t.Errorf("Expected 1 symbol, got %d", len(symbols)) - } +// TestDuplicateUpdate tests 2 consecutive updates +func TestDuplicateUpdate(t *testing.T) { + buf1 := "00000001 t my_symbol" + buf2 := "00000002 t my_symbol" + expected := &symbolInfo{name: "my_symbol", address: 2, owner: "system"} - expectedSymbol := KernelSymbol{Name: "test_symbol", Type: "t", Address: 0, Owner: "test_owner"} - if !reflect.DeepEqual(symbols[0], expectedSymbol) { - t.Errorf("GetSymbolByName() = %v; want %v", symbols[0], expectedSymbol) - } -} + kst := NewKernelSymbolTable(false, false) + kst.UpdateFromReader(strings.NewReader(buf1)) + kst.UpdateFromReader(strings.NewReader(buf2)) -// TestGetSymbolByAddr tests the GetSymbolByAddr function. -func TestGetSymbolByAddr(t *testing.T) { - kst, err := NewKernelSymbolTable() - if err != nil { - t.Fatalf("NewKernelSymbolTable() failed: %v", err) + symbol := getTheOnlySymbol(t, kst) + result := symbolToSymbolInfo(symbol) + if !reflect.DeepEqual(result, expected) { + t.Errorf("Update(%v) = %v; want %v", buf2, result, expected) } +} - kst.addrs[0x1234] = []*KernelSymbol{ - {Name: "test_symbol", Type: "t", Address: 0x1234, Owner: "test_owner"}, - } +// TestGetSymbolByName tests the GetSymbolByName function. +func TestGetSymbolByName(t *testing.T) { + buf := "00000001 t test_symbol test_owner" + kst := NewKernelSymbolTable(false, false) + kst.UpdateFromReader(strings.NewReader(buf)) - symbols, err := kst.GetSymbolByAddr(0x1234) + symbols, err := kst.GetSymbolByName("test_symbol") if err != nil { - t.Fatalf("GetSymbolByAddr() failed: %v", err) + t.Fatalf("GetSymbolByName() failed: %v", err) } if len(symbols) != 1 { t.Errorf("Expected 1 symbol, got %d", len(symbols)) } - expectedSymbol := KernelSymbol{Name: "test_symbol", Type: "t", Address: 0x1234, Owner: "test_owner"} - if !reflect.DeepEqual(symbols[0], expectedSymbol) { - t.Errorf("GetSymbolByAddr() = %v; want %v", symbols[0], expectedSymbol) + expected := &symbolInfo{name: "test_symbol", address: 1, owner: "test_owner"} + result := symbolToSymbolInfo(symbols[0]) + if !reflect.DeepEqual(result, expected) { + t.Errorf("GetSymbolByName() = %v; want %v", result, expected) } } -// TestRefresh tests the Refresh function. -func TestRefresh(t *testing.T) { - // Creating a mock KernelSymbolTable with required symbols to test Refresh - kst, err := NewKernelSymbolTable(WithRequiredSymbols([]string{"_stext", "_etext"})) +// TestGetSymbolByOwnerAndName tests the GetSymbolByOwnerAndName function. +func TestGetSymbolByOwnerAndName(t *testing.T) { + buf := `00000001 t test_symbol test_owner1 +00000002 t test_symbol test_owner2` + kst := NewKernelSymbolTable(false, false) + kst.UpdateFromReader(strings.NewReader(buf)) + + symbols, err := kst.GetSymbolByOwnerAndName("test_owner1", "test_symbol") if err != nil { - t.Fatalf("NewKernelSymbolTable() failed: %v", err) + t.Fatalf("GetSymbolByName() failed: %v", err) } - // Simulate the presence of these symbols - kst.symbols["_stext"] = []*KernelSymbol{{Name: "_stext", Type: "T", Address: 0x1000, Owner: "system"}} - kst.symbols["_etext"] = []*KernelSymbol{{Name: "_etext", Type: "T", Address: 0x2000, Owner: "system"}} - - // Call Refresh to update the symbol table - if err := kst.Refresh(); err != nil { - t.Fatalf("Refresh() failed: %v", err) + if len(symbols) != 1 { + t.Errorf("Expected 1 symbol, got %d", len(symbols)) } - // Check if symbols were added correctly - symbolsToTest := []string{"_stext", "_etext"} - for _, symbol := range symbolsToTest { - if syms, err := kst.GetSymbolByName(symbol); err != nil || len(syms) == 0 { - t.Errorf("Expected to find symbol %s, but it was not found", symbol) - } + expected := &symbolInfo{name: "test_symbol", address: 1, owner: "test_owner1"} + result := symbolToSymbolInfo(symbols[0]) + if !reflect.DeepEqual(result, expected) { + t.Errorf("GetSymbolByOwnerAndName() = %v; want %v", result, expected) } } -// TestTextSegmentContains tests the TextSegmentContains function. -func TestTextSegmentContains(t *testing.T) { - // Creating a mock KernelSymbolTable with text segment addresses - kst, err := NewKernelSymbolTable() +// TestGetSymbolByAddr tests the GetSymbolByAddr function. +func TestGetSymbolByAddr(t *testing.T) { + buf := "00001234 t test_symbol test_owner" + kst := NewKernelSymbolTable(false, false) + kst.UpdateFromReader(strings.NewReader(buf)) + + symbols, err := kst.GetSymbolByAddr(0x1234) if err != nil { - t.Fatalf("NewKernelSymbolTable() failed: %v", err) + t.Fatalf("GetSymbolByAddr() failed: %v", err) } - kst.symByName[nameAndOwner{"_stext", "system"}] = []*KernelSymbol{{Name: "_stext", Type: "T", Address: 0x1000, Owner: "system"}} - kst.symByName[nameAndOwner{"_etext", "system"}] = []*KernelSymbol{{Name: "_etext", Type: "T", Address: 0x2000, Owner: "system"}} - - tests := []struct { - addr uint64 - expected bool - }{ - {0x1000, true}, - {0x1500, true}, - {0x2000, false}, - {0x0999, false}, + if len(symbols) != 1 { + t.Errorf("Expected 1 symbol, got %d", len(symbols)) } - for _, tt := range tests { - result, err := kst.TextSegmentContains(tt.addr) - if err != nil { - t.Errorf("TextSegmentContains(%v) failed: %v", tt.addr, err) - } - if result != tt.expected { - t.Errorf("TextSegmentContains(%v) = %v; want %v", tt.addr, result, tt.expected) - } + expected := &symbolInfo{name: "test_symbol", address: 0x1234, owner: "test_owner"} + result := symbolToSymbolInfo(symbols[0]) + if !reflect.DeepEqual(result, expected) { + t.Errorf("GetSymbolByAddr() = %v; want %v", result, expected) } } -// Helper function to test required symbols or addresses. -func TestValidateOrAddRequired(t *testing.T) { - kst, err := NewKernelSymbolTable(WithRequiredSymbols([]string{"test_symbol"})) - if err != nil { - t.Fatalf("NewKernelSymbolTable() failed: %v", err) - } +// TestGetPotentiallyHiddenSymbolByAddr tests the GetPotentiallyHiddenSymbolByAddr function. +func TestGetPotentiallyHiddenSymbolByAddr(t *testing.T) { + buf := "00000001 t test_symbol test_owner" + kst := NewKernelSymbolTable(false, false) + kst.UpdateFromReader(strings.NewReader(buf)) - kst.requiredSyms["test_symbol"] = struct{}{} + symbols := kst.GetPotentiallyHiddenSymbolByAddr(2) - if err := kst.validateOrAddRequiredSym("test_symbol"); err != nil { - t.Errorf("validateOrAddRequiredSym() failed: %v", err) + if len(symbols) != 1 { + t.Errorf("Expected 1 symbol, got %d", len(symbols)) } - if err := kst.validateOrAddRequiredAddr(0x1234); err != nil { - t.Errorf("validateOrAddRequiredAddr() failed: %v", err) + expected := &symbolInfo{name: "", address: 2, owner: "hidden"} + result := symbolToSymbolInfo(symbols[0]) + if !reflect.DeepEqual(result, expected) { + t.Errorf("GetSymbolByAddr() = %v; want %v", result, expected) } } diff --git a/pkg/utils/symbol_table.go b/pkg/utils/symbol_table.go new file mode 100644 index 000000000000..8632c53dab60 --- /dev/null +++ b/pkg/utils/symbol_table.go @@ -0,0 +1,215 @@ +package utils + +import ( + "errors" + "sort" + "sync" +) + +// The Symbol interface defines what is needed from a symbol implementation in +// order to facilitate the lookup functionalities provided by SymbolTable. +// Implementations of Symbol can hold various types of information relevant to +// the type of symbol they represent. +type Symbol interface { + // Name returns the symbol's name + Name() string + // Address returns the base address of the symbol + Address() uint64 + // Contains returns whether a given address belongs to the symbol's + // address range, which is defined by the symbol's implementation + Contains(address uint64) bool +} + +// SymbolTable is used to hold information about symbols (mapping from symbolic +// names used in code to their address) in a certain executable. +// It can be used to hold symbols from an ELF binary, or symbols of the entire +// kernel and its modules. +// It provides functions to lookup symbols by address and name. +type SymbolTable[T Symbol] struct { + mu sync.RWMutex + // All symbols sorted by their address in descending order, + // for quick binary searches by address. + sortedSymbols []*T + // If lazyNameLookup is true, the symbolsByName map + // will be populated only when a failed lookup occurs. + symbolsByName map[string][]*T + lazyNameLookup bool +} + +var ErrSymbolNotFound = errors.New("symbol not found") + +// Creates a new SymbolTable. If lazyNameLookup is true, the mapping from +// name to symbol will be populated only when a failed lookup occurs. +// This reduces memory footprint at the cost of the time it takes to lookup +// a symbol name for the first time. +func NewSymbolTable[T Symbol](lazyNameLookup bool) *SymbolTable[T] { + return &SymbolTable[T]{ + sortedSymbols: make([]*T, 0), + symbolsByName: make(map[string][]*T), + lazyNameLookup: lazyNameLookup, + } +} + +// Adds a slice of symbols to the symbol table. +func (st *SymbolTable[T]) AddSymbols(symbols []*T) { + st.mu.Lock() + defer st.mu.Unlock() + + // Add the new symbols to the sorted slice (which now becomes unsorted). + // Allocate the slice with the needed capacity to avoid overallocation. + oldSymbols := st.sortedSymbols + newLen := len(oldSymbols) + len(symbols) + st.sortedSymbols = make([]*T, 0, newLen) + st.sortedSymbols = append(st.sortedSymbols, oldSymbols...) + st.sortedSymbols = append(st.sortedSymbols, symbols...) + + // If lazyNameLookup is false, we update the name to symbol mapping for + // each new symbol + if !st.lazyNameLookup { + for _, symbol := range symbols { + name := (*symbol).Name() + if symbols, found := st.symbolsByName[name]; found { + st.symbolsByName[name] = append(symbols, symbol) + } else { + st.symbolsByName[name] = []*T{symbol} + } + } + } + + // Sort the symbols slice by address in descending order + sort.Slice(st.sortedSymbols, + func(i, j int) bool { + return (*st.sortedSymbols[i]).Address() > (*st.sortedSymbols[j]).Address() + }) +} + +// Lookup a symbol in the table by its name. +// Because there may be multiple symbols with the same name, a slice of all +// matching symbols is returned. +func (st *SymbolTable[T]) LookupByName(name string) ([]*T, error) { + st.mu.RLock() + // We call RUnlock manually and not using defer because we may need to upgrade to a write lock later + + // Lookup the name in the name to symbol mapping + if symbols, found := st.symbolsByName[name]; found { + st.mu.RUnlock() + return symbols, nil + } + + // Lazy name lookup is disabled, the lookup failed + if !st.lazyNameLookup { + st.mu.RUnlock() + return nil, ErrSymbolNotFound + } + + // Lazy name lookup is enabled, perform a linear search to find the requested name + symbols := []*T{} + for _, symbol := range st.sortedSymbols { + if (*symbol).Name() == name { + symbols = append(symbols, symbol) + } + } + + if len(symbols) > 0 { + // We found symbols with this name, update the mapping + st.mu.RUnlock() + st.mu.Lock() + defer st.mu.Unlock() + st.symbolsByName[name] = symbols + return symbols, nil + } + + st.mu.RUnlock() + return nil, ErrSymbolNotFound +} + +// Lookup a symbol in the table by its exact address. +// Because there may be multiple symbols at the same address, a slice of all +// matching symbols is returned. +func (st *SymbolTable[T]) LookupByAddressExact(address uint64) ([]*T, error) { + st.mu.RLock() + defer st.mu.RUnlock() + + // Find the first symbol at an address smaller than or equal to the requested address + idx := sort.Search(len(st.sortedSymbols), + func(i int) bool { + return address >= (*st.sortedSymbols[i]).Address() + }) + + // Not found or not exact match + if idx == len(st.sortedSymbols) || (*st.sortedSymbols[idx]).Address() != address { + return nil, ErrSymbolNotFound + } + + // The search result is the first symbol with the requested address, + // find any additional symbols with the same address. + syms := []*T{st.sortedSymbols[idx]} + for i := idx + 1; i < len(st.sortedSymbols); i++ { + if (*st.sortedSymbols[i]).Address() != address { + break + } + syms = append(syms, st.sortedSymbols[i]) + } + + return syms, nil +} + +// Find the symbol which contains the given address. +// If multiple symbols at different addresses contain the requested address, +// the symbol with the highest address will be returned. +// If multiple symbols at the same address contain the requested address, +// one of them will be returned, but there is no guarantee which one. +// This function assumes that symbols don't overlap in a way that a symbol with +// a smaller address contains the requested address while a symbol with a larger +// address (but still smaller that requested) doesn't contain it. +// For example, the following situation is assumed to be impossible: +// +// Requested Address +// | +// | +// +---------------+--+ +// |Symbol 1 | | +// +---------------+--+ +// +--------+ | +// |Symbol 2| | +// +--------+ v +// <----------------------> +// +// Smaller Larger +// Address Address +// +// If the above situation happens, no symbol will be returned. +func (st *SymbolTable[T]) LookupByAddressContains(address uint64) (*T, error) { + st.mu.RLock() + defer st.mu.RUnlock() + + // Find the first symbol at an address smaller than or equal to the requested address + idx := sort.Search(len(st.sortedSymbols), + func(i int) bool { + return address >= (*st.sortedSymbols[i]).Address() + }) + + // Not found or the symbol doesn't contain this address + if idx == len(st.sortedSymbols) || !(*st.sortedSymbols[idx]).Contains(address) { + return nil, ErrSymbolNotFound + } + + return st.sortedSymbols[idx], nil +} + +func (st *SymbolTable[T]) ForEachSymbol(callback func(symbol *T)) { + st.mu.RLock() + defer st.mu.RUnlock() + + for i := range len(st.sortedSymbols) { + callback(st.sortedSymbols[i]) + } +} + +func (st *SymbolTable[T]) Clear() { + st.mu.Lock() + defer st.mu.Unlock() + + st.sortedSymbols = make([]*T, 0) + clear(st.symbolsByName) +} diff --git a/pkg/utils/symbol_table_test.go b/pkg/utils/symbol_table_test.go new file mode 100644 index 000000000000..c3af183f320b --- /dev/null +++ b/pkg/utils/symbol_table_test.go @@ -0,0 +1,336 @@ +package utils + +import ( + "reflect" + "testing" +) + +type testSymbol struct { + name string + addr uint64 + size uint64 +} + +func (s testSymbol) Name() string { + return s.name +} + +func (s testSymbol) Address() uint64 { + return s.addr +} + +func (s testSymbol) Contains(address uint64) bool { + return s.addr <= address && s.addr+s.size > address +} + +// TestNewSymbolTable tests the NewSymbolTable function. +func TestNewSymbolTable(t *testing.T) { + st := NewSymbolTable[testSymbol](true) + if st == nil { + t.Fatalf("NewSymbolTable() returned nil") + } + + if !st.lazyNameLookup { + t.Errorf("lazyNameLookup was not set to true") + } + + if st.sortedSymbols == nil || st.symbolsByName == nil { + t.Errorf("data structures are nil") + } +} + +// TestAddSymbols tests the AddSymbols function +func TestAddSymbols(t *testing.T) { + testCases := []struct { + symbols []*testSymbol + expectedOrder []int + }{ + {[]*testSymbol{ + {name: "symbol1", addr: 1, size: 1}, + {name: "symbol2", addr: 1, size: 1}, + }, []int{0, 1}}, + {[]*testSymbol{ + {name: "symbol1", addr: 2, size: 1}, + {name: "symbol2", addr: 1, size: 1}, + }, []int{0, 1}}, + {[]*testSymbol{ + {name: "symbol1", addr: 1, size: 1}, + {name: "symbol2", addr: 2, size: 1}, + }, []int{1, 0}}, + } + + for _, tc := range testCases { + st := NewSymbolTable[testSymbol](false) + st.AddSymbols(tc.symbols) + + if len(st.sortedSymbols) != len(tc.symbols) { + t.Errorf("len(st.sortedSymbol) = %d, want %d", len(st.sortedSymbols), len(tc.symbols)) + continue + } + + for i := range st.sortedSymbols { + if !reflect.DeepEqual(*st.sortedSymbols[i], *tc.symbols[tc.expectedOrder[i]]) { + t.Errorf("AddSymbols(%v) = symbol %d: %v; want %v", tc.symbols, i, st.sortedSymbols[i], tc.symbols[tc.expectedOrder[i]]) + } + } + } +} + +// TestLookupByName tests the LookupByName function +func TestLookupByName(t *testing.T) { + testCases := []struct { + symbols []*testSymbol + lookupName string + expectLookupError bool + expected []testSymbol + }{ + { + []*testSymbol{{name: "symbol1", addr: 1, size: 1}}, + "symbol1", + false, + []testSymbol{{name: "symbol1", addr: 1, size: 1}}, + }, + { + []*testSymbol{}, + "symbol2", + true, + []testSymbol{}, + }, + { + []*testSymbol{{name: "symbol3", addr: 1, size: 1}}, + "symbol4", + true, + []testSymbol{}, + }, + { + []*testSymbol{{name: "symbol5", addr: 1, size: 1}, {name: "symbol6", addr: 2, size: 2}}, + "symbol6", + false, + []testSymbol{{name: "symbol6", addr: 2, size: 2}}, + }, + { + []*testSymbol{{name: "symbol7", addr: 1, size: 1}, {name: "symbol7", addr: 2, size: 2}}, + "symbol7", + false, + []testSymbol{{name: "symbol7", addr: 1, size: 1}, {name: "symbol7", addr: 2, size: 2}}, + }, + } + + for _, tc := range testCases { + st := NewSymbolTable[testSymbol](false) + st.AddSymbols(tc.symbols) + result, err := st.LookupByName(tc.lookupName) + if !tc.expectLookupError && err != nil { + t.Errorf("LookupByName(%s) failed: %v", tc.lookupName, err) + continue + } else if tc.expectLookupError { + if err == nil { + t.Errorf("LookupByName(%s) expected to fail but didn't", tc.lookupName) + } + continue + } + if !reflect.DeepEqual(copySliceOfPointersToSliceOfStructs(result), tc.expected) { + t.Errorf("LookupByName(%s) = %v, expected %v", tc.lookupName, copySliceOfPointersToSliceOfStructs(result), tc.expected) + } + } +} + +// TestLazyNameLookup tests the lazy name lookup functionality +func TestLazyNameLookup(t *testing.T) { + testCases := []struct { + symbols []*testSymbol + lazyNameLookup bool + lookups []string + expectedMappings []int + }{ + { + []*testSymbol{{name: "symbol", addr: 1, size: 1}}, + false, + []string{}, + []int{0}, + }, + { + []*testSymbol{{name: "symbol", addr: 1, size: 1}}, + true, + []string{}, + []int{}, + }, + { + []*testSymbol{{name: "symbol1", addr: 1, size: 1}, {name: "symbol2", addr: 2, size: 1}}, + true, + []string{"symbol1", "symbol2"}, + []int{0, 1}, + }, + { + []*testSymbol{{name: "symbol1", addr: 1, size: 1}, {name: "symbol2", addr: 2, size: 1}}, + true, + []string{"symbol2"}, + []int{1}, + }, + } + +testLoop: + for _, tc := range testCases { + st := NewSymbolTable[testSymbol](tc.lazyNameLookup) + st.AddSymbols(tc.symbols) + if tc.lazyNameLookup { + if len(st.symbolsByName) != 0 { + t.Errorf("len(st.symbolsByName) = %d, expected 0", len(st.symbolsByName)) + continue + } + } else { + if len(st.symbolsByName) != len(tc.symbols) { + t.Errorf("len(st.symbolsByName) = %d, expected %d", len(st.symbolsByName), len(tc.symbols)) + continue + } + } + for _, lookupName := range tc.lookups { + _, err := st.LookupByName(lookupName) + if err != nil { + t.Errorf("LookupByName(%s) failed: %v", lookupName, err) + continue testLoop + } + } + for i := range tc.expectedMappings { + if !reflect.DeepEqual(*(st.symbolsByName[tc.symbols[tc.expectedMappings[i]].name][0]), *tc.symbols[tc.expectedMappings[i]]) { + t.Errorf("st.symbolsByName[\"%s\"] = %v, expected %v", tc.symbols[tc.expectedMappings[i]].name, *(st.symbolsByName[tc.symbols[tc.expectedMappings[i]].name][0]), *tc.symbols[tc.expectedMappings[i]]) + continue + } + } + } +} + +// TestLookupByAddressExact tests the LookupByAddressExact function +func TestLookupByAddressExaxt(t *testing.T) { + testCases := []struct { + symbols []*testSymbol + lookupAddr uint64 + expectLookupError bool + expected []testSymbol + }{ + { + []*testSymbol{{name: "symbol1", addr: 1, size: 1}}, + 1, + false, + []testSymbol{{name: "symbol1", addr: 1, size: 1}}, + }, + { + []*testSymbol{}, + 2, + true, + []testSymbol{}, + }, + { + []*testSymbol{{name: "symbol3", addr: 3, size: 1}}, + 4, + true, + []testSymbol{}, + }, + { + []*testSymbol{{name: "symbol5", addr: 5, size: 1}, {name: "symbol6", addr: 6, size: 2}}, + 6, + false, + []testSymbol{{name: "symbol6", addr: 6, size: 2}}, + }, + { + []*testSymbol{{name: "symbol7", addr: 7, size: 1}, {name: "symbol8", addr: 7, size: 2}}, + 7, + false, + []testSymbol{{name: "symbol7", addr: 7, size: 1}, {name: "symbol8", addr: 7, size: 2}}, + }, + } + + for _, tc := range testCases { + st := NewSymbolTable[testSymbol](false) + st.AddSymbols(tc.symbols) + result, err := st.LookupByAddressExact(tc.lookupAddr) + if !tc.expectLookupError && err != nil { + t.Errorf("LookupByAddressExact(%d) failed: %v", tc.lookupAddr, err) + continue + } else if tc.expectLookupError && err == nil { + t.Errorf("LookupByAddressExact(%d) expected to fail but didn't", tc.lookupAddr) + continue + } + if !reflect.DeepEqual(copySliceOfPointersToSliceOfStructs(result), tc.expected) { + t.Errorf("LookupByAddressExact(%d) = %v, expected %v", tc.lookupAddr, copySliceOfPointersToSliceOfStructs(result), tc.expected) + } + } +} + +// TestLookupByAddressContains tests the LookupByAddressContains function +func TestLookupByAddressContains(t *testing.T) { + testCases := []struct { + symbols []*testSymbol + lookupAddr uint64 + expected *testSymbol + }{ + { + []*testSymbol{}, + 1, + nil, + }, + { + []*testSymbol{{name: "symbol1", addr: 2, size: 2}}, + 2, + &testSymbol{name: "symbol1", addr: 2, size: 2}, + }, + { + []*testSymbol{{name: "symbol2", addr: 3, size: 2}}, + 4, + &testSymbol{name: "symbol2", addr: 3, size: 2}, + }, + { + []*testSymbol{{name: "symbol3", addr: 4, size: 2}}, + 6, + nil, + }, + { + []*testSymbol{{name: "symbol4", addr: 10, size: 2}}, + 8, + nil, + }, + { + []*testSymbol{{name: "symbol5", addr: 11, size: 2}}, + 14, + nil, + }, + { + []*testSymbol{{name: "symbol6", addr: 15, size: 5}, {name: "symbol7", addr: 17, size: 3}}, + 18, + &testSymbol{name: "symbol7", addr: 17, size: 3}, + }, + { // this is a special case assumed to be impossible in practice, see the docstring of LookupByAddressContains() + []*testSymbol{{name: "symbol8", addr: 20, size: 5}, {name: "symbol9", addr: 21, size: 2}}, + 23, + nil, + }, + } + + for _, tc := range testCases { + st := NewSymbolTable[testSymbol](false) + st.AddSymbols(tc.symbols) + result, err := st.LookupByAddressContains(tc.lookupAddr) + if tc.expected != nil && err != nil { + t.Errorf("LookupByAddressContains(%d) failed: %v", tc.lookupAddr, err) + continue + } + if tc.expected == nil { + if err == nil { + t.Errorf("LookupByAddressContains(%d) expected to fail, but returned %v", tc.lookupAddr, *result) + } + continue + } + if !reflect.DeepEqual(*result, *tc.expected) { + t.Errorf("LookupByAddressContains(%d) = %v, expected %v", tc.lookupAddr, *result, *tc.expected) + } + } +} + +// copySliceOfPointersToSliceOfStructs converts a slice of pointers to a slice of structs. +func copySliceOfPointersToSliceOfStructs(s []*testSymbol) []testSymbol { + ret := make([]testSymbol, len(s)) + for i, v := range s { + ret[i] = *v + } + return ret +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 3a88eb74232b..e72529dcb34f 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -5,10 +5,7 @@ import ( "io" "math/rand" "reflect" - "strings" "time" - - "github.com/aquasecurity/tracee/pkg/utils/environment" ) // Cloner is a generic interface for objects that can clone themselves. @@ -25,23 +22,6 @@ type Iterator[T any] interface { Next() T } -func ParseSymbol(address uint64, table *environment.KernelSymbolTable) environment.KernelSymbol { - var hookingFunction environment.KernelSymbol - - symbols, err := table.GetSymbolByAddr(address) - if err != nil { - hookingFunction = environment.KernelSymbol{} - hookingFunction.Owner = "hidden" - } else { - hookingFunction = symbols[0] - } - - hookingFunction.Owner = strings.TrimPrefix(hookingFunction.Owner, "[") - hookingFunction.Owner = strings.TrimSuffix(hookingFunction.Owner, "]") - - return hookingFunction -} - func HasBit(n uint64, offset uint) bool { return (n & (1 << offset)) > 0 } diff --git a/tests/e2e-inst-signatures/e2e-set_fs_pwd.go b/tests/e2e-inst-signatures/e2e-set_fs_pwd.go index f9cda190b64b..e3124d9c5938 100644 --- a/tests/e2e-inst-signatures/e2e-set_fs_pwd.go +++ b/tests/e2e-inst-signatures/e2e-set_fs_pwd.go @@ -4,6 +4,9 @@ import ( "fmt" "strings" + "kernel.org/pub/linux/libs/security/libcap/cap" + + "github.com/aquasecurity/tracee/pkg/capabilities" "github.com/aquasecurity/tracee/pkg/utils/environment" "github.com/aquasecurity/tracee/signatures/helpers" "github.com/aquasecurity/tracee/types/detect" @@ -21,7 +24,13 @@ func (sig *e2eSetFsPwd) Init(ctx detect.SignatureContext) error { // Find if this system has the bpf_probe_read_user_str helper. // If it doesn't we won't expect the unresolved path to contain anything - ksyms, err := environment.NewKernelSymbolTable() + ksyms := environment.NewKernelSymbolTable(false, false) + err := capabilities.GetInstance().Specific( + func() error { + return ksyms.Update() + }, + cap.SYSLOG, + ) if err != nil { return err }