Skip to content

Commit

Permalink
Base class made to enable re-use of sharding functions. Also added ab…
Browse files Browse the repository at this point in the history
…ility to auth with AD rather than just SAS Tokens
  • Loading branch information
Justin Jones committed Jul 10, 2023
1 parent 692b69c commit f6f0e7a
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 77 deletions.
79 changes: 79 additions & 0 deletions OrleansShardedStorageProvider/AzureShardedGrainBase.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
using Orleans.Runtime;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace OrleansShardedStorageProvider
{
public class AzureShardedGrainBase
{
protected readonly string _serviceId;
protected readonly AzureShardedStorageOptions _options;

public AzureShardedGrainBase(string serviceId, AzureShardedStorageOptions options)
{
_serviceId = serviceId;
_options = options;
}

protected int GetShardNumberFromKey(string pk)
{
var hash = GetStableHashCode(pk);
var storageNum = Math.Abs(hash % this._options.ConnectionStrings.Count());

return storageNum;
}

/// <summary>
/// Take from https://stackoverflow.com/a/36845864/852806
/// </summary>
/// <param name="str"></param>
/// <returns></returns>
protected int GetStableHashCode(string str)
{
unchecked
{
int hash1 = 5381;
int hash2 = hash1;

for (int i = 0; i < str.Length && str[i] != '\0'; i += 2)
{
hash1 = ((hash1 << 5) + hash1) ^ str[i];
if (i == str.Length - 1 || str[i + 1] == '\0')
break;
hash2 = ((hash2 << 5) + hash2) ^ str[i + 1];
}

return hash1 + (hash2 * 1566083941);
}
}


private const string KeyStringSeparator = "__";

protected string GetKeyString(GrainId grainId)
{
var key = $"{this._serviceId}{KeyStringSeparator}{grainId.ToString()}";

return SanitizeTableProperty(key);
}

protected string SanitizeTableProperty(string key)
{
// Remove any characters that can't be used in Azure PartitionKey or RowKey values
// http://www.jamestharpe.com/web-development/azure-table-service-character-combinations-disallowed-in-partitionkey-rowkey/
key = key
.Replace('/', '_') // Forward slash
.Replace('\\', '_') // Backslash
.Replace('#', '_') // Pound sign
.Replace('?', '_'); // Question mark

if (key.Length >= 1024)
throw new ArgumentException(string.Format("Key length {0} is too long to be an Azure table key. Key={1}", key.Length, key));

return key;
}
}
}
86 changes: 9 additions & 77 deletions OrleansShardedStorageProvider/AzureShardedGrainStorage.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Azure;
using Azure.Data.Tables;
using Azure.Identity;
using Azure.Storage.Blobs;
using Azure.Storage.Blobs.Models;
using Microsoft.Extensions.DependencyInjection;
Expand All @@ -19,24 +20,21 @@ namespace OrleansShardedStorageProvider
/// Origin: https://github.com/JsAndDotNet/OrleansShardedStorage
/// Similar to Oreleans:src\Azure\Orleans.Persistence.AzureStorage\Providers\Storage\AzureTableStorage.cs
/// </remarks>
public class AzureShardedGrainStorage : IGrainStorage, ILifecycleParticipant<ISiloLifecycle>
public class AzureShardedGrainStorage : AzureShardedGrainBase, IGrainStorage, ILifecycleParticipant<ISiloLifecycle>
{
private readonly string _serviceId;
private readonly string _name;
private readonly ILogger _logger;
private readonly AzureShardedStorageOptions _options;
private List<TableClient> _tableClients = new List<TableClient>();
private List<BlobContainerClient> _blobClients = new List<BlobContainerClient>();
private StorageType _storageType = StorageType.TableStorage;


public AzureShardedGrainStorage(string name, AzureShardedStorageOptions options, IOptions<ClusterOptions> clusterOptions, ILoggerFactory loggerFactory)
: base(clusterOptions.Value.ServiceId, options)
{
this._name = name;
var loggerName = $"{typeof(AzureShardedGrainStorage).FullName}.{name}";
this._logger = loggerFactory.CreateLogger(loggerName);
this._options = options;
this._serviceId = clusterOptions.Value.ServiceId;
}

public void Participate(ISiloLifecycle lifecycle)
Expand Down Expand Up @@ -66,9 +64,9 @@ public async Task Init(CancellationToken ct)
{
_storageType = StorageType.TableStorage;

var shareClient = new TableServiceClient(
storage.BaseTableUri,
new AzureSasCredential(storage.SasToken));
var shareClient = String.IsNullOrEmpty(storage.SasToken) ?
new TableServiceClient(storage.BaseTableUri, new DefaultAzureCredential()) :
new TableServiceClient(storage.BaseTableUri, new AzureSasCredential(storage.SasToken));

var table = await shareClient.CreateTableIfNotExistsAsync(storage.TableOrContainerName);

Expand All @@ -83,7 +81,9 @@ public async Task Init(CancellationToken ct)
{
_storageType = StorageType.BlobStorage;

BlobServiceClient blobServiceClient = new BlobServiceClient(storage.BaseBlobUri, storage.SasCredential);
BlobServiceClient blobServiceClient = (null == storage.SasCredential) ?
new BlobServiceClient(storage.BaseBlobUri, new DefaultAzureCredential()) :
new BlobServiceClient(storage.BaseBlobUri, storage.SasCredential);

var containerClient = blobServiceClient.GetBlobContainerClient(storage.TableOrContainerName);
await containerClient.CreateIfNotExistsAsync();
Expand Down Expand Up @@ -401,74 +401,6 @@ public Task Close(CancellationToken ct)
return Task.CompletedTask;
}





#region "Utils"

private int GetShardNumberFromKey(string pk)
{
var hash = GetStableHashCode(pk);
var storageNum = Math.Abs(hash % this._options.ConnectionStrings.Count());

return storageNum;
}

/// <summary>
/// Take from https://stackoverflow.com/a/36845864/852806
/// </summary>
/// <param name="str"></param>
/// <returns></returns>
public int GetStableHashCode(string str)
{
unchecked
{
int hash1 = 5381;
int hash2 = hash1;

for (int i = 0; i < str.Length && str[i] != '\0'; i += 2)
{
hash1 = ((hash1 << 5) + hash1) ^ str[i];
if (i == str.Length - 1 || str[i + 1] == '\0')
break;
hash2 = ((hash2 << 5) + hash2) ^ str[i + 1];
}

return hash1 + (hash2 * 1566083941);
}
}


private const string KeyStringSeparator = "__";

private string GetKeyString(GrainId grainId)
{
var key = $"{this._serviceId}{KeyStringSeparator}{grainId.ToString()}";

return SanitizeTableProperty(key);
}

public string SanitizeTableProperty(string key)
{
// Remove any characters that can't be used in Azure PartitionKey or RowKey values
// http://www.jamestharpe.com/web-development/azure-table-service-character-combinations-disallowed-in-partitionkey-rowkey/
key = key
.Replace('/', '_') // Forward slash
.Replace('\\', '_') // Backslash
.Replace('#', '_') // Pound sign
.Replace('?', '_'); // Question mark

if (key.Length >= 1024)
throw new ArgumentException(string.Format("Key length {0} is too long to be an Azure table key. Key={1}", key.Length, key));

return key;
}




#endregion
}

public static class AzureShardedGrainStorageFactory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

<ItemGroup>
<PackageReference Include="Azure.Data.Tables" Version="12.6.0" />
<PackageReference Include="Azure.Identity" Version="1.9.0" />
<PackageReference Include="Azure.Storage.Blobs" Version="12.10.0" />
<PackageReference Include="Azure.Storage.Files.Shares" Version="12.10.0" />
<PackageReference Include="Microsoft.Orleans.Sdk" Version="7.0.0" />
Expand Down

0 comments on commit f6f0e7a

Please sign in to comment.