Skip to content

Tasks

Bases: ABC

Source code in src/autolabel/tasks/base.py
class BaseTask(ABC):
    ZERO_SHOT_TEMPLATE = "{task_guidelines}\n\n{output_guidelines}\n\nNow I want you to label the following example:\n{current_example}"
    FEW_SHOT_TEMPLATE = "{task_guidelines}\n\n{output_guidelines}\n\nSome examples with their output answers are provided below:\n\n{seed_examples}\n\nNow I want you to label the following example:\n{current_example}"

    ZERO_SHOT_TEMPLATE_REFUEL_LLM = """
    <s>[INST] <<SYS>>
    {task_guidelines}\n{output_guidelines}
    <</SYS>>
    {current_example}[/INST]
    """
    FEW_SHOT_TEMPLATE_REFUEL_LLM = """
    <s>[INST] <<SYS>>
    {task_guidelines}\n{output_guidelines}\n{seed_examples}
    <</SYS>>
    {current_example}[/INST]
    """

    # Downstream classes should override these
    NULL_LABEL_TOKEN = "NO_LABEL"
    DEFAULT_TASK_GUIDELINES = ""
    DEFAULT_OUTPUT_GUIDELINES = ""
    DEFAULT_DATASET_GENERATION_GUIDELINES = ""

    def __init__(self, config: AutolabelConfig) -> None:
        self.config = config
        self.image_col = self.config.image_column()

        # Update the default prompt template with the prompt template from the config
        self.task_guidelines = (
            self.config.task_guidelines() or self.DEFAULT_TASK_GUIDELINES
        )
        self.output_guidelines = (
            self.config.output_guidelines() or self.DEFAULT_OUTPUT_GUIDELINES
        )

        self.dataset_generation_guidelines = (
            self.config.dataset_generation_guidelines()
            or self.DEFAULT_DATASET_GENERATION_GUIDELINES
        )
        self._prompt_schema_init()

    def _prompt_schema_init(self) -> None:
        self.use_refuel_prompt_schema = self.config.provider() == ModelProvider.REFUEL
        if self._is_few_shot_mode():
            self.example_template = (
                self.FEW_SHOT_TEMPLATE_REFUEL_LLM
                if self.use_refuel_prompt_schema
                else self.FEW_SHOT_TEMPLATE
            )
        else:
            self.example_template = (
                self.ZERO_SHOT_TEMPLATE_REFUEL_LLM
                if self.use_refuel_prompt_schema
                else self.ZERO_SHOT_TEMPLATE
            )
        self.prompt_template = PromptTemplate(
            input_variables=get_format_variables(self.example_template),
            template=self.example_template,
        )

    def _is_few_shot_mode(self) -> bool:
        return self.config.few_shot_algorithm() in [x.value for x in FewShotAlgorithm]

    @abstractmethod
    def construct_prompt(
        self,
        input: str,
        examples: List,
        prompt_template_override: PromptTemplate = None,
        refuel_prompt_override: bool = False,
        output_guidelines_override: str = None,
        max_input_tokens: int = None,
        get_num_tokens: Optional[Callable] = None,
        **kwargs,
    ) -> str:
        pass

    def construct_confidence_prompt(self, input: str, examples: List, **kwargs) -> str:
        curr_template = (
            self.FEW_SHOT_TEMPLATE_REFUEL_LLM
            if self._is_few_shot_mode()
            else self.ZERO_SHOT_TEMPLATE_REFUEL_LLM
        )
        prompt_template = PromptTemplate(
            input_variables=get_format_variables(curr_template),
            template=curr_template,
        )
        refuel_prompt = self.construct_prompt(
            input=input,
            examples=examples,
            prompt_template_override=prompt_template,
            refuel_prompt_override=True,
            **kwargs,
        )
        return refuel_prompt

    def trim_prompt(
        self,
        prompt_template: str,
        task_guidelines: str,
        output_guidelines: str,
        current_example: str,
        seed_examples: str = None,
        max_input_tokens: int = None,
        get_num_tokens: Optional[Callable] = None,
    ) -> str:
        complete_prompt = prompt_template.format(
            task_guidelines=task_guidelines,
            output_guidelines=output_guidelines,
            seed_examples=seed_examples,
            current_example=current_example,
        )
        if not max_input_tokens or not get_num_tokens:
            return complete_prompt

        trimming_priority = [
            seed_examples,
            task_guidelines,
            output_guidelines,
            current_example,
        ]
        trimmed_elements = {key: key for key in trimming_priority if key is not None}
        for trimming_candidate in trimming_priority:
            current_prompt_length = get_num_tokens(complete_prompt)
            if current_prompt_length <= max_input_tokens:
                break
            if trimming_candidate is None:
                continue
            extra_tokens = current_prompt_length - max_input_tokens
            trimming_candidate_tokens = get_num_tokens(trimming_candidate)
            max_chars = (
                float(len(trimming_candidate))
                * (trimming_candidate_tokens - extra_tokens - 1)
                / (trimming_candidate_tokens + 1)
            )
            final_candidate_chars = int(max(0, max_chars))
            trimmed_elements[trimming_candidate] = trimming_candidate[
                :final_candidate_chars
            ]
            complete_prompt = prompt_template.format(
                task_guidelines=trimmed_elements[task_guidelines],
                output_guidelines=trimmed_elements[output_guidelines],
                seed_examples=trimmed_elements[seed_examples]
                if seed_examples is not None
                else None,
                current_example=trimmed_elements[current_example],
            )

        return complete_prompt

    @abstractmethod
    def eval(
        self,
        llm_labels: List,
        gt_labels: List,
        additional_metrics: Optional[List[BaseMetric]] = [],
    ) -> List[MetricResult]:
        pass

    @abstractmethod
    def get_explanation_prompt(self, example: Dict, include_label=True) -> str:
        raise NotImplementedError(
            "Explanation generation not implemented for this task"
        )

    @abstractmethod
    def get_generate_dataset_prompt(
        self, label: str, num_rows: int, guidelines: str = None
    ) -> str:
        raise NotImplementedError("Dataset generation not implemented for this task")

    def parse_llm_response(
        self,
        response: Union[Generation, ChatGeneration],
        curr_sample: Dict,
        prompt: str,
    ) -> LLMAnnotation:
        # The last line of the response is the label
        # This is done to handle the case where the model generates an explanation before generating the label
        error = None
        if self.config.chain_of_thought():
            try:
                explanation = response.text.strip().split("\n")[0].strip()
                completion_text = extract_valid_json_substring(
                    response.text.strip().split("\n")[-1].strip()
                )
                completion_text = json.loads(completion_text)["label"]
            except:
                completion_text = None
        else:
            completion_text = response.text.strip().split("\n")[-1].strip()
        if len(response.text.strip()) == 0:
            successfully_labeled = False
            llm_label = self.NULL_LABEL_TOKEN
            logger.warning("LLM response is empty")
            error = LabelingError(
                error_type=ErrorType.EMPTY_RESPONSE_ERROR,
                error_message="Empty response from LLM",
            )
        elif not completion_text:
            successfully_labeled = False
            llm_label = self.NULL_LABEL_TOKEN
            logger.warning(f"Error parsing LLM response: {response.text}")
            error = LabelingError(
                error_type=ErrorType.PARSING_ERROR,
                error_message=f"Error parsing LLM response: {response.text}",
            )
        else:
            llm_label = completion_text.strip()
            if self.config.task_type() in [
                TaskType.CLASSIFICATION,
                TaskType.ENTITY_MATCHING,
            ]:
                if llm_label in self.config.labels_list():
                    successfully_labeled = True
                else:
                    logger.warning(f"LLM response is not in the labels list")
                    llm_label = self.NULL_LABEL_TOKEN
                    successfully_labeled = False
                    error = LabelingError(
                        error_type=ErrorType.OUTPUT_GUIDELINES_NOT_FOLLOWED_ERROR,
                        error_message=f"LLM response is not in the labels list: {llm_label}",
                    )
            elif self.config.task_type() == TaskType.MULTILABEL_CLASSIFICATION:
                llm_multi_labels = llm_label.split(self.config.label_separator())
                llm_multi_labels = list(
                    filter(
                        lambda label: label in self.config.labels_list(),
                        llm_multi_labels,
                    )
                )
                if len(llm_multi_labels) == 0:
                    llm_label = self.NULL_LABEL_TOKEN
                    successfully_labeled = False
                    error = LabelingError(
                        error_type=ErrorType.OUTPUT_GUIDELINES_NOT_FOLLOWED_ERROR,
                        error_message=f"LLM response is not in the labels list: {llm_label}",
                    )
                else:
                    llm_label = self.config.label_separator().join(llm_multi_labels)
                    successfully_labeled = True
            else:
                successfully_labeled = True

        return LLMAnnotation(
            successfully_labeled=successfully_labeled,
            label=llm_label,
            generation_info=response.generation_info,
            raw_response=response.text,
            prompt=prompt,
            curr_sample=pickle.dumps(curr_sample),
            explanation=explanation if self.config.chain_of_thought() else "",
            error=error,
        )

