This is a note on my learning journey with Gradio, a Python library that allows you to quickly create UIs for your machine learning models.

Resources:

Quick start

import gradio as gr

def greet(name, intensity):
    return "Hello, " + name + "!" * int(intensity)

demo = gr.Interface(
    fn=greet,
    inputs=["text", "slider"],
    outputs=["text"],
)

demo.launch()

Understanding the Interface Class

The Interface class has three core arguments:

  • fn: the function to wrap a user interface (UI) around
  • inputs: the Gradio component(s) to use for the input. The number of components should match the number of arguments in your function.
  • outputs: the Gradio component(s) to use for the output. The number of components should match the number of return values from your function.

The fn argument is very flexible – you can pass any Python function that you want to wrap with a UI. In the example above, we saw a relatively simple function, but the function could be anything from a music generator to a tax calculator to the prediction function of a pretrained machine learning model.

Core Gradio Classes

Gradio has three core classes:

  • Interface: the main class for creating a user interface
  • Chatbots with gr.ChatInterface
  • Custom Demos with gr.Blocks

Streaming outputs

In some cases, you may want to stream a sequence of outputs rather than show a single output at once. For example, you might have an image generation model and you want to show the image that is generated at each step, leading up to the final image. Or you might have a chatbot which streams its response one token at a time instead of returning it all at once.

In such cases, you can supply a generator function into Gradio instead of a regular function.

def my_generator(x):
    for i in range(x):
        yield i
import gradio as gr
import numpy as np
import time

def fake_diffusion(steps):
    rng = np.random.default_rng()
    for i in range(steps):
        time.sleep(1)
        image = rng.random(size=(600, 600, 3))
        yield image
    image = np.ones((1000,1000,3), np.uint8)
    image[:] = [255, 124, 0]
    yield image


demo = gr.Interface(fake_diffusion,
                    inputs=gr.Slider(1, 10, 3, step=1),
                    outputs="image")

demo.launch()

In the above demo, the function fake_diffusion is a generator function that yields an image at each step, with input steps determining the number of steps to run. The outputs argument of the Interface is set to image to indicate that the output is an image by the gr.Image component. The inputs argument is set to a gr.Slider component which takes a number between 1 and 10 as input.

Getting the current value of a component

One common use case is to have access to the current value of a component in a function. Here is my workaround:

def display_file(file, datadict):
    datadict = datadict or {}

    # Read the content of the uploaded file
    if type(file) is list:
        content = ""
        for f in file:
            # add file name to the content
            file_path = f.name
            file_name = file_path.split("/")[-1]
            content += '\n-------------------\n'
            content += 'file_path: ' + f.name + "\n"
            content += 'file_name: ' + file_name + "\n"
            _content = ""
            with open(f.name, 'r') as f:
                _content += f.read()
            content += _content
            datadict[file_name] = _content
                
    elif type(file) is str:
        with open(file.name, 'r') as f:
            content = f.read()
        file_name = file.name.split("/")[-1]
        datadict[file.name] = content
    else:
        content = 'No file uploaded'

    print(datadict)
    np.save('datadict.npy', datadict)

    return content, datadict

def test_upload_and_read():
    import gradio as gr

    # Create a Gradio interface
    with gr.Blocks() as demo:
        # Add a file/files upload button
        file_input = gr.File(label="Upload your file", file_count='multiple')
        
        # Add a textbox to display the content of the uploaded file
        file_content = gr.Textbox(label="File Content", lines=20)
        
        # Add a button to trigger the display of the file content
        upload_button = gr.Button("Display File Content")
        
        state = gr.State()

        # Define the action for the button
        upload_button.click(display_file, inputs=[file_input, state], 
                            outputs=[file_content, state])

    # Launch the Gradio interface
    demo.launch()

In the above code, the display_file function takes in a file and a dictionary datadict as input and returns the content of the file and the updated dictionary. The state object is used to store the current value of the dictionary datadict and pass it to the next function call.

Batch functions

Gradio supports the ability to pass batch functions. Batch functions are just functions which take in a list of inputs and return a list of predictions.

For example, here is a batched function that takes in two lists of inputs (a list of words and a list of ints), and returns a list of trimmed words as output:

import time

def trim_words(words, lens):
    trimmed_words = []
    time.sleep(5)
    for w, l in zip(words, lens):
        trimmed_words.append(w[:int(l)])
    return [trimmed_words]

demo = gr.Interface(
    fn=trim_words, 
    inputs=["textbox", "number"], 
    outputs=["output"],
    batch=True, 
    max_batch_size=16
)

demo.launch()

With the gr.Blocks class

import gradio as gr

with gr.Blocks() as demo:
    with gr.Row():
        word = gr.Textbox(label="word")
        leng = gr.Number(label="leng")
        output = gr.Textbox(label="Output")
    with gr.Row():
        run = gr.Button()

    event = run.click(trim_words, [word, leng], output, batch=True, max_batch_size=16)

demo.launch()

The advantage of using batched functions is that if you enable queuing, the Gradio server can automatically batch incoming requests and process them in parallel, potentially speeding up your demo. Here’s what the Gradio code looks like (notice the batch=True and max_batch_size=16).

In the above exmple with gr.Blocks class, the click method is used to bind the trim_words function to the run button. The trim_words function takes in the word and leng from two input components and returns the output to the Textbox component.