Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataWorker] Add prototypes for data workers. #95

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions tensorflow/core/distributed_runtime/data_worker_controller.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DATA_WORKER_CONTROLLER_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DATA_WORKER_CONTROLLER_H_

#include <unordered_map>
#include <memory>
#include <string>
#include <utility>

#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/data_worker_graph_partition.h"

namespace tensorflow {

// Maintains and dispatches subgraphs to data workers.
class DataWorkerController {
private:
// A data-processing graph partitioned from each *training*
// worker task to be dispatched to data workers.
struct TaskDataWorkerGraph {
string task_name;
std::vector<std::pair<string /* dw name */, string /* dw host_port */>>
registered_data_workers;
std::shared_ptr<GraphDef> g;
// Names of the DataWorkerSend ops that should be run
// on the data worker clients.
std::vector<string> node_names;
// Names of the tensors to be sent from data workers.
std::vector<string> tensor_names;
int num_registered() const { return registered_data_workers.size(); }
void RegisterDataWorker(const string& name, const string& host_port) {
registered_data_workers.emplace_back(name, host_port);
}

TaskDataWorkerGraph(const string& task_name,
std::shared_ptr<GraphDef> g,
const std::vector<string>& node_names,
const std::vector<string>& tensor_names)
: task_name(task_name), g(g), node_names(node_names), tensor_names(tensor_names) {}
~TaskDataWorkerGraph() {}
};

mutex mu_;
std::vector<TaskDataWorkerGraph> graphs_ GUARDED_BY(mu_);
bool use_default_split_points_ = true;
bool extend_default_split_ = false;
bool fuse_recv_ = false;
// Used for sequencing DataWorkerSend/Recv nodes.
int64 next_node_id_ GUARDED_BY(mu_) = 0;

// Returns the graph that has been allocated to the least number of data workers.
TaskDataWorkerGraph& GetGraphForNewDataWorker();
// Resets the device names to the target data worker.
void ResetDeviceNamesForGraph(GraphDef* const g, const string& dw_name);
void ResetDeviceNameForNode(NodeDef* node, const string& dw_name);

public:
DataWorkerController() {}
DataWorkerController(bool use_default_split_points, bool extend_default_split, bool fuse_recv);
~DataWorkerController() {}
Status Partition(Graph* g, PartitionForDataWorkerOptions& popts);
Status RegisterDataWorker(GraphDef* dst_graph,
const string& name,
const string& host_port,
string& training_worker_name,
std::vector<string>& node_names);
const std::vector<string>* GetTensorNames(const string& task_name);
};

} // namespace tensorflow


#endif
160 changes: 160 additions & 0 deletions tensorflow/core/distributed_runtime/data_worker_rendezvous_mgr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DATA_WORKER_RENDEZVOUS_MGR_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DATA_WORKER_RENDEZVOUS_MGR_H_

#include <string>
#include <vector>
#include <unordered_map>

#include "tensorflow/core/distributed_runtime/data_worker_rendezvous_mgr_interface.h"
#include "tensorflow/core/framework/data_worker_rendezvous.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/distributed_runtime/worker_session.h"

