Part II
namespace Rebex.SimpleAzureBlobProvider.Tools
{
public class BlobStream : Stream
{
private enum StreamWriteMode
{
Overwrite,
Update
}
private const int FIRST_BLOCK_INDEX = 0;
private readonly BlobClient _blobClient;
private readonly BlockBlobClient _blobBlockClient;
private ChannelReader<StreamBlockIdPair> _channelReader;
private Channel<StreamBlockIdPair> _channelStream;
private ChannelWriter<StreamBlockIdPair> _channelWriter;
private readonly RecyclableMemoryStreamManager _memoryStreamManager;
private CancellationTokenSource _cts;
private int _position;
private Task _uploadToAzureTask;
private readonly BlobWriteStreamSettings _settings;
private long _length;
private MemoryStream _lastWaitingToUploadStream;
private StreamWriteMode _mode;
private List<Action> _waitingActionsBeforeWrite;
public BlobStream(BlobClient blobClient, BlockBlobClient blobBlockClient, RecyclableMemoryStreamManager memoryStreamManager, BlobWriteStreamSettings settings = null)
{
_settings = settings ?? BlobWriteStreamSettings.Default;
_blobClient = blobClient ?? throw new ArgumentNullException(nameof(blobClient));
_blobBlockClient = blobBlockClient ?? throw new ArgumentNullException(nameof(blobBlockClient));
_memoryStreamManager = memoryStreamManager ?? throw new ArgumentNullException(nameof(memoryStreamManager));
CreateChannel();
_position = 0;
_length = 0;
_uploadToAzureTask = null;
_lastWaitingToUploadStream = null;
var committedBlockList = _blobBlockClient.GetBlockList(BlockListTypes.Committed).Value.CommittedBlocks;
_mode = !committedBlockList.Any() || (committedBlockList.First().Size != settings.MinBatchSizeInBytes || committedBlockList.First().Name != createBlockId(FIRST_BLOCK_INDEX))
? StreamWriteMode.Update
: StreamWriteMode.Overwrite;
_waitingActionsBeforeWrite = new List<Action>();
}
private void CreateChannel()
{
_cts = new CancellationTokenSource();
_channelStream = Channel.CreateUnbounded<StreamBlockIdPair>(new UnboundedChannelOptions
{
SingleReader = true,
SingleWriter = true
});
_channelWriter = _channelStream.Writer;
_channelReader = _channelStream.Reader;
}
public override bool CanRead => false;
public override bool CanSeek => false;
public override bool CanWrite => true;
public override long Length => _length;
public override long Position
{
get => _position;
set => throw new NotSupportedException();
}
public override void Flush()
{
}
public override int Read(byte[] buffer, int offset, int count)
{
throw new NotSupportedException();
}
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException();
}
public override void SetLength(long value)
{
throw new NotSupportedException();
}
public override void Write(byte[] buffer, int offset, int count)
{
if (buffer == null)
{
throw new ArgumentNullException(nameof(buffer));
}
if (offset < 0)
{
throw new ArgumentOutOfRangeException("offset");
}
if (count < 0)
{
throw new ArgumentOutOfRangeException("count");
}
if (count == 0)
{
return;
}
if (offset + count > buffer.Length)
{
throw new ArgumentException("sum + offset is greater than buffer length.");
}
if (_uploadToAzureTask != null && _uploadToAzureTask.IsCompleted)
{
throw new InvalidOperationException("Unexpected error. Upload task finished prematurely.", _uploadToAzureTask.Exception);
}
var toProcessLength = count;
var currentOffset = offset;
while (toProcessLength > 0)
{
if (_lastWaitingToUploadStream == null)
{
_lastWaitingToUploadStream = _memoryStreamManager.GetStream();
}
var toWriteBytes = (int) Math.Min(_settings.MinBatchSizeInBytes - _lastWaitingToUploadStream.Length, toProcessLength);
_lastWaitingToUploadStream.Write(buffer, currentOffset, toWriteBytes);
toProcessLength -= toWriteBytes;
currentOffset += toWriteBytes;
if (_lastWaitingToUploadStream.Length == _settings.MinBatchSizeInBytes)
{
if (_uploadToAzureTask == null)
{
CreateChannel();
_uploadToAzureTask = Task.Run(() => UploadData(_cts.Token));
}
_lastWaitingToUploadStream.Position = 0;
var blockIndex = (_position + (count - toProcessLength)) / _settings.MinBatchSizeInBytes;
Debug.Assert((_position + (count - toProcessLength)) % _settings.MinBatchSizeInBytes == 0);
var streamBlockIdPair = new StreamBlockIdPair(_lastWaitingToUploadStream, createBlockId(blockIndex));
var writeResult = _channelWriter.TryWrite(streamBlockIdPair);
_lastWaitingToUploadStream = null;
Debug.Assert(writeResult);
}
}
_position += count;
if (_position > _length)
{
_length = _position;
}
}
protected override void Dispose(bool disposing)
{
if (disposing)
{
try
{
if (_uploadToAzureTask == null)
{
if (_lastWaitingToUploadStream != null && _lastWaitingToUploadStream.Length > 0)
{
_lastWaitingToUploadStream.Position = 0;
_blobClient.Upload(_lastWaitingToUploadStream, overwrite: true);
_lastWaitingToUploadStream = null;
}
}
else
{
if (_lastWaitingToUploadStream != null && _lastWaitingToUploadStream.Length > 0)
{
_lastWaitingToUploadStream.Position = 0;
var writeResult = _channelWriter.TryWrite(_lastWaitingToUploadStream);
Debug.Assert(writeResult);
}
_channelWriter.TryComplete();
_uploadToAzureTask.Wait();
}
}
catch (Exception e)
{
Console.WriteLine(e);
}
finally
{
_lastWaitingToUploadStream?.Dispose();
}
}
base.Dispose(disposing);
}
private async Task UploadData(CancellationToken ct)
{
var blockCounter = FIRST_BLOCK_INDEX;
var blockIds = new List<string>();
var tasksStreamDictionary = new Dictionary<Task, Stream>();
var uploadCts = new CancellationTokenSource();
var combinedCts = CancellationTokenSource.CreateLinkedTokenSource(ct, uploadCts.Token);
var combinedCtsToken = combinedCts.Token;
try
{
while (await _channelReader.WaitToReadAsync(combinedCtsToken).ConfigureAwait(false))
{
while (_channelReader.TryRead(out var stream))
{
var blockIdString = createBlockId(blockCounter++);
blockIds.Add(blockIdString);
if (tasksStreamDictionary.Keys.Any(task => task.IsFaulted || task.IsCanceled))
{
var firstFailedTask = tasksStreamDictionary.Keys.First(task => task.IsFaulted || task.IsCanceled);
await firstFailedTask.ConfigureAwait(false);
}
if (tasksStreamDictionary.Count >= _settings.MaxUploadTasks)
{
var completedTask = await Task.WhenAny(tasksStreamDictionary.Keys).ConfigureAwait(false);
tasksStreamDictionary[completedTask].Dispose();
tasksStreamDictionary.Remove(completedTask);
await completedTask.ConfigureAwait(false);
}
combinedCtsToken.ThrowIfCancellationRequested();
tasksStreamDictionary.Add(_blobBlockClient.StageBlockAsync(blockIdString, stream, cancellationToken: combinedCtsToken), stream);
}
}
await Task.WhenAll(tasksStreamDictionary.Keys).ConfigureAwait(false);
combinedCtsToken.ThrowIfCancellationRequested();
if (!_channelReader.Completion.IsFaulted && blockIds.Count > 0)
{
await _blobBlockClient.CommitBlockListAsync(blockIds, cancellationToken: combinedCtsToken)
.ConfigureAwait(false);
}
}
finally
{
uploadCts.Cancel();
try
{
await Task.WhenAll(tasksStreamDictionary.Keys).ConfigureAwait(false);
}
catch(Exception exception)
{
Debug.WriteLine(exception);
}
foreach (var stream in tasksStreamDictionary.Values)
{
stream.Dispose();
}
tasksStreamDictionary.Clear();
}
}
private static string createBlockId(int blockCounter)
{
var blockId = BitConverter.GetBytes(blockCounter);
var blockIdString = Convert.ToBase64String(blockId, 0, blockId.Length);
return blockIdString;
}
private readonly struct StreamBlockIdPair
{
public StreamBlockIdPair(Stream stream, string blockId)
{
Stream = stream;
BlockId = blockId;
}
public Stream Stream
{
get;
}
public string BlockId
{
get;
}
}
}
}