Bases: BaseTask

Source code in src/autolabel/tasks/classification.py
class ClassificationTask(BaseTask):
    DEFAULT_OUTPUT_GUIDELINES = (
        'You will return the answer with just one element: "the correct label"'
    )
    DEFAULT_TASK_GUIDELINES = "Your job is to correctly label the provided input example into one of the following {num_labels} categories.\nCategories:\n{labels}\n"

    LABEL_FORMAT_IN_EXPLANATION = (
        " The last line of the explanation should be - So, the answer is <label>."
    )
    EXCLUDE_LABEL_IN_EXPLANATION = " Do not repeat the output of the task - simply provide an explanation for the provided output. The provided label was generated by you in a previous step and your job now is to only provided an explanation for the output. Your job is not verify the output but instead explain why it might have been generated, even if it is incorrect. If you think the provided output is incorrect, give an explanation of why it might have been generated anyway but don't say that the output may be incorrect or incorrectly generated.'"
    GENERATE_EXPLANATION_PROMPT = "Your job is to provide an explanation for why a specific output might have been generated for a task. \n\nBEGIN TASK DESCRIPTION\n{task_guidelines}\nEND TASK DESCRIPTION\nYou will be given an input example and the corresponding output. Your job is to provide an explanation for why the output is correct for the task above.\nThink step by step and generate an explanation with at most 2 sentences.{label_format}\n{labeled_example}\nExplanation: "

    GENERATE_DATASET_TEMPLATE = "{guidelines}\n\nThe inputs must be diverse, covering a wide range of scenarios. You will not generate duplicate inputs. These inputs should be organized in rows in csv format with the columns {columns}.\n\n{label_descriptions}\n\n{format_guidelines}\n\n{output_guidelines}\n\n```csv"
    DEFAULT_DATASET_GENERATION_GUIDELINES = "You are an expert at generating plausible inputs for a given task.\n\nBEGIN TASK DESCRIPTION\n{task_guidelines}\nEND TASK DESCRIPTION"
    LABEL_DESCRIPTIONS_PROMPT = "Each input should fall into one of these {num_labels} categories. These are the only categories that the inputs can belong to."
    GENERATE_DATASET_FORMAT_GUIDELINES = "Your response should be in csv format with the following columns: {columns}.\n\nHere is a template you can follow for your output:\n```csv\n{columns}\n{example_rows}\n```\n\nMake sure to replace the placeholder variables with your own values."
    GENERATE_DATASET_OUTPUT_GUIDELINES = 'Now I want you to generate {num_rows} excerpts that follow the guidelines and all belong to the "{label}" category. They should not belong to any of the other categories.'

    def __init__(self, config: AutolabelConfig) -> None:
        super().__init__(config)
        self.metrics = [
            AccuracyMetric(),
            SupportMetric(),
            CompletionRateMetric(),
            ClassificationReportMetric(),
        ]

        if self.config.confidence():
            self.metrics.append(AUROCMetric())

        for label in self.config.labels_list():
            if "\n" in label:
                logger.warning(
                    "Label contains newline character. This can have output guideline issues."
                )

    def construct_prompt(
        self,
        input: Dict,
        examples: List,
        selected_labels: List[str] = None,
        prompt_template_override: PromptTemplate = None,
        refuel_prompt_override: bool = False,
        output_guidelines_override: str = None,
        max_input_tokens: int = None,
        get_num_tokens: Optional[Callable] = None,
        **kwargs,
    ) -> str:
        # Copy over the input so that we can modify it
        input = input.copy()

        # prepare task guideline
        labels_list = (
            self.config.labels_list() if not selected_labels else selected_labels
        )
        num_labels = len(labels_list)

        fmt_task_guidelines = self.task_guidelines.format_map(
            defaultdict(str, labels="\n".join(labels_list), num_labels=num_labels)
        )

        # prepare seed examples
        example_template = self.config.example_template()
        label_column = self.config.label_column()
        fmt_examples = []
        for eg in examples:
            eg_copy = eg.copy()
            # If chain of thought is enabled
            if label_column and self.config.chain_of_thought():
                eg_copy[label_column] = json.dumps({"label": eg[label_column]})
            fmt_examples.append(example_template.format_map(defaultdict(str, eg_copy)))

        # populate the current example in the prompt
        if label_column:
            input[label_column] = ""

        # populate the explanation column with empty string for current example
        explanation_column = self.config.explanation_column()
        if explanation_column:
            input[explanation_column] = ""

        # check if all mapped keys in input are in the example template
        try:
            current_example = example_template.format(**input)
        except KeyError as e:
            current_example = example_template.format_map(defaultdict(str, input))
            logger.warn(
                f'\n\nKey {e} in the "example_template" in the given config'
                f"\n\n{example_template}\n\nis not present in the datsaset columns - {input.keys()}.\n\n"
                f"Input - {input}\n\n"
                "Continuing with the prompt as {current_example}"
            )

        # populate the current example in the prompt
        prompt_template = (
            self.prompt_template
            if prompt_template_override is None
            else prompt_template_override
        )
        output_guidelines = (
            self.output_guidelines
            if output_guidelines_override is None
            else output_guidelines_override
        )
        if self._is_few_shot_mode():
            curr_text_prompt = self.trim_prompt(
                prompt_template,
                task_guidelines=fmt_task_guidelines,
                output_guidelines=output_guidelines,
                seed_examples="\n\n".join(fmt_examples),
                current_example=current_example,
                max_input_tokens=max_input_tokens,
                get_num_tokens=get_num_tokens,
            )
        else:
            curr_text_prompt = self.trim_prompt(
                prompt_template,
                task_guidelines=fmt_task_guidelines,
                output_guidelines=output_guidelines,
                current_example=current_example,
                max_input_tokens=max_input_tokens,
                get_num_tokens=get_num_tokens,
            )
        if self.image_col is not None:
            return json.dumps(
                {"text": curr_text_prompt, "image_url": input[self.image_col]}
            )
        else:
            return curr_text_prompt

    def get_explanation_prompt(self, example: Dict, include_label=True) -> str:
        pt = PromptTemplate(
            input_variables=get_format_variables(self.GENERATE_EXPLANATION_PROMPT),
            template=self.GENERATE_EXPLANATION_PROMPT,
        )

        # prepare task guideline
        labels_list = self.config.labels_list()
        num_labels = len(labels_list)
        fmt_task_guidelines = self.task_guidelines.format(
            num_labels=num_labels, labels="\n".join(labels_list)
        )

        # prepare labeled example
        example_template = self.config.example_template()
        fmt_example = example_template.format_map(defaultdict(str, example))

        return pt.format(
            task_guidelines=fmt_task_guidelines,
            label_format=self.LABEL_FORMAT_IN_EXPLANATION
            if include_label
            else self.EXCLUDE_LABEL_IN_EXPLANATION,
            labeled_example=fmt_example,
        )

    def get_generate_dataset_prompt(self, label: str) -> str:
        pt = PromptTemplate(
            input_variables=get_format_variables(self.GENERATE_DATASET_TEMPLATE),
            template=self.GENERATE_DATASET_TEMPLATE,
        )

        # prepare task guideline
        labels_list = self.config.labels_list()
        num_labels = len(labels_list)
        fmt_task_guidelines = self.task_guidelines.format(
            num_labels=num_labels, labels="\n".join(labels_list)
        )
        fmt_guidelines = self.dataset_generation_guidelines.format(
            task_guidelines=fmt_task_guidelines
        )

        # prepare columns
        columns = get_format_variables(self.config.example_template())
        columns.remove(self.config.label_column())

        # prepare label descriptions
        fmt_label_descriptions = self.LABEL_DESCRIPTIONS_PROMPT.format(
            num_labels=num_labels
        )
        for i, l in enumerate(labels_list):
            fmt_label_descriptions += f"\n{i+1}. {l}{': ' + self.config.label_descriptions()[l] if self.config.label_descriptions() is not None and l in self.config.label_descriptions() else ''}"

        # prepare format
        example_rows = "\n".join(
            [",".join([f'"{column}_{i+1}"' for column in columns]) for i in range(3)]
        )
        fmt_format_guidelines = self.GENERATE_DATASET_FORMAT_GUIDELINES.format(
            columns=",".join(columns), example_rows=example_rows
        )

        # prepare output guidelines
        fmt_output_guidelines = self.GENERATE_DATASET_OUTPUT_GUIDELINES.format(
            num_rows=self.config.dataset_generation_num_rows(), label=label
        )

        return pt.format(
            guidelines=fmt_guidelines,
            columns=columns,
            label_descriptions=fmt_label_descriptions,
            format_guidelines=fmt_format_guidelines,
            output_guidelines=fmt_output_guidelines,
        )

    def eval(
        self,
        llm_labels: List[LLMAnnotation],
        gt_labels: List[str],
        additional_metrics: List[BaseMetric] = [],
    ) -> List[MetricResult]:
        """Evaluate the LLM generated labels by comparing them against ground truth

        Args:
            llm_labels (List[LLMAnnotation]): _description_
            gt_labels (List[str]): _description_
            additional_metrics (List[BaseMetric], optional): The additional metrics to run. Defaults to [].

        Returns:
            List[MetricResult]: list of metrics and corresponding values
        """

        eval_metrics = []

        for metric in self.metrics + additional_metrics:
            eval_metrics.extend(metric.compute(llm_labels, gt_labels))

        return eval_metrics

