My FeedDiscussionsHeadless CMS
New
Sign in
Log inSign up
Learn more about Hashnode Headless CMSHashnode Headless CMS
Collaborate seamlessly with Hashnode Headless CMS for Enterprise.
Upgrade ✨Learn more

Creating a high-level framework built on top of PyTorch

Quan Hua's photo
Quan Hua
·May 21, 2019

Hi everyone,

I am working on pytorch-dlvn - a high-level framework built on top of PyTorch.

The main focus of this framework is a Runner class that can take a model, an optimizer, a criterion and a dataset and train the model with a few lines of code.

Here is an example:

model = models.LeNet5()
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
criterion = nn.CrossEntropyLoss()

dataloaders = get_dataloaders_from_torchvision(
    Datasets.MNIST,
    root='../data')

runner = Runner(model=model,
                optimizer=optimizer,
                criterion=criterion,
                dataloaders=dataloaders)
runner.fit(max_epochs=5)

A runner will have a State dictionary (runner.state) which stores every important things such as optimizer, loss, number of epochs, inputs and outputs of a batch ...

The core of this framework is the callback system that enable reusable code. Each callback will receive the State as an input so that it can manipulate and monitor the runner process.

The central of this callback system is the Event Publisher & Subscriber.

The Publisher & Subscriber are controlled by an EventDispatcher which is implemented using Singleton pattern.

We design that a Callback will subscribe itself automatically to the EventDispatcher on __init__ method. Therefore, we can create a Tensorboard Callback to enable logging to Tensorboard like the following:

tb = TensorboardCallback()
...
runner.fit()

Another way to create a Callback is using the @on decorator to register an arbitrary method.

@on(Events.TRAIN_EPOCH_BEGIN, deps=['TensorboardCallback.on_train_epoch_begin'], param=100)
def sample_callback(state: State, param: int):
    print('Sample callback', state, param)

As you can see, a callback can depends on other callbacks using the deps parameters. Currently, we are using the __qualname__ of a method to specify the dependency.

The dependency system of the callbacks are implemented as a Directed acyclic graph (DAG) and sorted using the topological sort algorithm.

As in the above example, the TensorboardCallback will be called before the sample_callback.

As a result, a callback can save its result into the State and the dependent can access that result from the State.

You can inspect the sorted callback order using the summary method in the EventDispatcher.

dispatcher.summary()

Please give me some feedback so I can improve this framework.

Thanks.

Github Link: https://github.com/quanhua92/pytorch-dlvn