Code search using embeddings

OpenAI LogoOpenAI LogoOpenAI Logo
Boris Power, Logan Kilpatrick, Eli Salamie
Mar 10, 2022
Open in Github

This notebook shows how Ada embeddings can be used to implement semantic code search. For this demonstration, we use our own openai-python code repository. We implement a simple version of file parsing and extracting of functions from python files, which can be embedded, indexed, and queried.

Helper Functions

We first setup some simple parsing functions that allow us to extract important information from our codebase.

import pandas as pd
from pathlib import Path

DEF_PREFIXES = ['def ', 'async def ']
NEWLINE = '\n'

def get_function_name(code):
    """
    Extract function name from a line beginning with 'def' or 'async def'.
    """
    for prefix in DEF_PREFIXES:
        if code.startswith(prefix):
            return code[len(prefix): code.index('(')]


def get_until_no_space(all_lines, i):
    """
    Get all lines until a line outside the function definition is found.
    """
    ret = [all_lines[i]]
    for j in range(i + 1, len(all_lines)):
        if len(all_lines[j]) == 0 or all_lines[j][0] in [' ', '\t', ')']:
            ret.append(all_lines[j])
        else:
            break
    return NEWLINE.join(ret)


def get_functions(filepath):
    """
    Get all functions in a Python file.
    """
    with open(filepath, 'r') as file:
        all_lines = file.read().replace('\r', NEWLINE).split(NEWLINE)
        for i, l in enumerate(all_lines):
            for prefix in DEF_PREFIXES:
                if l.startswith(prefix):
                    code = get_until_no_space(all_lines, i)
                    function_name = get_function_name(code)
                    yield {
                        'code': code,
                        'function_name': function_name,
                        'filepath': filepath,
                    }
                    break


def extract_functions_from_repo(code_root):
    """
    Extract all .py functions from the repository.
    """
    code_files = list(code_root.glob('**/*.py'))

    num_files = len(code_files)
    print(f'Total number of .py files: {num_files}')

    if num_files == 0:
        print('Verify openai-python repo exists and code_root is set correctly.')
        return None

    all_funcs = [
        func
        for code_file in code_files
        for func in get_functions(str(code_file))
    ]

    num_funcs = len(all_funcs)
    print(f'Total number of functions extracted: {num_funcs}')

    return all_funcs

Data Loading

We'll first load the openai-python folder and extract the needed information using the functions we defined above.

# Set user root directory to the 'openai-python' repository
root_dir = Path.home()

# Assumes the 'openai-python' repository exists in the user's root directory
code_root = root_dir / 'openai-python'

# Extract all functions from the repository
all_funcs = extract_functions_from_repo(code_root)
Total number of .py files: 57
Total number of functions extracted: 118

Now that we have our content, we can pass the data to the text-embedding-ada-002 endpoint to get back our vector embeddings.

from openai.embeddings_utils import get_embedding

