Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Neo Core Exception] Create WalletException and use it to replace all exceptions in wallet #3434

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions src/Neo/Wallets/AssetDescriptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,29 @@ public class AssetDescriptor
/// <param name="asset_id">The id of the asset.</param>
public AssetDescriptor(DataCache snapshot, ProtocolSettings settings, UInt160 asset_id)
{
var contract = NativeContract.ContractManagement.GetContract(snapshot, asset_id);
if (contract is null) throw new ArgumentException(null, nameof(asset_id));
try
{
var contract = NativeContract.ContractManagement.GetContract(snapshot, asset_id);
if (contract is null) throw new WalletException(WalletErrorType.ContractNotFound, nameof(asset_id));

byte[] script;
using (ScriptBuilder sb = new())
byte[] script;
using (ScriptBuilder sb = new())
{
sb.EmitDynamicCall(asset_id, "decimals", CallFlags.ReadOnly);
sb.EmitDynamicCall(asset_id, "symbol", CallFlags.ReadOnly);
script = sb.ToArray();
}
using ApplicationEngine engine = ApplicationEngine.Run(script, snapshot, settings: settings, gas: 0_30000000L);
if (engine.State != VMState.HALT) throw new WalletException(WalletErrorType.ExecutionFault, nameof(asset_id));
AssetId = asset_id;
AssetName = contract.Manifest.Name;
Symbol = engine.ResultStack.Pop().GetString();
Decimals = (byte)engine.ResultStack.Pop().GetInteger();
}
catch (Exception ex) when (ex is not WalletException)
{
sb.EmitDynamicCall(asset_id, "decimals", CallFlags.ReadOnly);
sb.EmitDynamicCall(asset_id, "symbol", CallFlags.ReadOnly);
script = sb.ToArray();
throw WalletException.FromException(ex);
}
Comment on lines +72 to 75
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WalletException constructor is able to create WalletException from WalletException, so we can safely remove when (ex is not WalletException).

using ApplicationEngine engine = ApplicationEngine.Run(script, snapshot, settings: settings, gas: 0_30000000L);
if (engine.State != VMState.HALT) throw new ArgumentException(null, nameof(asset_id));
AssetId = asset_id;
AssetName = contract.Manifest.Name;
Symbol = engine.ResultStack.Pop().GetString();
Decimals = (byte)engine.ResultStack.Pop().GetInteger();
}

public override string ToString()
Expand Down
157 changes: 92 additions & 65 deletions src/Neo/Wallets/Helper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// Redistribution and use in source and binary forms with or without
// modifications are permitted.

using Akka.Routing;
shargon marked this conversation as resolved.
Show resolved Hide resolved
using Neo.Cryptography;
using Neo.IO;
using Neo.Network.P2P;
Expand Down Expand Up @@ -36,7 +37,14 @@ public static class Helper
/// <returns>The signature for the <see cref="IVerifiable"/>.</returns>
public static byte[] Sign(this IVerifiable verifiable, KeyPair key, uint network)
{
return Crypto.Sign(verifiable.GetSignData(network), key.PrivateKey);
try
{
return Crypto.Sign(verifiable.GetSignData(network), key.PrivateKey);
}
catch (Exception ex)
{
throw WalletException.FromException(ex);
}
}

/// <summary>
Expand All @@ -61,17 +69,24 @@ public static string ToAddress(this UInt160 scriptHash, byte version)
/// <returns>The converted script hash.</returns>
public static UInt160 ToScriptHash(this string address, byte version)
{
byte[] data = address.Base58CheckDecode();
if (data.Length != 21)
throw new FormatException();
if (data[0] != version)
throw new FormatException();
return new UInt160(data.AsSpan(1));
try
{
byte[] data = address.Base58CheckDecode();
if (data.Length != 21)
throw new WalletException(WalletErrorType.FormatError, "Invalid address format: incorrect length");
if (data[0] != version)
throw new WalletException(WalletErrorType.FormatError, "Invalid address version");
return new UInt160(data.AsSpan(1));
}
catch (Exception e) when (e is not WalletException)
{
throw new WalletException(WalletErrorType.FormatError, "Invalid address format");
}
}