eval(llm_labels, gt_labels, additional_metrics=[])

Evaluate the LLM generated labels by comparing them against ground truth

Parameters:

Name Type Description Default
llm_labels List[LLMAnnotation]

description

required
gt_labels List[str]

description

required
additional_metrics List[BaseMetric]

The additional metrics to run. Defaults to [].

[]

Returns:

Type Description
List[MetricResult]

List[MetricResult]: list of metrics and corresponding values

Source code in src/autolabel/tasks/classification.py
def eval(
    self,
    llm_labels: List[LLMAnnotation],
    gt_labels: List[str],
    additional_metrics: List[BaseMetric] = [],
) -> List[MetricResult]:
    """Evaluate the LLM generated labels by comparing them against ground truth

    Args:
        llm_labels (List[LLMAnnotation]): _description_
        gt_labels (List[str]): _description_
        additional_metrics (List[BaseMetric], optional): The additional metrics to run. Defaults to [].

    Returns:
        List[MetricResult]: list of metrics and corresponding values
    """

    eval_metrics = []

    for metric in self.metrics + additional_metrics:
        eval_metrics.extend(metric.compute(llm_labels, gt_labels))

    return eval_metrics

Bases: BaseTask

Source code in src/autolabel/tasks/entity_matching.py
class EntityMatchingTask(BaseTask):
    DEFAULT_OUTPUT_GUIDELINES = (
        'You will return the answer with one element: "the correct option"\n'
    )
    DEFAULT_TASK_GUIDELINES = "Your job is to tell if the two given entities are duplicates or not. You will return the answer from one of the choices. Choices:\n{labels}\n"
    LABEL_FORMAT_IN_EXPLANATION = (
        " The last line of the explanation should be - So, the answer is <label>."
    )
    EXCLUDE_LABEL_IN_EXPLANATION = " Do not repeat the output of the task - simply provide an explanation for the provided output. The provided label was generated by you in a previous step and your job now is to only provided an explanation for the output. Your job is not verify the output but instead explain why it might have been generated, even if it is incorrect. If you think the provided output is incorrect, give an explanation of why it might have been generated anyway but don't say that the output may be incorrect or incorrectly generated.'"
    GENERATE_EXPLANATION_PROMPT = "You are an expert at providing a well reasoned explanation for the output of a given task. \n\nBEGIN TASK DESCRIPTION\n{task_guidelines}\nEND TASK DESCRIPTION\nYou will be given an input example and the corresponding output. Your job is to provide an explanation for why the output is correct for the task above.\nThink step by step and generate an explanation.{label_format}\n{labeled_example}\nExplanation: "

    GENERATE_DATASET_TEMPLATE = "{guidelines}\n\nThe inputs must be diverse, covering a wide range of scenarios. You will not generate duplicate inputs. These inputs should be organized in rows in csv format with the columns {columns}.\n\n{label_descriptions}\n\n{format_guidelines}\n\n{output_guidelines}\n\n```csv"
    DEFAULT_DATASET_GENERATION_GUIDELINES = "You are an expert at generating plausible inputs for a given task.\n\nBEGIN TASK DESCRIPTION\n{task_guidelines}\nEND TASK DESCRIPTION"
    LABEL_DESCRIPTIONS_PROMPT = "Each input should fall into one of these {num_labels} categories. These are the only categories that the inputs can belong to."
    GENERATE_DATASET_FORMAT_GUIDELINES = "Your response should be in csv format with the following columns: {columns}.\n\nHere is a template you can follow for your output:\n```csv\n{columns}\n{example_rows}\n```\n\nMake sure to replace the placeholder variables with your own values."
    GENERATE_DATASET_OUTPUT_GUIDELINES = 'Now I want you to generate {num_rows} excerpts that follow the guidelines and all belong to the "{label}" category. They should not belong to any of the other categories.'

    def __init__(self, config: AutolabelConfig) -> None:
        super().__init__(config)
        self.metrics = [
            AccuracyMetric(),
            SupportMetric(),
            CompletionRateMetric(),
            ClassificationReportMetric(),
        ]

        if self.config.confidence():
            self.metrics.append(AUROCMetric())

        for label in self.config.labels_list():
            if "\n" in label:
                logger.warning(
                    "Label contains newline character. This can have output guideline issues."
                )

    def construct_prompt(
        self,
        input: Dict,
        examples: List[Dict],
        prompt_template_override: PromptTemplate = None,
        refuel_prompt_override: bool = False,
        output_guidelines_override: str = None,
        max_input_tokens: int = None,
        get_num_tokens: Optional[Callable] = None,
        **kwargs,
    ) -> str:
        # Copy over the input so that we can modify it
        input = input.copy()

        # prepare task guideline
        labels_list = self.config.labels_list()
        num_labels = len(labels_list)
        fmt_task_guidelines = self.task_guidelines.format_map(
            defaultdict(str, labels="\n".join(labels_list), num_labels=num_labels)
        )

        # prepare seed examples
        example_template = self.config.example_template()
        label_column = self.config.label_column()
        fmt_examples = []
        for eg in examples:
            eg_copy = eg.copy()
            # If chain of thought is enabled
            if label_column and self.config.chain_of_thought():
                eg_copy[label_column] = json.dumps({"label": eg[label_column]})
            fmt_examples.append(example_template.format_map(defaultdict(str, eg_copy)))

        # populate the current example in the prompt
        if label_column:
            input[label_column] = ""

        # populate the explanation column with empty string for current example
        explanation_column = self.config.explanation_column()
        if explanation_column:
            input[explanation_column] = ""

        # check if all mapped keys in input are in the example template
        try:
            current_example = example_template.format(**input)
        except KeyError as e:
            current_example = example_template.format_map(defaultdict(str, input))
            logger.warn(
                f'\n\nKey {e} in the "example_template" in the given config'
                f"\n\n{example_template}\n\nis not present in the datsaset columns - {input.keys()}.\n\n"
                f"Input - {input}\n\n"
                "Continuing with the prompt as {current_example}"
            )

        # populate the current example in the prompt
        prompt_template = (
            self.prompt_template
            if prompt_template_override is None
            else prompt_template_override
        )
        output_guidelines = (
            self.output_guidelines
            if output_guidelines_override is None
            else output_guidelines_override
        )
        if self._is_few_shot_mode():
            curr_text_prompt = self.trim_prompt(
                prompt_template,
                task_guidelines=fmt_task_guidelines,
                output_guidelines=output_guidelines,
                seed_examples="\n\n".join(fmt_examples),
                current_example=current_example,
                max_input_tokens=max_input_tokens,
                get_num_tokens=get_num_tokens,
            )
        else:
            curr_text_prompt = self.trim_prompt(
                prompt_template,
                task_guidelines=fmt_task_guidelines,
                output_guidelines=output_guidelines,
                current_example=current_example,
                max_input_tokens=max_input_tokens,
                get_num_tokens=get_num_tokens,
            )

        if self.image_col is not None:
            return json.dumps(
                {"text": curr_text_prompt, "image_url": input[self.image_col]}
            )
        else:
            return curr_text_prompt

    def get_explanation_prompt(self, example: Dict, include_label=True) -> str:
        pt = PromptTemplate(
            input_variables=get_format_variables(self.GENERATE_EXPLANATION_PROMPT),
            template=self.GENERATE_EXPLANATION_PROMPT,
        )

        # prepare task guideline
        labels_list = self.config.labels_list()
        num_labels = len(labels_list)
        fmt_task_guidelines = self.task_guidelines.format(
            num_labels=num_labels, labels="\n".join(labels_list)
        )

        # prepare labeled example
        example_template = self.config.example_template()
        fmt_example = example_template.format_map(defaultdict(str, example))

        return pt.format(
            task_guidelines=fmt_task_guidelines,
            label_format=self.LABEL_FORMAT_IN_EXPLANATION
            if include_label
            else self.EXCLUDE_LABEL_IN_EXPLANATION,
            labeled_example=fmt_example,
        )

    def get_generate_dataset_prompt(self, label: str) -> str:
        pt = PromptTemplate(
            input_variables=get_format_variables(self.GENERATE_DATASET_TEMPLATE),
            template=self.GENERATE_DATASET_TEMPLATE,
        )

        # prepare task guideline
        labels_list = self.config.labels_list()
        num_labels = len(labels_list)
        fmt_task_guidelines = self.task_guidelines.format(
            num_labels=num_labels, labels="\n".join(labels_list)
        )
        fmt_guidelines = self.dataset_generation_guidelines.format(
            task_guidelines=fmt_task_guidelines
        )

        # prepare columns
        columns = get_format_variables(self.config.example_template())
        columns.remove(self.config.label_column())

        # prepare label descriptions
        fmt_label_descriptions = self.LABEL_DESCRIPTIONS_PROMPT.format(
            num_labels=num_labels
        )

        for i, l in enumerate(labels_list):
            fmt_label_descriptions += f"\n{i+1}. {l}{': ' + self.config.label_descriptions()[l] if self.config.label_descriptions() is not None and l in self.config.label_descriptions() else ''}"

        # prepare format
        example_rows = "\n".join(
            [",".join([f'"{column}_{i+1}"' for column in columns]) for i in range(3)]
        )
        fmt_format_guidelines = self.GENERATE_DATASET_FORMAT_GUIDELINES.format(
            columns=",".join(columns), example_rows=example_rows
        )

        # prepare output guidelines
        fmt_output_guidelines = self.GENERATE_DATASET_OUTPUT_GUIDELINES.format(
            num_rows=self.config.dataset_generation_num_rows(), label=label
        )

        return pt.format(
            guidelines=fmt_guidelines,
            columns=columns,
            label_descriptions=fmt_label_descriptions,
            format_guidelines=fmt_format_guidelines,
            output_guidelines=fmt_output_guidelines,
        )

    def eval(
        self,
        llm_labels: List[LLMAnnotation],
        gt_labels: List[str],
        additional_metrics: List[BaseMetric] = [],
    ) -> List[MetricResult]:
        """Evaluate the LLM generated labels by comparing them against ground truth

        Args:
            llm_labels (List[LLMAnnotation]): _description_
            gt_labels (List[str]): _description_
            additional_metrics (List[BaseMetric], optional): List of additional metrics to run. Defaults to [].

        Returns:
            List[MetricResult]: list of metrics and corresponding values
        """

        eval_metrics = []

        for metric in self.metrics + additional_metrics:
            eval_metrics.extend(metric.compute(llm_labels, gt_labels))

        return eval_metrics