df = pd.DataFrame(all_funcs)
df['code_embedding'] = df['code'].apply(lambda x: get_embedding(x, engine='text-embedding-ada-002'))
df['filepath'] = df['filepath'].map(lambda x: Path(x).relative_to(code_root))
df.to_csv("data/code_search_openai-python.csv", index=False)
df.head()
code function_name filepath code_embedding
0 def _console_log_level():\n if openai.log i... _console_log_level openai/util.py [0.033906757831573486, -0.00418944051489234, 0...
1 def log_debug(message, **params):\n msg = l... log_debug openai/util.py [-0.004059609025716782, 0.004895503632724285, ...
2 def log_info(message, **params):\n msg = lo... log_info openai/util.py [0.0048639848828315735, 0.0033139237202703953,...
3 def log_warn(message, **params):\n msg = lo... log_warn openai/util.py [0.0024026145692914724, -0.010721310041844845,...
4 def logfmt(props):\n def fmt(key, val):\n ... logfmt openai/util.py [0.01664826273918152, 0.01730910874903202, 0.0...

Testing

Let's test our endpoint with some simple queries. If you're familiar with the openai-python repository, you'll see that we're able to easily find functions we're looking for only a simple English description.

We define a search_functions method that takes our data that contains our embeddings, a query string, and some other configuration options. The process of searching our database works like such:

  1. We first embed our query string (code_query) with text-embedding-ada-002. The reasoning here is that a query string like 'a function that reverses a string' and a function like 'def reverse(string): return string[::-1]' will be very similar when embedded.
  2. We then calculate the cosine similarity between our query string embedding and all data points in our database. This gives a distance between each point and our query.
  3. We finally sort all of our data points by their distance to our query string and return the number of results requested in the function parameters.
from openai.embeddings_utils import cosine_similarity

def search_functions(df, code_query, n=3, pprint=True, n_lines=7):
    embedding = get_embedding(code_query, engine='text-embedding-ada-002')
    df['similarities'] = df.code_embedding.apply(lambda x: cosine_similarity(x, embedding))

    res = df.sort_values('similarities', ascending=False).head(n)

    if pprint:
        for r in res.iterrows():
            print(f"{r[1].filepath}:{r[1].function_name}  score={round(r[1].similarities, 3)}")
            print("\n".join(r[1].code.split("\n")[:n_lines]))
            print('-' * 70)

    return res
res = search_functions(df, 'fine-tuning input data validation logic', n=3)
openai/validators.py:format_inferrer_validator  score=0.751
def format_inferrer_validator(df):
    """
    This validator will infer the likely fine-tuning format of the data, and display it to the user if it is classification.
    It will also suggest to use ada and explain train/validation split benefits.
    """
    ft_type = infer_task_type(df)
    immediate_msg = None
----------------------------------------------------------------------
openai/validators.py:get_validators  score=0.748
def get_validators():
    return [
        num_examples_validator,
        lambda x: necessary_column_validator(x, "prompt"),
        lambda x: necessary_column_validator(x, "completion"),
        additional_column_validator,
        non_empty_field_validator,
----------------------------------------------------------------------
openai/validators.py:infer_task_type  score=0.739
def infer_task_type(df):
    """
    Infer the likely fine-tuning task type from the data
    """
    CLASSIFICATION_THRESHOLD = 3  # min_average instances of each class
    if sum(df.prompt.str.len()) == 0:
        return "open-ended generation"
----------------------------------------------------------------------
res = search_functions(df, 'find common suffix', n=2, n_lines=10)
openai/validators.py:get_common_xfix  score=0.794
def get_common_xfix(series, xfix="suffix"):
    """
    Finds the longest common suffix or prefix of all the values in a series
    """
    common_xfix = ""
    while True:
        common_xfixes = (
            series.str[-(len(common_xfix) + 1) :]
            if xfix == "suffix"
            else series.str[: len(common_xfix) + 1]
----------------------------------------------------------------------
openai/validators.py:common_completion_suffix_validator  score=0.778
def common_completion_suffix_validator(df):
    """
    This validator will suggest to add a common suffix to the completion if one doesn't already exist in case of classification or conditional generation.
    """
    error_msg = None
    immediate_msg = None
    optional_msg = None
    optional_fn = None

    ft_type = infer_task_type(df)
----------------------------------------------------------------------
res = search_functions(df, 'Command line interface for fine-tuning', n=1, n_lines=20)
openai/cli.py:tools_register  score=0.78
def tools_register(parser):
    subparsers = parser.add_subparsers(
        title="Tools", help="Convenience client side tools"
    )

    def help(args):
        parser.print_help()

    parser.set_defaults(func=help)

    sub = subparsers.add_parser("fine_tunes.prepare_data")
    sub.add_argument(
        "-f",
        "--file",
        required=True,
        help="JSONL, JSON, CSV, TSV, TXT or XLSX file containing prompt-completion examples to be analyzed."
        "This should be the local file path.",
    )
    sub.add_argument(
        "-q",
----------------------------------------------------------------------