How to generate git commit message using AI?

27 August 2023 ・ Reading time: 4 minutes

Why don’t use available solution? All of them using ChatGPT. But I’m out of credits ;) Of course I want to learn something!

How to generate git commit message?

Git allows you to create hooks. Let’s use global one. Global hooks works without modifying every git repo.

Create a directory for hooks:

$ mkdir ~/.config/git/hooks/

Let git knows where hooks are:

$ git config core.hooksPath ~/.config/git/hooks/

Long story short the prepare-commit-msg is the one we need. The file we need to update is passed as first parameter. Create a simple script:

#!/bin/sh

echo "Fancy commit message" > $1

Make it executable:

$ chmod +z ~/.confog/git/hooks/prepare-commit-msg

Is it works? Let’s commit something … Yep, we have a message at the end of the commit message.

Let’s generate something:

Generating commit message

Let’s build something that’s works offline. AI? Yes, let’s use AI!

We need model right?

Let’s look at huggingface!

There it is: https://huggingface.co/mamiksik/T5-commit-message-generation but there are no documentation :( But if you’ll look deeper you’ll find https://huggingface.co/spaces/mamiksik/commit-message-generator

We can use this https://huggingface.co/spaces/mamiksik/commit-message-generator/blob/main/app.py with a little modifications.

As we can use any shell script in a hook, let’s use python.

Let’s take a look what’s there:

import re

import gradio as gr
import torch
from transformers import T5ForConditionalGeneration, RobertaTokenizer


tokenizer = RobertaTokenizer.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01")
model = T5ForConditionalGeneration.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01")

def parse_files(patch):
    accumulator = []
    lines = patch.splitlines()

    filename_before = None
    for line in lines:
        if line.startswith("index") or line.startswith("diff"):
            continue
        if line.startswith("---"):
            filename_before = line.split(" ", 1)[1][1:]
            continue

        if line.startswith("+++"):
            filename_after = line.split(" ", 1)[1][1:]

            if filename_before == filename_after:
                accumulator.append(f"<ide><path>{filename_before}")
            else:
                accumulator.append(f"<add><path>{filename_after}")
                accumulator.append(f"<del><path>{filename_before}")
            continue

        line = re.sub("@@[^@@]*@@", "", line)
        if len(line) == 0:
            continue

        if line[0] == "+":
            line = line.replace("+", "<add>", 1)
        elif line[0] == "-":
            line = line.replace("-", "<del>", 1)
        else:
            line = f"<ide>{line}"

        accumulator.append(line)

    return '\n'.join(accumulator)


def predict(patch, max_length, min_length, num_beams, prediction_count):
    input_text = parse_files(patch)
    with torch.no_grad():
        token_count = tokenizer(input_text, return_tensors="pt").input_ids.shape[1]

        input_ids = tokenizer(
            input_text,
            truncation=True,
            padding=True,
            return_tensors="pt",
        ).input_ids

        outputs = model.generate(
            input_ids,
            max_length=max_length,
            min_length=min_length,
            num_beams=num_beams,
            num_return_sequences=prediction_count,
        )

    result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return token_count, input_text, {k: 0 for k in result}


iface = gr.Interface(fn=predict, inputs=[
    gr.Textbox(label="Patch (as generated by git diff)"),
    gr.Slider(1, 128, value=40, label="Max message length"),
    gr.Slider(1, 128, value=5, label="Min message length"),
    gr.Slider(1, 10, value=7, label="Number of beams"),
    gr.Slider(1, 15, value=5, label="Number of predictions"),
], outputs=[
    gr.Textbox(label="Token count"),
    gr.Textbox(label="Parsed patch"),
    gr.Label(label="Predictions")
], examples=[
["""
diff --git a/.github/workflows/pylint.yml b/.github/workflows/codestyle_checks.yml
similarity index 86%
rename from .github/workflows/pylint.yml
rename to .github/workflows/codestyle_checks.yml
index a5d5c4d9..8cbf9713 100644
--- a/.github/workflows/pylint.yml
+++ b/.github/workflows/codestyle_checks.yml
@@ -20,3 +20,6 @@ jobs:
     - name: Analysing the code with pylint
       run: |
         pylint --rcfile=.pylintrc webapp core
+    - name: Analysing the code with flake8
+      run: |
+        flake8
""", 40, 5, 7, 5]
]
)

