Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorganize SDK config and service instances logic #54

Merged
merged 5 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ grpcio-health-checking==1.59.0
jsonschema==4.0.0
eth-account==0.9.0
snet.cli==2.1.3
snet.contracts==0.1.1
snet.contracts==0.1.1
urllib3>=2.2.2
169 changes: 67 additions & 102 deletions snet/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from pathlib import Path
import sys
from typing import Any, NewType
import re
import copy

import google.protobuf.internal.api_implementation
from snet.sdk.metadata_provider.ipfs_metadata_provider import IPFSMetadataProvider
from snet.sdk.payment_strategies.default_payment_strategy import DefaultPaymentStrategy
from snet.cli.commands.sdk_command import SDKCommand
from snet.cli.commands.commands import BlockchainCommand
from snet.cli.config import Config
from snet.cli.utils.utils import bytes32_to_str, type_converter

google.protobuf.internal.api_implementation.Type = lambda: 'python'

Expand Down Expand Up @@ -41,7 +43,7 @@


class Arguments:
def __init__(self, org_id, service_id):
def __init__(self, org_id=None, service_id=None):
self.org_id = org_id
self.service_id = service_id
self.language = "python"
Expand All @@ -51,56 +53,46 @@ def __init__(self, org_id, service_id):
class SnetSDK:
"""Base Snet SDK"""

def __init__(self, config, metadata_provider=None):
self._config = config
def __init__(self, sdk_config, metadata_provider=None):
self._sdk_config = sdk_config
self._metadata_provider = metadata_provider

# Instantiate Ethereum client
eth_rpc_endpoint = self._config.get("eth_rpc_endpoint", "https://mainnet.infura.io/v3/e7732e1f679e461b9bb4da5653ac3fc2")
eth_rpc_request_kwargs = self._config.get("eth_rpc_request_kwargs")
eth_rpc_endpoint = self._sdk_config.get("eth_rpc_endpoint",
"https://mainnet.infura.io/v3/e7732e1f679e461b9bb4da5653ac3fc2")
eth_rpc_request_kwargs = self._sdk_config.get("eth_rpc_request_kwargs")

provider = web3.HTTPProvider(endpoint_uri=eth_rpc_endpoint, request_kwargs=eth_rpc_request_kwargs)

self.web3 = web3.Web3(provider)

# Get MPE contract address from config if specified; mostly for local testing
_mpe_contract_address = self._config.get("mpe_contract_address", None)
_mpe_contract_address = self._sdk_config.get("mpe_contract_address", None)
if _mpe_contract_address is None:
self.mpe_contract = MPEContract(self.web3)
else:
self.mpe_contract = MPEContract(self.web3, _mpe_contract_address)

# Instantiate IPFS client
ipfs_endpoint = self._config.get("default_ipfs_endpoint", "/dns/ipfs.singularitynet.io/tcp/80/")
ipfs_endpoint = self._sdk_config.get("default_ipfs_endpoint", "/dns/ipfs.singularitynet.io/tcp/80/")
self.ipfs_client = ipfshttpclient.connect(ipfs_endpoint)

# Get Registry contract address from config if specified; mostly for local testing
_registry_contract_address = self._config.get("registry_contract_address", None)
_registry_contract_address = self._sdk_config.get("registry_contract_address", None)
if _registry_contract_address is None:
self.registry_contract = get_contract_object(self.web3, "Registry")
else:
self.registry_contract = get_contract_object(self.web3, "Registry", _registry_contract_address)

self.account = Account(self.web3, config, self.mpe_contract)
self.account = Account(self.web3, sdk_config, self.mpe_contract)

global_config = Config(sdk_config=config)
global_config = Config(sdk_config=sdk_config)
self.setup_config(global_config)
sdk = SDKCommand(global_config, args=Arguments(config['org_id'], config['service_id']))
force_update = config.get('force_update', False)