internal static byte[] XOR(byte[] x, byte[] y)
{
if (x.Length != y.Length) throw new ArgumentException();
if (x.Length != y.Length) throw new WalletException(WalletErrorType.InvalidOperation, "Arrays must have the same length");
byte[] r = new byte[x.Length];
for (int i = 0; i < r.Length; i++)
r[i] = (byte)(x[i] ^ y[i]);
Expand All @@ -90,78 +105,90 @@ internal static byte[] XOR(byte[] x, byte[] y)
/// <returns>The network fee of the transaction.</returns>
public static long CalculateNetworkFee(this Transaction tx, DataCache snapshot, ProtocolSettings settings, Func<UInt160, byte[]> accountScript, long maxExecutionCost = ApplicationEngine.TestModeGas)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will make conflicts into my never ending pull request #3385 :'(

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may merge #3385 as it is now and move non-standard verification scripts support to another issue. I think it's acceptable way since #3385 contains everything that's required except non-standard verification scripts handling.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to #3539.

{
UInt160[] hashes = tx.GetScriptHashesForVerifying(snapshot);

// base size for transaction: includes const_header + signers + attributes + script + hashes
int size = Transaction.HeaderSize + tx.Signers.GetVarSize() + tx.Attributes.GetVarSize() + tx.Script.GetVarSize() + IO.Helper.GetVarSize(hashes.Length), index = -1;
uint exec_fee_factor = NativeContract.Policy.GetExecFeeFactor(snapshot);
long networkFee = 0;
foreach (UInt160 hash in hashes)
try
{
index++;
byte[] witnessScript = accountScript(hash);
byte[] invocationScript = null;
UInt160[] hashes = tx.GetScriptHashesForVerifying(snapshot);

if (tx.Witnesses != null && witnessScript is null)
// base size for transaction: includes const_header + signers + attributes + script + hashes
int size = Transaction.HeaderSize + tx.Signers.GetVarSize() + tx.Attributes.GetVarSize() + tx.Script.GetVarSize() + IO.Helper.GetVarSize(hashes.Length), index = -1;
uint exec_fee_factor = NativeContract.Policy.GetExecFeeFactor(snapshot);
long networkFee = 0;
foreach (UInt160 hash in hashes)
{
// Try to find the script in the witnesses
Witness witness = tx.Witnesses[index];
witnessScript = witness?.VerificationScript.ToArray();
index++;
byte[] witnessScript = accountScript(hash);
byte[] invocationScript = null;

if (witnessScript is null || witnessScript.Length == 0)
if (tx.Witnesses != null && witnessScript is null)
{
// Then it's a contract-based witness, so try to get the corresponding invocation script for it
invocationScript = witness?.InvocationScript.ToArray();
}
}
// Try to find the script in the witnesses
Witness witness = tx.Witnesses[index];
witnessScript = witness?.VerificationScript.ToArray();

if (witnessScript is null || witnessScript.Length == 0)
{
var contract = NativeContract.ContractManagement.GetContract(snapshot, hash);
if (contract is null)
throw new ArgumentException($"The smart contract or address {hash} is not found");
var md = contract.Manifest.Abi.GetMethod(ContractBasicMethod.Verify, ContractBasicMethod.VerifyPCount);
if (md is null)
throw new ArgumentException($"The smart contract {contract.Hash} haven't got verify method");
if (md.ReturnType != ContractParameterType.Boolean)
throw new ArgumentException("The verify method doesn't return boolean value.");
if (md.Parameters.Length > 0 && invocationScript is null)
throw new ArgumentException("The verify method requires parameters that need to be passed via the witness' invocation script.");
if (witnessScript is null || witnessScript.Length == 0)
{
// Then it's a contract-based witness, so try to get the corresponding invocation script for it
invocationScript = witness?.InvocationScript.ToArray();
}
}
if (witnessScript is null || witnessScript.Length == 0)
{
var contract = NativeContract.ContractManagement.GetContract(snapshot, hash);
if (contract is null)
throw new WalletException(WalletErrorType.ContractNotFound, $"The smart contract or address {hash} is not found");
var md = contract.Manifest.Abi.GetMethod(ContractBasicMethod.Verify, ContractBasicMethod.VerifyPCount);
if (md is null)
throw new WalletException(WalletErrorType.ContractError, $"The smart contract {contract.Hash} hasn't got a verify method");
if (md.ReturnType != ContractParameterType.Boolean)
throw new WalletException(WalletErrorType.ContractError, "The verify method doesn't return boolean value");
if (md.Parameters.Length > 0 && invocationScript is null)
throw new WalletException(WalletErrorType.ContractError, "The verify method requires parameters that need to be passed via the witness' invocation script");

// Empty verification and non-empty invocation scripts
var invSize = invocationScript?.GetVarSize() ?? Array.Empty<byte>().GetVarSize();
size += Array.Empty<byte>().GetVarSize() + invSize;
// Empty verification and non-empty invocation scripts
var invSize = invocationScript?.GetVarSize() ?? Array.Empty<byte>().GetVarSize();
size += Array.Empty<byte>().GetVarSize() + invSize;

// Check verify cost
using ApplicationEngine engine = ApplicationEngine.Create(TriggerType.Verification, tx, snapshot.CloneCache(), settings: settings, gas: maxExecutionCost);
engine.LoadContract(contract, md, CallFlags.ReadOnly);
if (invocationScript != null) engine.LoadScript(invocationScript, configureState: p => p.CallFlags = CallFlags.None);
if (engine.Execute() == VMState.FAULT) throw new ArgumentException($"Smart contract {contract.Hash} verification fault.");
if (!engine.ResultStack.Pop().GetBoolean()) throw new ArgumentException($"Smart contract {contract.Hash} returns false.");
// Check verify cost
using ApplicationEngine engine = ApplicationEngine.Create(TriggerType.Verification, tx, snapshot.CloneCache(), settings: settings, gas: maxExecutionCost);
engine.LoadContract(contract, md, CallFlags.ReadOnly);
if (invocationScript != null) engine.LoadScript(invocationScript, configureState: p => p.CallFlags = CallFlags.None);
if (engine.Execute() == VMState.FAULT) throw new WalletException(WalletErrorType.ExecutionFault, $"Smart contract {contract.Hash} verification fault");
if (!engine.ResultStack.Pop().GetBoolean()) throw new WalletException(WalletErrorType.VerificationFailed, $"Smart contract {contract.Hash} returns false");

maxExecutionCost -= engine.FeeConsumed;
if (maxExecutionCost <= 0) throw new InvalidOperationException("Insufficient GAS.");
networkFee += engine.FeeConsumed;
}
else if (IsSignatureContract(witnessScript))
{
size += 67 + witnessScript.GetVarSize();
networkFee += exec_fee_factor * SignatureContractCost();
maxExecutionCost -= engine.FeeConsumed;
if (maxExecutionCost <= 0) throw new WalletException(WalletErrorType.InsufficientFunds, "Insufficient GAS");
networkFee += engine.FeeConsumed;
}
else if (IsSignatureContract(witnessScript))
{
size += 67 + witnessScript.GetVarSize();
networkFee += exec_fee_factor * SignatureContractCost();
}
else if (IsMultiSigContract(witnessScript, out int m, out int n))
{
int size_inv = 66 * m;
size += IO.Helper.GetVarSize(size_inv) + size_inv + witnessScript.GetVarSize();
networkFee += exec_fee_factor * MultiSignatureContractCost(m, n);
}
// We can support more contract types in the future.
}
else if (IsMultiSigContract(witnessScript, out int m, out int n))
networkFee += size * NativeContract.Policy.GetFeePerByte(snapshot);
foreach (TransactionAttribute attr in tx.Attributes)
{
int size_inv = 66 * m;
size += IO.Helper.GetVarSize(size_inv) + size_inv + witnessScript.GetVarSize();
networkFee += exec_fee_factor * MultiSignatureContractCost(m, n);
networkFee += attr.CalculateNetworkFee(snapshot, tx);
}
// We can support more contract types in the future.
return networkFee;
}
networkFee += size * NativeContract.Policy.GetFeePerByte(snapshot);
foreach (TransactionAttribute attr in tx.Attributes)
catch (Exception ex) when (ex is not WalletException)
{
networkFee += attr.CalculateNetworkFee(snapshot, tx);
throw new WalletException(WalletErrorType.UnknownError, ex.Message, ex);
}
return networkFee;
}