if __name__ == "__main__":
    iface.launch()

Everything we need is here! We need to:

  • fetch gitmessage file to update
  • fetch git diff
  • use current script make predictions
  • prepend commit message to gitmessage file

File that we need to update is passed as first parameter so

import sys

sys.argv[1]

Heh that was easy.

Fetch git diff

import subprocess

subprocess.run(['git', 'diff', '--cached'], capture_output=True).stdout.decode('utf-8')

Easy peasy!

Use current script to make predictions

max_message = 40
min_message = 5
num_beams = 10
num_predictions = 1

msg = predict(diff, max_message, min_message, num_beams, num_predictions)

Prepend our message to gitmessage file

with open(sys.argv[1], 'r+') as f:
    content = f.read()
    f.seek(0)
    f.write(msg + '\n' + content)
    f.close()

It’s just like that. With little cleanups this is our final script.

#!/usr/bin/env python
print("Generating commit message", end="", flush=True)

import sys
import re
import subprocess
import torch
from transformers import T5ForConditionalGeneration, RobertaTokenizer

def parse_files(patch):
    accumulator = []
    lines = patch.splitlines()

    filename_before = None
    for line in lines:
        print(".", end="", flush=True)
        if line.startswith("index") or line.startswith("diff"):
            continue
        if line.startswith("---"):
            filename_before = line.split(" ", 1)[1][1:]
            continue

        if line.startswith("+++"):
            filename_after = line.split(" ", 1)[1][1:]

            if filename_before == filename_after:
                accumulator.append(f"<ide><path>{filename_before}")
            else:
                accumulator.append(f"<add><path>{filename_after}")
                accumulator.append(f"<del><path>{filename_before}")
            continue

        line = re.sub("@@[^@@]*@@", "", line)
        if len(line) == 0:
            continue

        if line[0] == "+":
            line = line.replace("+", "<add>", 1)
        elif line[0] == "-":
            line = line.replace("-", "<del>", 1)
        else:
            line = f"<ide>{line}"

        accumulator.append(line)

    return '\n'.join(accumulator)

def predict(patch, max_length, min_length, num_beams, prediction_count):
    print(".", end="", flush=True)
    input_text = parse_files(patch)
    
    tokenizer = RobertaTokenizer.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01", low_cpu_mem_usage=True)
    print(".", end="", flush=True)
    model = T5ForConditionalGeneration.from_pretrained("mamiksik/CommitPredictorT5PL", revision="fb08d01", low_cpu_mem_usage=True)
    print(".", end="", flush=True)

    with torch.no_grad():
        input_ids = tokenizer(
            input_text,
            truncation=True,
            padding=True,
            return_tensors="pt",
        ).input_ids
        print(".", end="", flush=True)
        outputs = model.generate(
            input_ids,
            max_length=max_length,
            min_length=min_length,
            num_beams=num_beams,
            num_return_sequences=prediction_count,
        )
        print(".", end="", flush=True)

    result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return result[0]

if __name__ == "__main__":
    diff = subprocess.run(['git', 'diff', '--cached'], capture_output=True).stdout.decode('utf-8')

    max_message = 40
    min_message = 5
    num_beams = 10
    num_predictions = 1

    msg = predict(diff, max_message, min_message, num_beams, num_predictions)

    with open(sys.argv[1], 'r+') as f:
        content = f.read()
        f.seek(0)
        f.write(msg + '\n' + content)
        f.close()

    print("Done!\n")

It’s fast on cpu, but loading model take a lot of times. Anyway 3s is OK. That’s all. It works. At least for me.

One more thing!. Just figure it out you can use api directly from command line shell (fish).

curl -s -X POST https://mamiksik-commit-message-generator.hf.space/run/predict
  -H 'Content-Type: application/json'
  -d '{ "data": ["(git diff)", 40, 5, 7, 1]}'
  | jq ".data[2].label"