if force_update:
sdk.generate_client_library()
else:
path_to_pb_files = self.get_path_to_pb_files(config['org_id'], config['service_id'])
pb_2_file_name = find_file_by_keyword(path_to_pb_files, keyword="pb2.py")
pb_2_grpc_file_name = find_file_by_keyword(path_to_pb_files, keyword="pb2_grpc.py")
if not pb_2_file_name or not pb_2_grpc_file_name:
sdk.generate_client_library()

def setup_config(self, config: Config) -> None:
out_f = sys.stdout
network = self._config.get("network", None)
identity_name = self._config.get("identity_name", None)
network = self._sdk_config.get("network", None)
identity_name = self._sdk_config.get("identity_name", None)
# Checking for an empty network
if network and config["session"]["network"] != network:
config.set_session_network(network, out_f)
Expand All @@ -118,21 +110,39 @@ def set_session_identity(self, identity_name: str, config: Config, out_f):
elif config["session"]["identity"] != identity_name:
config.set_session_identity(identity_name, out_f)

def create_service_client(self, payment_channel_management_strategy=None,
options=None, concurrent_calls=1):
org_id = self._config.get("org_id")
service_id = self._config.get("service_id")
group_name = self._config.get("group_name", "default_group")
def create_service_client(self, org_id: str, service_id: str, group_name=None,
payment_channel_management_strategy=None,
free_call_auth_token_bin=None,
free_call_token_expiry_block=None,
options=None,
concurrent_calls=1):

# Create and instance of the Config object, so we can create an instance of SDKCommand
sdk_config_object = Config(sdk_config=self._sdk_config)
sdk = SDKCommand(sdk_config_object, args=Arguments(org_id, service_id))

# Download the proto file and generate stubs if needed
force_update = self._sdk_config.get('force_update', False)
if force_update:
sdk.generate_client_library()
else:
path_to_pb_files = self.get_path_to_pb_files(org_id, service_id)
pb_2_file_name = find_file_by_keyword(path_to_pb_files, keyword="pb2.py")
pb_2_grpc_file_name = find_file_by_keyword(path_to_pb_files, keyword="pb2_grpc.py")
if not pb_2_file_name or not pb_2_grpc_file_name:
sdk.generate_client_library()

if payment_channel_management_strategy is None:
payment_channel_management_strategy = DefaultPaymentStrategy(concurrent_calls)

if options is None:
options = dict()

options['free_call_auth_token-bin'] = bytes.fromhex(self._config.get("free_call_auth_token-bin", ""))
options['free-call-token-expiry-block'] = self._config.get("free-call-token-expiry-block", 0)
options['email'] = self._config.get("email", "")
options['concurrency'] = self._config.get("concurrency", True)
options['free_call_auth_token-bin'] = bytes.fromhex(free_call_auth_token_bin) if\
free_call_token_expiry_block else ""
options['free-call-token-expiry-block'] = free_call_token_expiry_block if\
free_call_token_expiry_block else 0
options['email'] = self._sdk_config.get("email", "")
options['concurrency'] = self._sdk_config.get("concurrency", True)