namespace tensorflow {
class GenericDataWorkerRendezvous;
class DataWorkerRecvTensorThread;

class DataWorkerRendezvousMgr: public DataWorkerRendezvousMgrInterface{
public:
struct DataWorkerRendezvousMgrOptions{
int queue_size = 100;
int num_recv_threads = 1;
int num_send_threads = 4;
string protocol = "grpc";
bool fuse_recv = false;
};

explicit DataWorkerRendezvousMgr(const DataWorkerRendezvousMgrOptions& options);
~DataWorkerRendezvousMgr();

void RecvLocalAsync(const DataWorkerRendezvous::ParsedKey& key,
DataWorkerRendezvous::DoneCallback done) override;

void FuseRecvLocalAsync(const std::vector<DataWorkerRendezvous::ParsedKey>& keys,
DataWorkerRendezvous::FuseDoneCallback done) override;

void RegisterDataWorker(const string& task_name, const string& host_port) override;

void SetTensorNames(const std::vector<string>& tensor_names) override;

DataWorkerRendezvous* Find();

private:
mutex mu_;
const int queue_size_;
const int num_recv_threads_;
const int num_send_threads_;
const string protocol_;
const bool fuse_recv_;
GenericDataWorkerRendezvous* rdwr_ GUARDED_BY(mu_);
GenericDataWorkerRendezvous* FindOrCreate();
};

// GenericDataWorkerRendezvous supports both grpc and grpc++ as the underlying
// communication protocol. It also supports transferring data from the local
// data worker directly.
class GenericDataWorkerRendezvous: public DataWorkerRendezvous {
public:
GenericDataWorkerRendezvous(const int& queue_size,
const int& num_recv_threads,
const int& num_send_threads,
const string& protocol,
const bool& fuse_recv);
~GenericDataWorkerRendezvous();

Status Initialize(WorkerSession* session) override;
void StartAbort(const Status& status) override;
void SetTensorNames(const std::vector<string>& tensor_names) override;
Status SetRecvAttrs(const DataWorkerRendezvous::ParsedKey& key,
const AllocatorAttributes& alloc_attrs,
const string& device) override;
void DataWorkerSendAsync(const DataWorkerRendezvous::ParsedKey& key,
const Tensor& val,
const DataWorkerRendezvous::Args& send_args,
DataWorkerRendezvous::DoneCallback done) override;
Status LocalDataWorkerSend(const DataWorkerRendezvous::ParsedKey& key,
const string& tensor_name,
const Tensor& val,
const DataWorkerRendezvous::Args& send_args) override;
void RecvLocalAsync(const DataWorkerRendezvous::ParsedKey& key, DataWorkerRendezvous::DoneCallback done) override;
void FuseRecvLocalAsync(const std::vector<DataWorkerRendezvous::ParsedKey>& keys,
DataWorkerRendezvous::FuseDoneCallback done) override;
void DataWorkerRecvAsync(const DataWorkerRendezvous::ParsedKey& key,
const DataWorkerRendezvous::Args& recv_args,
DataWorkerRendezvous::DoneCallback done) override;
void DataWorkerFuseRecvAsync(const DataWorkerRendezvous::Args& recv_args,
DataWorkerRendezvous::FuseDoneCallback done) override;
void RegisterDataWorker(const string& task_name, const string& host_port);

private:
void RecvAsync(const DataWorkerRendezvous::ParsedKey& key,
const DataWorkerRendezvous::Args& recv_args,
DataWorkerRendezvous::DoneCallback done);
void EnqueueRecvItems(std::vector<Item*>& items);
void EnqueueFuseRecvItem(FuseItem* item);
void SameWorkerRecvDone(const DataWorkerRendezvous::ParsedKey& parsed,
const DataWorkerRendezvous::Args& send_args,
const DataWorkerRendezvous::Args& recv_args,
const Tensor& in, Tensor* out, StatusCallback done);

static uint64 KeyHash(const StringPiece& k) {
return Hash64(k.data(), k.size());
}

const string protocol_;
const bool fuse_recv_;
mutex attrs_mu_;
std::unordered_map<uint64, std::pair<AllocatorAttributes, Device*>> recv_nodes_attrs_ GUARDED_BY(attrs_mu_);
const int num_recv_threads_;
std::vector<std::unique_ptr<DataWorkerRecvTensorThread>> recv_threads_;
std::unique_ptr<thread::ThreadPool> send_threads_;

typedef std::deque<Item*> ItemQueue;
typedef std::deque<FuseItem*> FuseItemQueue;
typedef gtl::FlatMap<uint64, ItemQueue> Table;

std::mutex mu_;
std::mutex local_tmp_mu_;
std::condition_variable cv_;
Status status_ GUARDED_BY(mu_);
WorkerSession* session_ GUARDED_BY(mu_);

// Table is used for both data workers and training workers for storing the items to enable async execution:
// Data workers put the produced tensors in the Table and wait for the training workers
// to fetch them. Training workers put the fetched tensors in their local Table.
Table table_ GUARDED_BY(mu_);
FuseItemQueue fuse_queue_ GUARDED_BY(mu_);
std::vector<Item*> local_tmp_ GUARDED_BY(local_tmp_mu_);
std::vector<string> tensor_names_;
const int queue_size_;

friend class DataWorkerRecvTensorThread;
friend class GrpcDataWorkerRecvTensorThread;
friend class StarDataWorkerRecvTensorThread;
TF_DISALLOW_COPY_AND_ASSIGN(GenericDataWorkerRendezvous);
};


} // end namespace tensorflow

#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DATA_WORKER_RENDEZVOUS_MGR_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DATA_WORKER_RENDEZVOUS_MGR_INTERFACE_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DATA_WORKER_RENDEZVOUS_MGR_INTERFACE_H_

#include <string>
#include <vector>

#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/data_worker_rendezvous.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {

class DataWorkerRendezvousMgrInterface {
public:
DataWorkerRendezvousMgrInterface() {}
virtual ~DataWorkerRendezvousMgrInterface() {}

virtual DataWorkerRendezvous* Find() = 0;

virtual void RecvLocalAsync(const DataWorkerRendezvous::ParsedKey& key,
DataWorkerRendezvous::DoneCallback done) = 0;

virtual void FuseRecvLocalAsync(const std::vector<DataWorkerRendezvous::ParsedKey>& keys,
DataWorkerRendezvous::FuseDoneCallback done) = 0;

virtual void RegisterDataWorker(const string& task_name, const string& host_port) = 0;

virtual void SetTensorNames(const std::vector<string>& tensor_names) = 0;
};

} // end namespace tensorflow

#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DATA_WORKER_RENDEZVOUS_MGR_INTERFACE_H_
Loading