eval(llm_labels, gt_labels, additional_metrics=[])

Evaluate the LLM generated labels by comparing them against ground truth

Parameters:

Name Type Description Default
llm_labels List[LLMAnnotation]

description

required
gt_labels List[str]

description

required
additional_metrics List[BaseMetric]

List of additional metrics to run. Defaults to [].

[]

Returns:

Type Description
List[MetricResult]

List[MetricResult]: list of metrics and corresponding values

Source code in src/autolabel/tasks/entity_matching.py
def eval(
    self,
    llm_labels: List[LLMAnnotation],
    gt_labels: List[str],
    additional_metrics: List[BaseMetric] = [],
) -> List[MetricResult]:
    """Evaluate the LLM generated labels by comparing them against ground truth

    Args:
        llm_labels (List[LLMAnnotation]): _description_
        gt_labels (List[str]): _description_
        additional_metrics (List[BaseMetric], optional): List of additional metrics to run. Defaults to [].

    Returns:
        List[MetricResult]: list of metrics and corresponding values
    """

    eval_metrics = []

    for metric in self.metrics + additional_metrics:
        eval_metrics.extend(metric.compute(llm_labels, gt_labels))

    return eval_metrics

Bases: BaseTask

Source code in src/autolabel/tasks/question_answering.py
class QuestionAnsweringTask(BaseTask):
    DEFAULT_OUTPUT_GUIDELINES = (
        'You will return the answer one element: "the correct label"\n'
    )
    REFUEL_LLM_DEFAULT_OUTPUT_GUIDELINES = ""
    DEFAULT_TASK_GUIDELINES = "Your job is to answer the following questions using the options provided for each question. Choose the best answer for the question.\n"
    NULL_LABEL_TOKEN = "NO_LABEL"

    LABEL_FORMAT_IN_EXPLANATION = (
        " The last line of the explanation should be - So, the answer is <label>."
    )
    EXCLUDE_LABEL_IN_EXPLANATION = " Do not repeat the output of the task - simply provide an explanation for the provided output. The provided label was generated by you in a previous step and your job now is to only provided an explanation for the output. Your job is not verify the output but instead explain why it might have been generated, even if it is incorrect. If you think the provided output is incorrect, give an explanation of why it might have been generated anyway but don't say that the output may be incorrect or incorrectly generated.'"
    GENERATE_EXPLANATION_PROMPT = "You are an expert at providing a well reasoned explanation for the output of a given task. \n\nBEGIN TASK DESCRIPTION\n{task_guidelines}\nEND TASK DESCRIPTION\nYou will be given an input example and the corresponding output. You will be given a question and an answer. Your job is to provide an explanation for why the answer is correct for the task above.\nThink step by step and generate an explanation.{label_format}\n{labeled_example}\nExplanation: "

    def __init__(self, config: AutolabelConfig) -> None:
        if config.provider() == ModelProvider.REFUEL:
            self.DEFAULT_OUTPUT_GUIDELINES = self.REFUEL_LLM_DEFAULT_OUTPUT_GUIDELINES

        super().__init__(config)
        self.metrics = [
            AccuracyMetric(),
            SupportMetric(),
            CompletionRateMetric(),
            F1Metric(
                type=F1Type.TEXT,
            ),
        ]

        if self.config.confidence():
            self.metrics.append(AUROCMetric())

    def construct_prompt(
        self,
        input: Dict,
        examples: List[Dict],
        prompt_template_override: PromptTemplate = None,
        refuel_prompt_override: bool = False,
        output_guidelines_override: str = None,
        max_input_tokens: int = None,
        get_num_tokens: Optional[Callable] = None,
        **kwargs,
    ) -> str:
        # Copy over the input so that we can modify it
        input = input.copy()

        # prepare seed examples
        example_template = self.config.example_template()
        label_column = self.config.label_column()
        fmt_examples = []
        for eg in examples:
            eg_copy = eg.copy()
            # If chain of thought is enabled
            if label_column and self.config.chain_of_thought():
                eg_copy[label_column] = json.dumps({"label": eg[label_column]})
            fmt_examples.append(example_template.format_map(defaultdict(str, eg_copy)))

        # populate the current example in the prompt
        if label_column:
            input[label_column] = ""

        # populate the explanation column with empty string for current example
        explanation_column = self.config.explanation_column()
        if explanation_column:
            input[explanation_column] = ""

            # check if all mapped keys in input are in the example template
        try:
            current_example = example_template.format(**input)
        except KeyError as e:
            current_example = example_template.format_map(defaultdict(str, input))
            logger.warn(
                f'\n\nKey {e} in the "example_template" in the given config'
                f"\n\n{example_template}\n\nis not present in the datsaset columns - {input.keys()}.\n\n"
                f"Input - {input}\n\n"
                "Continuing with the prompt as {current_example}"
            )

        # populate the current example in the prompt
        prompt_template = (
            self.prompt_template
            if prompt_template_override is None
            else prompt_template_override
        )
        output_guidelines = (
            self.output_guidelines
            if output_guidelines_override is None
            else output_guidelines_override
        )
        if self._is_few_shot_mode():
            curr_text_prompt = prompt_template.format(
                task_guidelines=self.task_guidelines,
                output_guidelines=output_guidelines,
                seed_examples="\n\n".join(fmt_examples),
                current_example=current_example,
            )
        else:
            curr_text_prompt = prompt_template.format(
                task_guidelines=self.task_guidelines,
                output_guidelines=output_guidelines,
                current_example=current_example,
            )

        if self.image_col is not None:
            return json.dumps(
                {"text": curr_text_prompt, "image_url": input[self.image_col]}
            )
        else:
            return curr_text_prompt

    def construct_confidence_prompt(self, input: str, examples: List, **kwargs) -> str:
        output_guidelines_override = (
            self.config.output_guidelines() or self.REFUEL_LLM_DEFAULT_OUTPUT_GUIDELINES
        )
        refuel_prompt = super().construct_confidence_prompt(
            input,
            examples,
            output_guidelines_override=output_guidelines_override,
            **kwargs,
        )
        return refuel_prompt

    def get_explanation_prompt(self, example: Dict, include_label=True) -> str:
        pt = PromptTemplate(
            input_variables=get_format_variables(self.GENERATE_EXPLANATION_PROMPT),
            template=self.GENERATE_EXPLANATION_PROMPT,
        )
        example_template = self.config.example_template()
        fmt_example = example_template.format_map(defaultdict(str, example))

        return pt.format(
            task_guidelines=self.task_guidelines,
            label_format=self.LABEL_FORMAT_IN_EXPLANATION
            if include_label
            else self.EXCLUDE_LABEL_IN_EXPLANATION,
            labeled_example=fmt_example,
        )

    def get_generate_dataset_prompt(
        self, label: str, num_rows: int, guidelines: str = None
    ) -> str:
        raise NotImplementedError("Dataset generation not implemented for this task")

    def eval(
        self,
        llm_labels: List[LLMAnnotation],
        gt_labels: List[str],
        additional_metrics: Optional[List[BaseMetric]] = [],
    ) -> List[MetricResult]:
        """Evaluate the LLM generated labels by comparing them against ground truth

        Args:
            llm_labels (List[LLMAnnotation]): _description_
            gt_labels (List[str]): _description_
            additional_metrics (Optional[List[BaseMetric]], optional): _description_. Defaults to [].

        Returns:
            List[MetricResult]: list of metrics and corresponding values
        """
        eval_metrics = []

        for metric in self.metrics + additional_metrics:
            eval_metrics.extend(metric.compute(llm_labels, gt_labels))

        return eval_metrics

