Creating a high-level framework built on top of PyTorch
Hi everyone,
I am working on pytorch-dlvn - a high-level framework built on top of PyTorch.
The main focus of this framework is a Runner
class that can take a model, an optimizer, a criterion and a dataset and train the model with a few lines of code.
Here is an example:
model = models.LeNet5()
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
criterion = nn.CrossEntropyLoss()
dataloaders = get_dataloaders_from_torchvision(
Datasets.MNIST,
root='../data')
runner = Runner(model=model,
optimizer=optimizer,
criterion=criterion,
dataloaders=dataloaders)
runner.fit(max_epochs=5)
A runner will have a State
dictionary (runner.state
) which stores every important things such as optimizer, loss, number of epochs, inputs and outputs of a batch ...
The core of this framework is the callback system
that enable reusable code. Each callback will receive the State
as an input so that it can manipulate and monitor the runner process.
The central of this callback system is the Event Publisher & Subscriber.
The Publisher & Subscriber are controlled by an EventDispatcher which is implemented using Singleton pattern.
We design that a Callback will subscribe itself automatically to the EventDispatcher
on __init__
method. Therefore, we can create a Tensorboard Callback
to enable logging to Tensorboard
like the following:
tb = TensorboardCallback()
...
runner.fit()
Another way to create a Callback is using the @on decorator
to register an arbitrary method.
@on(Events.TRAIN_EPOCH_BEGIN, deps=['TensorboardCallback.on_train_epoch_begin'], param=100)
def sample_callback(state: State, param: int):
print('Sample callback', state, param)
As you can see, a callback can depends on other callbacks using the deps
parameters. Currently, we are using the __qualname__
of a method to specify the dependency.
The dependency system of the callbacks are implemented as a Directed acyclic graph (DAG)
and sorted using the topological sort algorithm
.
As in the above example, the TensorboardCallback
will be called before the sample_callback
.
As a result, a callback can save its result into the State
and the dependent can access that result from the State
.
You can inspect the sorted callback order using the summary method in the EventDispatcher
.
dispatcher.summary()
Please give me some feedback so I can improve this framework.
Thanks.
Github Link: https://github.com/quanhua92/pytorch-dlvn