Data Models
The Data Model classes are used to save the progress of AutoLabel jobs in an SQL database.
Saved data is stored in .autolabel.db
Every Data Model class implements its own "get" and "create" methods for accessing this saved data.
Bases: Base
Source code in src/autolabel/data_models/annotation.py
| class AnnotationModel(Base):
__tablename__ = "annotations"
id = Column(Integer, primary_key=True, autoincrement=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
index = Column(Integer)
llm_annotation = Column(TEXT)
task_run_id = Column(Integer, ForeignKey("task_runs.id"))
task_runs = relationship("TaskRunModel", back_populates="annotations")
def __repr__(self):
return f"<AnnotationModel(id={self.id}, index={self.index}, annotation={self.llm_annotation})"
@classmethod
def create_from_llm_annotation(
cls, db, llm_annotation: LLMAnnotation, index: int, task_run_id: int
):
db_object = cls(
llm_annotation=pickle.dumps(llm_annotation),
index=index,
task_run_id=task_run_id,
)
db.add(db_object)
db.commit()
db_object = db.query(cls).order_by(cls.id.desc()).first()
logger.debug(f"created new annotation: {db_object}")
return db_object
@classmethod
def get_annotations_by_task_run_id(cls, db, task_run_id: int):
annotations = (
db.query(cls)
.filter(cls.task_run_id == task_run_id)
.order_by(cls.index)
.all()
)
filtered_annotations = []
ids = {}
for annotation in annotations:
if annotation.index not in ids:
ids[annotation.index] = True
filtered_annotations.append(annotation)
return filtered_annotations
@classmethod
def from_pydantic(cls, annotation: BaseModel):
return cls(**json.loads(annotation.json()))
def delete(self, db):
db.delete(self)
db.commit()
|
rendering:
show_root_heading: yes
show_root_full_path: no
Bases: Base
an SQLAlchemy based Cache system for storing and retriving CacheEntries
Source code in src/autolabel/data_models/generation_cache.py
| class GenerationCacheEntryModel(Base):
"""an SQLAlchemy based Cache system for storing and retriving CacheEntries"""
__tablename__ = "generation_cache"
id = Column(Integer, primary_key=True)
model_name = Column(String(50))
prompt = Column(Text)
model_params = Column(Text)
generations = Column(JSON)
creation_time_ms = Column(Integer)
ttl_ms = Column(Integer)
def __repr__(self):
return f"<Cache(model_name={self.model_name},prompt={self.prompt},model_params={self.model_params},generations={self.generations})>"
@classmethod
def get(cls, db, cache_entry: GenerationCacheEntry):
looked_up_entry = (
db.query(cls)
.filter(
cls.model_name == cache_entry.model_name,
cls.prompt == cache_entry.prompt,
cls.model_params == cache_entry.model_params,
)
.first()
)
if not looked_up_entry:
return None
generations = json.loads(looked_up_entry.generations)["generations"]
generations = [
Generation(**gen) if gen["type"] == "Generation" else ChatGeneration(**gen)
for gen in generations
]
entry = GenerationCacheEntry(
model_name=looked_up_entry.model_name,
prompt=looked_up_entry.prompt,
model_params=looked_up_entry.model_params,
generations=generations,
creation_time_ms=looked_up_entry.creation_time_ms,
ttl_ms=looked_up_entry.ttl_ms,
)
return entry
@classmethod
def insert(cls, db, cache_entry: BaseModel):
generations = {"generations": [gen.dict() for gen in cache_entry.generations]}
db_object = cls(
model_name=cache_entry.model_name,
prompt=cache_entry.prompt,
model_params=cache_entry.model_params,
generations=json.dumps(generations),
creation_time_ms=int(time.time() * 1000),
ttl_ms=cache_entry.ttl_ms,
)
db.add(db_object)
db.commit()
return cache_entry
@classmethod
def clear(cls, db, use_ttl: bool = True) -> None:
if use_ttl:
current_time_ms = int(time.time() * 1000)
db.query(cls).filter(
current_time_ms - cls.creation_time_ms > cls.ttl_ms
).delete()
else:
db.query(cls).delete()
db.commit()
|
rendering:
show_root_heading: yes
show_root_full_path: no
Bases: Base
an SQLAlchemy based Cache system for storing and retriving CacheEntries
Source code in src/autolabel/data_models/transform_cache.py
| class TransformCacheEntryModel(Base):
"""an SQLAlchemy based Cache system for storing and retriving CacheEntries"""
__tablename__ = "transform_cache"
id = Column(String, primary_key=True)
transform_name = Column(String(50))
transform_params = Column(TEXT)
input = Column(TEXT)
output = Column(TEXT)
creation_time_ms = Column(Integer)
ttl_ms = Column(Integer)
def __repr__(self):
return f"<TransformCache(id={self.id},transform_name={self.transform_name},transform_params={self.transform_params},input={self.input},output={self.output})>"
@classmethod
def get(cls, db, cache_entry: TransformCacheEntry) -> TransformCacheEntry:
id = cache_entry.get_id()
looked_up_entry = db.query(cls).filter(cls.id == id).first()
if not looked_up_entry:
return None
entry = TransformCacheEntry(
transform_name=looked_up_entry.transform_name,
transform_params=pickle.loads(looked_up_entry.transform_params),
input=pickle.loads(looked_up_entry.input),
output=pickle.loads(looked_up_entry.output),
creation_time_ms=looked_up_entry.creation_time_ms,
ttl_ms=looked_up_entry.ttl_ms,
)
return entry
@classmethod
def insert(cls, db, cache_entry: TransformCacheEntry) -> None:
db_object = cls(
id=cache_entry.get_id(),
transform_name=cache_entry.transform_name,
transform_params=pickle.dumps(cache_entry.transform_params),
input=pickle.dumps(cache_entry.input),
output=pickle.dumps(cache_entry.output),
creation_time_ms=int(time.time() * 1000),
ttl_ms=cache_entry.ttl_ms,
)
db.add(db_object)
db.commit()
return db_object
@classmethod
def clear(cls, db, use_ttl: bool = True) -> None:
if use_ttl:
current_time_ms = int(time.time() * 1000)
db.query(cls).filter(
current_time_ms - cls.creation_time_ms > cls.ttl_ms
).delete()
else:
db.query(cls).delete()
db.commit()
|
rendering:
show_root_heading: yes
show_root_full_path: no
Bases: Base
Source code in src/autolabel/data_models/dataset.py
| class DatasetModel(Base):
__tablename__ = "datasets"
id = Column(String(32), primary_key=True)
input_file = Column(String(50))
start_index = Column(Integer)
end_index = Column(Integer)
task_runs = relationship("TaskRunModel", back_populates="dataset")
def __repr__(self):
return f"<DatasetModel(id={self.id}, input_file={self.input_file}, start_index={self.start_index}, end_index={self.end_index})>"
@classmethod
def create(cls, db, dataset: BaseModel):
db_object = cls(**json.loads(dataset.json()))
db.add(db_object)
db.commit()
return db_object
@classmethod
def get_by_id(cls, db, id: int):
return db.query(cls).filter(cls.id == id).first()
@classmethod
def get_by_input_file(cls, db, input_file: str):
return db.query(cls).filter(cls.input_file == input_file).first()
def delete(self, db):
db.delete(self)
db.commit()
|
rendering:
show_root_heading: yes
show_root_full_path: no
Bases: Base
Source code in src/autolabel/data_models/task.py
| class TaskModel(Base):
__tablename__ = "tasks"
id = Column(String(32), primary_key=True)
task_type = Column(String(50))
provider = Column(String(50))
model_name = Column(String(50))
config = Column(Text)
task_runs = relationship("TaskRunModel", back_populates="task")
def __repr__(self):
return f"<TaskModel(id={self.id}, task_type={self.task_type}, provider={self.provider}, model_name={self.model_name})>"
@classmethod
def create(cls, db, task: BaseModel):
db_object = cls(**json.loads(task.json()))
db.add(db_object)
db.commit()
return db_object
@classmethod
def get_by_id(cls, db, id: int):
return db.query(cls).filter(cls.id == id).first()
def delete(self, db):
db.delete(self)
db.commit()
|
rendering:
show_root_heading: yes
show_root_full_path: no
Bases: Base
Source code in src/autolabel/data_models/task_run.py
| class TaskRunModel(Base):
__tablename__ = "task_runs"
id = Column(
Integer,
default=lambda: uuid.uuid4().int >> (128 - 32),
primary_key=True,
)
task_id = Column(String(32), ForeignKey("tasks.id"))
created_at = Column(DateTime(timezone=True), server_default=func.now())
dataset_id = Column(String(32), ForeignKey("datasets.id"))
current_index = Column(Integer)
error = Column(String(256))
metrics = Column(Text)
output_file = Column(String(50))
status = Column(String(50))
task = relationship("TaskModel", back_populates="task_runs")
dataset = relationship("DatasetModel", back_populates="task_runs")
annotations = relationship("AnnotationModel", back_populates="task_runs")
def __repr__(self):
return f"<TaskRunModel(id={self.id}, created_at={str(self.created_at)}, task_id={self.task_id}, dataset_id={self.dataset_id}, output_file={self.output_file}, current_index={self.current_index}, status={self.status}, error={self.error}, metrics={self.metrics})"
@classmethod
def create(cls, db, task_run: BaseModel):
logger.debug(f"creating new task: {task_run}")
db_object = cls(**task_run.dict())
db.add(db_object)
db.commit()
db.refresh(db_object)
logger.debug(f"created new task: {db_object}")
return db_object
@classmethod
def get(cls, db, task_id: str, dataset_id: str):
return (
db.query(cls)
.filter(cls.task_id == task_id, cls.dataset_id == dataset_id)
.first()
)
@classmethod
def from_pydantic(cls, task_run: BaseModel):
return cls(**json.loads(task_run.json()))
@classmethod
def update(cls, db, task_run: BaseModel):
task_run_id = task_run.id
task_run_orm = db.query(cls).filter(cls.id == task_run_id).first()
logger.debug(f"updating task_run: {task_run}")
for key, value in task_run.dict().items():
setattr(task_run_orm, key, value)
db.commit()
logger.debug(f"task_run updated: {task_run}")
return TaskRun.from_orm(task_run_orm)
@classmethod
def delete_by_id(cls, db, id: int):
db.query(cls).filter(cls.id == id).delete()
def delete(self, db):
db.delete(self)
db.commit()
|
rendering:
show_root_heading: yes
show_root_full_path: no
Source code in src/autolabel/database/state_manager.py
| class StateManager:
def __init__(self):
self.engine = create_db_engine()
self.base = Base
self.session = None
def initialize(self):
self.base.metadata.create_all(self.engine)
self.session = sessionmaker(bind=self.engine)()
def initialize_dataset(
self,
dataset: Union[str, pd.DataFrame],
config: AutolabelConfig,
start_index: int = 0,
max_items: Optional[int] = None,
):
# TODO: Check if this works for max_items = None
dataset_id = Dataset.create_id(dataset, config, start_index, max_items)
dataset_orm = DatasetModel.get_by_id(self.session, dataset_id)
if dataset_orm:
return Dataset.from_orm(dataset_orm)
dataset = Dataset(
id=dataset_id,
input_file=dataset if isinstance(dataset, str) else "",
start_index=start_index,
end_index=start_index + max_items if max_items else -1,
)
return Dataset.from_orm(DatasetModel.create(self.session, dataset))
def initialize_task(self, config: AutolabelConfig):
task_id = Task.create_id(config)
task_orm = TaskModel.get_by_id(self.session, task_id)
if task_orm:
return Task.from_orm(task_orm)
task = Task(
id=task_id,
config=config.to_json(),
task_type=config.task_type(),
provider=config.provider(),
model_name=config.model_name(),
)
return Task.from_orm(TaskModel.create(self.session, task))
def get_task_run(self, task_id: str, dataset_id: str):
task_run_orm = TaskRunModel.get(self.session, task_id, dataset_id)
if task_run_orm:
return TaskRun.from_orm(task_run_orm)
else:
return None
def create_task_run(
self, output_file: str, task_id: str, dataset_id: str
) -> TaskRun:
logger.debug(f"creating new task_run")
new_task_run = TaskRun(
task_id=task_id,
dataset_id=dataset_id,
status=TaskStatus.ACTIVE,
current_index=0,
output_file=output_file,
created_at=datetime.now(),
)
task_run_orm = TaskRunModel.create(self.session, new_task_run)
return TaskRun.from_orm(task_run_orm)
|
rendering:
show_root_heading: yes
show_root_full_path: no
Source code in src/autolabel/database/engine.py
| def create_db_engine(db_path: Optional[str] = DB_PATH) -> Engine:
global DB_ENGINE
if DB_ENGINE is None:
DB_ENGINE = create_engine(f"sqlite:///{db_path}", pool_size=0)
return DB_ENGINE
|
rendering:
show_root_heading: yes
show_root_full_path: no