eval(llm_labels, gt_labels, additional_metrics=[])

Evaluate the LLM generated labels by comparing them against ground truth

Parameters:

Name Type Description Default
llm_labels List[LLMAnnotation]

description

required
gt_labels List[str]

description

required
additional_metrics Optional[List[BaseMetric]]

description. Defaults to [].

[]

Returns:

Type Description
List[MetricResult]

List[MetricResult]: list of metrics and corresponding values

Source code in src/autolabel/tasks/question_answering.py
def eval(
    self,
    llm_labels: List[LLMAnnotation],
    gt_labels: List[str],
    additional_metrics: Optional[List[BaseMetric]] = [],
) -> List[MetricResult]:
    """Evaluate the LLM generated labels by comparing them against ground truth

    Args:
        llm_labels (List[LLMAnnotation]): _description_
        gt_labels (List[str]): _description_
        additional_metrics (Optional[List[BaseMetric]], optional): _description_. Defaults to [].

    Returns:
        List[MetricResult]: list of metrics and corresponding values
    """
    eval_metrics = []

    for metric in self.metrics + additional_metrics:
        eval_metrics.extend(metric.compute(llm_labels, gt_labels))

    return eval_metrics

Bases: BaseTask

Source code in src/autolabel/tasks/named_entity_recognition.py
class NamedEntityRecognitionTask(BaseTask):
    DEFAULT_OUTPUT_GUIDELINES = "You will return the answer in CSV format, with two columns seperated by the % character. First column is the extracted entity and second column is the category. Rows in the CSV are separated by new line character."
    DEFAULT_TASK_GUIDELINES = "Your job is to extract named entities mentioned in text, and classify them into one of the following {num_labels} categories.\nCategories:\n{labels}\n "
    NULL_LABEL = {}

    def __init__(self, config: AutolabelConfig) -> None:
        super().__init__(config)

    def _json_to_llm_format(self, input_label: str) -> str:
        # `label` format: {"entity type": [list of entities of this type]}
        try:
            labels = json.loads(input_label)
            rows = []
            for entity_type, detected_entites in labels.items():
                for e in detected_entites:
                    row = "%".join([e, entity_type])
                    rows.append(row)
            llm_formatted_label = "\n".join(rows)
            return llm_formatted_label
        except json.JSONDecodeError as e:
            logger.error(
                f"Could not parse label: {input_label}. Few-shot examples might be formatted incorrectly"
            )
            return input_label

    def _llm_to_json_format(self, response: str):
        split_response = response.split("\n")
        json_output = {i: [] for i in self.config.labels_list()}

        for row in split_response:
            parts = row.split("%")
            if len(parts) != 2 or parts[1] not in json_output.keys():
                logger.debug(f"Malformed LLM response: {row}")
                continue
            named_entity = parts[0]
            category = parts[1]
            json_output[category].append(named_entity)
        return json_output

    def construct_prompt(
        self,
        input: Dict,
        examples: List,
        prompt_template_override: PromptTemplate = None,
        refuel_prompt_override: bool = False,
        output_guidelines_override: str = None,
        max_input_tokens: int = None,
        get_num_tokens: Optional[Callable] = None,
        **kwargs,
    ) -> str:
        # prepare task guideline
        labels_list = self.config.labels_list()
        num_labels = len(labels_list)
        fmt_task_guidelines = self.task_guidelines.format_map(
            defaultdict(str, labels="\n".join(labels_list), num_labels=num_labels)
        )

        # prepare seed examples
        label_column = self.config.label_column()
        example_template = self.config.example_template()
        fmt_examples = []
        for eg in examples:
            eg_copy = deepcopy(eg)
            if label_column:
                eg_copy[label_column] = self._json_to_llm_format(eg_copy[label_column])
            fmt_examples.append(example_template.format_map(defaultdict(str, eg_copy)))

        # populate the current example in the prompt
        if label_column:
            input[label_column] = ""

        # populate the explanation column with empty string for current example
        explanation_column = self.config.explanation_column()
        if explanation_column:
            input[explanation_column] = ""

        # check if all mapped keys in input are in the example template
        try:
            current_example = example_template.format(**input)
        except KeyError as e:
            current_example = example_template.format_map(defaultdict(str, input))
            logger.warn(
                f'\n\nKey {e} in the "example_template" in the given config'
                f"\n\n{example_template}\n\nis not present in the datsaset columns - {input.keys()}.\n\n"
                f"Input - {input}\n\n"
                "Continuing with the prompt as {current_example}"
            )

        # populate the current example in the prompt
        prompt_template = (
            self.prompt_template
            if prompt_template_override is None
            else prompt_template_override
        )
        output_guidelines = (
            self.output_guidelines
            if output_guidelines_override is None
            else output_guidelines_override
        )
        if self._is_few_shot_mode():
            curr_text_prompt = self.trim_prompt(
                prompt_template,
                task_guidelines=fmt_task_guidelines,
                output_guidelines=output_guidelines,
                seed_examples="\n\n".join(fmt_examples),
                current_example=current_example,
                max_input_tokens=max_input_tokens,
                get_num_tokens=get_num_tokens,
            )
        else:
            curr_text_prompt = self.trim_prompt(
                prompt_template,
                task_guidelines=fmt_task_guidelines,
                output_guidelines=output_guidelines,
                current_example=current_example,
                max_input_tokens=max_input_tokens,
                get_num_tokens=get_num_tokens,
            )

        if self.image_col is not None:
            return json.dumps(
                {"text": curr_text_prompt, "image_url": input[self.image_col]}
            )
        else:
            return curr_text_prompt

    def get_explanation_prompt(self, example: Dict, include_label=True) -> str:
        raise NotImplementedError(
            "Explanation generation not implemented for this task"
        )

    def get_generate_dataset_prompt(
        self, label: str, num_rows: int, guidelines: str = None
    ) -> str:
        raise NotImplementedError("Dataset generation not implemented for this task")

    def add_text_spans(self, raw_output: dict, input: str) -> list:
        processed_output = []
        for entity_type in raw_output:
            for curr_entity in raw_output[entity_type]:
                processed_output.append({"type": entity_type, "text": curr_entity})

        # create a frequency dict of each named entity in the input to determine text spans for repeated entities
        frequency_count = {label["text"]: 0 for label in processed_output}

        for label in processed_output:
            text = label["text"]
            matches = [i.start() for i in re.finditer(text, input)]
            count = frequency_count[text]
            # if count of the named entity is greater than the number of matches, default to last found match
            if count >= len(matches):
                count = -1

            # if no occurence of named entity in input, default text span to start: -1, end: -1
            if len(matches) == 0:
                label["start"] = -1
                label["end"] = -1
            else:
                label["start"] = matches[count]
                label["end"] = matches[count] + len(text)
            frequency_count[text] += 1
        return processed_output

    def parse_llm_response(
        self,
        response: Union[Generation, ChatGeneration],
        curr_sample: Dict,
        prompt: str,
    ) -> LLMAnnotation:
        output = {}
        successfully_labeled = False
        error = None
        text_column = self.config.text_column()
        input_str = curr_sample[text_column]
        try:
            completion_text = response.text
            output = self._llm_to_json_format(completion_text.strip())
            llm_label = self.add_text_spans(output, input_str)
        except Exception as e:
            logger.error(f"Error parsing LLM response: {response.text}, Error: {e}")
            llm_label = self.NULL_LABEL
            error = LabelingError(error_type=ErrorType.PARSING_ERROR, error_msg=str(e))

        successfully_labeled = False if llm_label == self.NULL_LABEL else True

        # TODO: parse generation info correctly to fetch & transform logprobs -> score
        return LLMAnnotation(
            curr_sample=input_str,
            successfully_labeled=successfully_labeled,
            label=llm_label,
            generation_info=response.generation_info,
            raw_response=response.text,
            prompt=prompt,
            error=error,
        )

    def auroc_score_labels(
        self, gt_labels, llm_labels_with_conf
    ) -> Tuple[List[int], List[float]]:
        labels = []
        confidences = []
        for index, pred_entities in enumerate(llm_labels_with_conf):
            gt_entities = gt_labels[index]
            pred_conf = pred_entities[0]["conf"] if len(pred_entities) > 0 else 0
            for gt_entity in gt_entities:
                match_found = False
                pred_index = 0
                while not match_found and pred_index < len(pred_entities):
                    curr_match = True
                    for key in gt_entity:
                        if gt_entity[key] != pred_entities[pred_index][key]:
                            curr_match = False
                    if curr_match:
                        match_found = True
                    pred_index += 1
                labels.append(int(match_found))
                confidences.append(pred_conf)
        return labels, confidences

    def get_labels_predictions_with_threshold(self, gt_labels, llm_labels, threshold):
        answered_gt_labels, answered_llm_preds = [], []
        for index, l in enumerate(llm_labels):
            if l.successfully_labeled and (
                l.confidence_score is None or l.confidence_score >= threshold
            ):
                answered_gt_labels.append(
                    [{**entity, "label": entity["type"]} for entity in gt_labels[index]]
                )
                answered_llm_preds.append(
                    [
                        {
                            **entity,
                            "label": entity["type"],
                            "conf": l.confidence_score,
                        }
                        for entity in l.label
                    ],
                )

        return answered_gt_labels, answered_llm_preds

    def run_metrics(
        self,
        answered_gt_labels,
        answered_llm_preds,
        entity_types_set,
    ) -> List[MetricResult]:
        eval_metrics = []
        evaluator = Evaluator(
            answered_gt_labels, answered_llm_preds, tags=entity_types_set
        )

        results, _ = evaluator.evaluate()
        # f1 score for exact match
        eval_metrics.append(
            MetricResult(
                name=MetricType.F1_EXACT,
                value=results["exact"]["f1"],
            )
        )
        # f1 score for strict match
        eval_metrics.append(
            MetricResult(
                name=MetricType.F1_STRICT,
                value=results["strict"]["f1"],
            )
        )
        # f1 score for partial match
        eval_metrics.append(
            MetricResult(
                name=MetricType.F1_PARTIAL,
                value=results["partial"]["f1"],
            )
        )
        # f1 score for entity type match
        eval_metrics.append(
            MetricResult(
                name=MetricType.F1_ENT_TYPE,
                value=results["ent_type"]["f1"],
            )
        )
        # accuracy
        accuracy = (
            results.get("strict").get("correct")
            / (results.get("strict").get("possible"))
            if results.get("strict").get("possible") > 0
            else 0.0
        )
        eval_metrics.append(
            MetricResult(
                name=MetricType.ACCURACY,
                value=accuracy,
            )
        )

        if self.config.confidence():
            match, confidences = self.auroc_score_labels(
                answered_gt_labels, answered_llm_preds
            )
            auroc = roc_auc_score(match, confidences)
            eval_metrics.append(
                MetricResult(
                    name=MetricType.AUROC,
                    value=auroc,
                )
            )

        return eval_metrics

    def eval(
        self,
        llm_labels: List[LLMAnnotation],
        gt_labels: List[str],
        additional_metrics: Optional[List[BaseMetric]] = [],
    ) -> List[MetricResult]:
        """Evaluate the LLM generated labels by comparing them against ground truth

        Args:
            llm_labels (List[LLMAnnotation]): _description_
            gt_labels (List[str]): _description_

        Returns:
            List[MetricResult]: list of metrics and corresponding values
        """
        gt_labels = [
            self.add_text_spans(
                json.loads(gt_labels[index]), llm_labels[index].curr_sample.decode()
            )
            for index in range(len(gt_labels))
        ]

        (
            curr_gt_labels,
            curr_llm_labels,
        ) = self.get_labels_predictions_with_threshold(
            gt_labels, llm_labels, float("-inf")
        )

        entity_types_set = list(
            set(
                [
                    gt_entity.get("label")
                    for gt_label in curr_gt_labels
                    for gt_entity in gt_label
                ]
            )
        )

        eval_metrics = []

        eval_metrics.append(
            MetricResult(
                name=MetricType.SUPPORT,
                value=len(gt_labels),
            )
        )

        eval_metrics.append(
            MetricResult(
                name=MetricType.COMPLETION_RATE,
                value=(
                    len(curr_llm_labels) / float(len(gt_labels))
                    if len(gt_labels) > 0
                    else 0.0
                ),
            )
        )

        curr_threshold_metrics = self.run_metrics(
            curr_gt_labels,
            curr_llm_labels,
            entity_types_set,
        )

        eval_metrics.extend(curr_threshold_metrics)
        return eval_metrics