internal static void ThrowIfNull(object argument, string paramName)
{
if (argument == null)
throw new WalletException(WalletErrorType.ArgumentNull, $"Argument {paramName} cannot be null.");
}
}
}
66 changes: 44 additions & 22 deletions src/Neo/Wallets/KeyPair.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,23 @@ public class KeyPair : IEquatable<KeyPair>
public KeyPair(byte[] privateKey)
{
if (privateKey.Length != 32 && privateKey.Length != 96 && privateKey.Length != 104)
throw new ArgumentException(null, nameof(privateKey));
throw new WalletException(WalletErrorType.InvalidPrivateKey, nameof(privateKey));
PrivateKey = privateKey[^32..];
if (privateKey.Length == 32)
{
PublicKey = Cryptography.ECC.ECCurve.Secp256r1.G * privateKey;
}
else
{
PublicKey = Cryptography.ECC.ECPoint.FromBytes(privateKey, Cryptography.ECC.ECCurve.Secp256r1);
try
{
PublicKey = Cryptography.ECC.ECPoint.FromBytes(privateKey, Cryptography.ECC.ECCurve.Secp256r1);
}
catch
{
throw new WalletException(WalletErrorType.InvalidPrivateKey, nameof(privateKey));
}

}
}

Expand Down Expand Up @@ -119,30 +127,44 @@ public string Export(string passphrase, byte version, int N = 16384, int r = 8,
/// <returns>The private key in NEP-2 format.</returns>
public string Export(byte[] passphrase, byte version, int N = 16384, int r = 8, int p = 8)
{
UInt160 script_hash = Contract.CreateSignatureRedeemScript(PublicKey).ToScriptHash();
string address = script_hash.ToAddress(version);
byte[] addresshash = Encoding.ASCII.GetBytes(address).Sha256().Sha256()[..4];
byte[] derivedkey = SCrypt.Generate(passphrase, addresshash, N, r, p, 64);
byte[] derivedhalf1 = derivedkey[..32];
byte[] derivedhalf2 = derivedkey[32..];
byte[] encryptedkey = Encrypt(XOR(PrivateKey, derivedhalf1), derivedhalf2);
Span<byte> buffer = stackalloc byte[39];
buffer[0] = 0x01;
buffer[1] = 0x42;
buffer[2] = 0xe0;
addresshash.CopyTo(buffer[3..]);
encryptedkey.CopyTo(buffer[7..]);
return Base58.Base58CheckEncode(buffer);
try
{
UInt160 script_hash = Contract.CreateSignatureRedeemScript(PublicKey).ToScriptHash();
string address = script_hash.ToAddress(version);
byte[] addresshash = Encoding.ASCII.GetBytes(address).Sha256().Sha256()[..4];
byte[] derivedkey = SCrypt.Generate(passphrase, addresshash, N, r, p, 64);
byte[] derivedhalf1 = derivedkey[..32];
byte[] derivedhalf2 = derivedkey[32..];
byte[] encryptedkey = Encrypt(XOR(PrivateKey, derivedhalf1), derivedhalf2);
Span<byte> buffer = stackalloc byte[39];
buffer[0] = 0x01;
buffer[1] = 0x42;
buffer[2] = 0xe0;
addresshash.CopyTo(buffer[3..]);
encryptedkey.CopyTo(buffer[7..]);
return Base58.Base58CheckEncode(buffer);
}
catch (Exception e)
{
throw WalletException.FromException(e);
}
}

private static byte[] Encrypt(byte[] data, byte[] key)
{
using Aes aes = Aes.Create();
aes.Key = key;
aes.Mode = CipherMode.ECB;
aes.Padding = PaddingMode.None;
using ICryptoTransform encryptor = aes.CreateEncryptor();
return encryptor.TransformFinalBlock(data, 0, data.Length);
try
{
using Aes aes = Aes.Create();
aes.Key = key;
aes.Mode = CipherMode.ECB;
aes.Padding = PaddingMode.None;
using ICryptoTransform encryptor = aes.CreateEncryptor();
return encryptor.TransformFinalBlock(data, 0, data.Length);
}
catch (Exception e)
{
throw WalletException.FromException(e);
}
}

public override int GetHashCode()
Expand Down
Loading