if self._metadata_provider is None:
self._metadata_provider = IPFSMetadataProvider(self.ipfs_client, self.registry_contract)
Expand All @@ -145,8 +155,8 @@ def create_service_client(self, payment_channel_management_strategy=None,

pb2_module = self.get_module_by_keyword(org_id, service_id, keyword="pb2.py")

service_client = ServiceClient(org_id, service_id, service_metadata, group, service_stub, strategy, options,
self.mpe_contract, self.account, self.web3, pb2_module)
service_client = ServiceClient(org_id, service_id, service_metadata, group, service_stub, strategy,
options, self.mpe_contract, self.account, self.web3, pb2_module)
return service_client

def get_service_stub(self, org_id: str, service_id: str) -> ServiceStub:
Expand Down Expand Up @@ -206,69 +216,24 @@ def _get_service_group_details(self, service_metadata, group_name):

return self._get_group_by_group_name(service_metadata, group_name)

def get_services_and_messages_info(self):
# Get proto file filepath and open it
path_to_pb_files = self.get_path_to_pb_files(self._config['org_id'], self._config['service_id'])
proto_file_name = find_file_by_keyword(path_to_pb_files, keyword=".proto")
proto_filepath = os.path.join(path_to_pb_files, proto_file_name)
with open(proto_filepath, 'r') as file:
proto_content = file.read()
# Regular expression patterns to match services, methods, messages, and fields
service_pattern = re.compile(r'service\s+(\w+)\s*{')
rpc_pattern = re.compile(r'rpc\s+(\w+)\s*\((\w+)\)\s+returns\s+\((\w+)\)')
message_pattern = re.compile(r'message\s+(\w+)\s*{')
field_pattern = re.compile(r'\s*(\w+)\s+(\w+)\s*=\s*\d+\s*;')

services = {}
messages = {}
current_service = None
current_message = None

for line in proto_content.splitlines():
# Match a service definition
service_match = service_pattern.search(line)
if service_match:
current_service = service_match.group(1)
services[current_service] = []
continue

# Match an RPC method inside a service
if current_service:
rpc_match = rpc_pattern.search(line)
if rpc_match:
method_name = rpc_match.group(1)
input_type = rpc_match.group(2)
output_type = rpc_match.group(3)
services[current_service].append((method_name, input_type, output_type))

# Match a message definition
message_match = message_pattern.search(line)
if message_match:
current_message = message_match.group(1)
messages[current_message] = []
continue

# Match a field inside a message
if current_message:
field_match = field_pattern.search(line)
if field_match:
field_type = field_match.group(1)
field_name = field_match.group(2)
messages[current_message].append((field_type, field_name))

return services, messages

def print_services_info(self):
services, messages = self.get_services_and_messages_info()

# Print the parsed services and their methods
for service, methods in services.items():
print(f"Service: {service}")
for method_name, input_type, output_type in methods:
print(f" Method: {method_name}, Input: {input_type}, Output: {output_type}")

# Print the parsed messages and their fields
for message, fields in messages.items():
print(f"Message: {message}")
for field_type, field_name in fields:
print(f" Field: {field_type} {field_name}")
def get_organization_list(self) -> list:
global_config = Config(sdk_config=self._sdk_config)
blockchain_command = BlockchainCommand(config=global_config, args=Arguments())
org_list = blockchain_command.call_contract_command(
"Registry", "listOrganizations", [])
organization_list = []
for idx, org_id in enumerate(org_list):
organization_list.append(bytes32_to_str(org_id))
return organization_list

def get_services_list(self, org_id: str) -> list:
global_config = Config(sdk_config=self._sdk_config)
blockchain_command = BlockchainCommand(config=global_config, args=Arguments())
(found, org_service_list) = blockchain_command.call_contract_command("Registry",
"listServicesForOrganization",
[type_converter("bytes32")(org_id)])
if not found:
raise Exception(f"Organization with id={org_id} doesn't exist!")
org_service_list = list(map(bytes32_to_str, org_service_list))
return org_service_list

81 changes: 79 additions & 2 deletions snet/sdk/service_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import base64
import collections
import importlib
import re
import os
from pathlib import Path

import grpc
import web3
from eth_account.messages import defunct_hash_message
from rfc3986 import urlparse
from snet.cli.resources.root_certificate import certificate
from snet.cli.utils.utils import RESOURCES_PATH, add_to_path
from snet.cli.utils.utils import RESOURCES_PATH, add_to_path, find_file_by_keyword

import snet.sdk.generic_client_interceptor as generic_client_interceptor
from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider
Expand Down Expand Up @@ -188,4 +191,78 @@ def get_concurrency_token_and_channel(self):

def set_concurrency_token_and_channel(self, token, channel):
self.payment_strategy.concurrency_token = token
self.payment_strategy.channel = channel
self.payment_strategy.channel = channel

def get_path_to_pb_files(self, org_id: str, service_id: str) -> str:
client_libraries_base_dir_path = Path("~").expanduser().joinpath(".snet")
path_to_pb_files = f"{client_libraries_base_dir_path}/{org_id}/{service_id}/python/"
return path_to_pb_files

def get_services_and_messages_info(self):
# Get proto file filepath and open it
path_to_pb_files = self.get_path_to_pb_files(self.org_id, self.service_id)
proto_file_name = find_file_by_keyword(path_to_pb_files, keyword=".proto")
proto_filepath = os.path.join(path_to_pb_files, proto_file_name)
with open(proto_filepath, 'r') as file:
proto_content = file.read()
# Regular expression patterns to match services, methods, messages, and fields
service_pattern = re.compile(r'service\s+(\w+)\s*{')
rpc_pattern = re.compile(r'rpc\s+(\w+)\s*\((\w+)\)\s+returns\s+\((\w+)\)')
message_pattern = re.compile(r'message\s+(\w+)\s*{')
field_pattern = re.compile(r'\s*(\w+)\s+(\w+)\s*=\s*\d+\s*;')

services = {}
messages = {}
current_service = None
current_message = None

for line in proto_content.splitlines():
# Match a service definition
service_match = service_pattern.search(line)
if service_match:
current_service = service_match.group(1)
services[current_service] = []
continue

# Match an RPC method inside a service
if current_service:
rpc_match = rpc_pattern.search(line)
if rpc_match:
method_name = rpc_match.group(1)
input_type = rpc_match.group(2)
output_type = rpc_match.group(3)
services[current_service].append((method_name, input_type, output_type))

# Match a message definition
message_match = message_pattern.search(line)
if message_match:
current_message = message_match.group(1)
messages[current_message] = []
continue

# Match a field inside a message
if current_message:
field_match = field_pattern.search(line)
if field_match:
field_type = field_match.group(1)
field_name = field_match.group(2)
messages[current_message].append((field_type, field_name))

return services, messages

def get_services_and_messages_info_as_pretty_string(self):
services, messages = self.get_services_and_messages_info()

string_output = ""
# Prettify the parsed services and their methods
for service, methods in services.items():
string_output += f"Service: {service}\n"
for method_name, input_type, output_type in methods:
string_output += f" Method: {method_name}, Input: {input_type}, Output: {output_type}\n"

# Prettify the messages and their fields
for message, fields in messages.items():
string_output += f"Message: {message}\n"
for field_type, field_name in fields:
string_output += f" Field: {field_type} {field_name}\n"
return string_output
11 changes: 5 additions & 6 deletions testcases/functional_tests/test_sdk_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def setUp(self):
self.service_client, self.path_to_pb_files = get_test_service_data()
channel = self.service_client.deposit_and_open_channel(123456, 33333)

def test__call_to_service(self):
def test_call_to_service(self):
result = self.service_client.call_rpc("mul", "Numbers", a=20, b=3)
self.assertEqual(60.0, result.value)

Expand All @@ -27,18 +27,17 @@ def get_test_service_data():
"private_key": os.environ['SNET_TEST_WALLET_PRIVATE_KEY'],
"eth_rpc_endpoint": f"https://sepolia.infura.io/v3/{os.environ['SNET_TEST_INFURA_KEY']}",
"concurrency": False,
"org_id": "26072b8b6a0e448180f8c0e702ab6d2f",
"service_id": "Exampleservice",
"group_name": "default_group",
"identity_name": "test",
"identity_type": "key",
"network": "sepolia",
"force_update": False
}

snet_sdk = sdk.SnetSDK(config)
service_client = snet_sdk.create_service_client()
path_to_pb_files = snet_sdk.get_path_to_pb_files(config['org_id'], config['service_id'])
service_client = snet_sdk.create_service_client(org_id="26072b8b6a0e448180f8c0e702ab6d2f",
service_id="Exampleservice", group_name="default_group")
path_to_pb_files = snet_sdk.get_path_to_pb_files(org_id="26072b8b6a0e448180f8c0e702ab6d2f",
service_id="Exampleservice")
return service_client, path_to_pb_files


Expand Down
Loading