eval(llm_labels, gt_labels, additional_metrics=[])

Evaluate the LLM generated labels by comparing them against ground truth

Parameters:

Name Type Description Default
llm_labels List[LLMAnnotation]

description

required
gt_labels List[str]

description

required

Returns:

Type Description
List[MetricResult]

List[MetricResult]: list of metrics and corresponding values

Source code in src/autolabel/tasks/named_entity_recognition.py
def eval(
    self,
    llm_labels: List[LLMAnnotation],
    gt_labels: List[str],
    additional_metrics: Optional[List[BaseMetric]] = [],
) -> List[MetricResult]:
    """Evaluate the LLM generated labels by comparing them against ground truth

    Args:
        llm_labels (List[LLMAnnotation]): _description_
        gt_labels (List[str]): _description_

    Returns:
        List[MetricResult]: list of metrics and corresponding values
    """
    gt_labels = [
        self.add_text_spans(
            json.loads(gt_labels[index]), llm_labels[index].curr_sample.decode()
        )
        for index in range(len(gt_labels))
    ]

    (
        curr_gt_labels,
        curr_llm_labels,
    ) = self.get_labels_predictions_with_threshold(
        gt_labels, llm_labels, float("-inf")
    )

    entity_types_set = list(
        set(
            [
                gt_entity.get("label")
                for gt_label in curr_gt_labels
                for gt_entity in gt_label
            ]
        )
    )

    eval_metrics = []

    eval_metrics.append(
        MetricResult(
            name=MetricType.SUPPORT,
            value=len(gt_labels),
        )
    )

    eval_metrics.append(
        MetricResult(
            name=MetricType.COMPLETION_RATE,
            value=(
                len(curr_llm_labels) / float(len(gt_labels))
                if len(gt_labels) > 0
                else 0.0
            ),
        )
    )

    curr_threshold_metrics = self.run_metrics(
        curr_gt_labels,
        curr_llm_labels,
        entity_types_set,
    )

    eval_metrics.extend(curr_threshold_metrics)
    return eval_metrics

filter_unlabeled_examples(gt_labels, llm_labels)

Filter out unlabeled examples from the ground truth and LLM generated labels. This is done by checking the ground truth labels which have nan values. The corresponding ground truth and LLM labels are removed from the filtered labels lists.

Parameters:

Name Type Description Default
gt_labels List[str]

ground truth labels

required
llm_labels List[LLMAnnotation]

llm labels

required

Returns:

Type Description
Tuple[List[str], List[LLMAnnotation]]

filtered_gt_labels, filtered_llm_labels: filtered ground truth and LLM generated labels

Source code in src/autolabel/tasks/utils.py
def filter_unlabeled_examples(
    gt_labels: List[str], llm_labels: List[LLMAnnotation]
) -> Tuple[List[str], List[LLMAnnotation]]:
    """Filter out unlabeled examples from the ground truth and LLM generated labels.
    This is done by checking the ground truth labels which have nan values.
    The corresponding ground truth and LLM labels are removed from the filtered labels lists.

    Args:
        gt_labels (List[str]): ground truth labels
        llm_labels (List[LLMAnnotation]): llm labels

    Returns:
        filtered_gt_labels, filtered_llm_labels: filtered ground truth and LLM generated labels
    """
    filtered_gt_labels = []
    filtered_llm_labels = []
    for gt_label, llm_label in zip(gt_labels, llm_labels):
        if gt_label != "nan":
            filtered_gt_labels.append(gt_label)
            filtered_llm_labels.append(llm_label)
    return filtered_gt_labels, filtered_llm_labels

normalize_text(s)

Removing articles and punctuation, and standardizing whitespace are all typical text processing steps.

Source code in src/autolabel/tasks/utils.py
def normalize_text(s: str) -> str:
    """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))