Source code for datumaro.components.project

# Copyright (C) 2019-2023 Intel Corporation
# SPDX-License-Identifier: MIT

from __future__ import annotations

import logging as log
import os
import os.path as osp
import re
import shutil
import tempfile
import unittest.mock
from contextlib import ExitStack, suppress
from enum import Enum, auto
from typing import (

from datumaro.components.config import Config
from datumaro.components.config_model import (
from datumaro.components.dataset import DEFAULT_FORMAT, Dataset, IDataset
from datumaro.components.environment import Environment
from datumaro.components.errors import (
from datumaro.components.launcher import Launcher
from datumaro.util import find, parse_json_file, parse_str_enum_value
from datumaro.util.log_utils import catch_logs, logging_disabled
from datumaro.util.os_util import (
from datumaro.util.scope import on_error_do, scope_add, scoped

    import networkx as nx

    from datumaro.util.import_util import lazy_import

    nx = lazy_import("networkx")

[docs] class ProjectSourceDataset(IDataset): def __init__(self, path: str, tree: Tree, source: str, readonly: bool = False): config = tree.sources[source] rpath = path if config.path: rpath = osp.join(path, config.path) if "path" in config.options: rpath = osp.join(path, config.options.pop("path")) dataset = Dataset.import_from(rpath, env=tree.env, format=config.format, **config.options) # Using rpath won't allow to save directly with .save() when a file # path is specified. Dataset doesn't know the root location and if # it exists at all, but in a project, we do. dataset.bind(path, format=dataset.format, options=dataset.options) self.__dict__["_dataset"] = dataset self.__dict__["_config"] = config self.__dict__["_readonly"] = readonly self.__dict__["name"] = source
[docs] def save(self, save_dir=None, **kwargs): if self.readonly and ( save_dir is None or osp.abspath(save_dir) == osp.abspath(self.data_path) ): raise ReadonlyDatasetError(), **kwargs)
@property def readonly(self): return self._readonly or not self.is_bound @property def config(self): return self._config def __getattr__(self, name): return getattr(self._dataset, name) def __setattr__(self, name, value): return setattr(self._dataset, name, value) def __iter__(self): yield from self._dataset def __len__(self): return len(self._dataset)
[docs] def subsets(self): return self._dataset.subsets()
[docs] def get_subset(self, name): return self._dataset.get_subset(name)
[docs] def infos(self): return self._dataset.infos()
[docs] def categories(self): return self._dataset.categories()
[docs] def get(self, id, subset=None): return self._dataset.get(id, subset)
[docs] def media_type(self): return self._dataset.media_type()
[docs] def ann_types(self): return self._dataset.ann_types()
[docs] class IgnoreMode(Enum): rewrite = auto() append = auto() remove = auto()
def _update_ignore_file( paths: Union[str, List[str]], repo_root: str, filepath: str, mode: Union[None, str, IgnoreMode] = None, ): def _make_ignored_path(path): path = osp.join(repo_root, osp.normpath(path)) assert is_subpath(path, base=repo_root) # Prepend the '/' to match only direct childs. # Otherwise the rule can be in any path part. return "/" + osp.relpath(path, repo_root).replace("\\", "/") header = "# The file is autogenerated by Datumaro" mode = parse_str_enum_value(mode, IgnoreMode, IgnoreMode.append) if isinstance(paths, str): paths = [paths] paths = {osp.join(repo_root, osp.normpath(p)): _make_ignored_path(p) for p in paths} openmode = "r+" if not osp.isfile(filepath): openmode = "w+" # r+ cannot create, w truncates with open(filepath, openmode) as f: lines = [] if mode in {IgnoreMode.append, IgnoreMode.remove}: for line in f: lines.append(line.strip()) new_lines = [] for line in lines: if not line or line.startswith("#"): new_lines.append(line) continue line_path = osp.join( repo_root, osp.normpath(line.split("#", maxsplit=1)[0]).replace("\\", "/").lstrip("/"), ) if mode == IgnoreMode.append: if line_path in paths: paths.pop(line_path) new_lines.append(line) elif mode == IgnoreMode.remove: if line_path not in paths: new_lines.append(line) if mode in {IgnoreMode.rewrite, IgnoreMode.append}: new_lines.extend(paths.values()) if not new_lines or new_lines[0] != header: print(header, file=f) for line in new_lines: print(line, file=f) f.truncate() CrudEntry = TypeVar("CrudEntry") T = TypeVar("T")
[docs] class CrudProxy(Generic[CrudEntry]): @property def _data(self) -> Dict[str, CrudEntry]: raise NotImplementedError() def __len__(self): return len(self._data) def __getitem__(self, name: str) -> CrudEntry: return self._data[name]
[docs] def get( self, name: str, default: Union[None, T, CrudEntry] = None ) -> Union[None, T, CrudEntry]: return self._data.get(name, default)
def __iter__(self) -> Iterator[CrudEntry]: return iter(self._data.keys())
[docs] def items(self) -> Iterable[Tuple[str, CrudEntry]]: return iter(self._data.items())
def __contains__(self, name: str): return name in self._data
class _DataSourceBase(CrudProxy[Source]): def __init__(self, tree: Tree, config_field: str): self._tree = tree self._field = config_field @property def _data(self) -> Dict[str, Source]: return self._tree.config[self._field] def add(self, name: str, value: Union[Dict, Config, Source]) -> Source: if name in self: raise SourceExistsError(name) return self._data.set(name, value) def remove(self, name: str): self._data.remove(name)
[docs] class ProjectSources(_DataSourceBase): def __init__(self, tree: Tree): super().__init__(tree, "sources") def __getitem__(self, name): try: return super().__getitem__(name) except KeyError as e: raise KeyError("Unknown source '%s'" % name) from e
[docs] class BuildStageType(Enum): source = auto() project = auto() transform = auto() filter = auto() convert = auto() inference = auto() explore = auto()
[docs] class Pipeline: @staticmethod def _create_graph(config: PipelineConfig): graph = nx.DiGraph() for entry in config: target_name = entry["name"] parents = entry["parents"] target = BuildStage(entry["config"]) graph.add_node(target_name, config=target) for prev_stage in parents: graph.add_edge(prev_stage, target_name) return graph def __init__(self, config: PipelineConfig = None): self._head = None if config is not None: self._graph = self._create_craph(config) if not self.head: raise MissingPipelineHeadError() else: self._graph = nx.DiGraph() def __getattr__(self, key): return getattr(self._graph, key) @staticmethod def _find_head_node(graph) -> Optional[str]: head = None for node in graph.nodes: if graph.out_degree(node) == 0: if head is not None: raise MultiplePipelineHeadsError( "A pipeline can have only one " "main target, but it has at least 2: %s, %s" % (head, node) ) head = node return head @property def head(self) -> str: if self._head is None: self._head = self._find_head_node(self._graph) return self._head @property def head_node(self): return self._graph.nodes[self.head] @staticmethod def _serialize(graph) -> PipelineConfig: serialized = PipelineConfig() for node_name, node in graph.nodes.items(): serialized.nodes.append( { "name": node_name, "parents": list(graph.predecessors(node_name)), "config": dict(node["config"]), } ) return serialized @staticmethod def _get_subgraph(graph, target): """ Returns a subgraph with all the target dependencies and the target itself. """ return graph.subgraph(nx.ancestors(graph, target) | {target})
[docs] def get_slice(self, target) -> Pipeline: pipeline = Pipeline() pipeline._graph = self._get_subgraph(self._graph, target).copy() return pipeline
[docs] class ProjectBuilder: def __init__(self, project: Project, tree: Tree): self._project = project self._tree = tree
[docs] def make_dataset(self, pipeline: Pipeline) -> IDataset: dataset = self._get_resulting_dataset(pipeline) # TODO: May be need to save and load, because it can modify dataset, # unless we work with the internal format. For example, it can # add format-specific attributes. It should be needed as soon # format converting stages (export, convert, load) are allowed. # # TODO: If the target was rebuilt from sources, it may require saving # and hashing, so the resulting hash could be compared with the saved # one in the pipeline. This is needed to make sure the reproduced # version of the dataset is correct. Currently we only rely on the # initial source version check, which can be not enough if stages # produce different result (because of the library changes etc). # # save_in_cache(project, pipeline) # update and check hash in config! # dataset = load_dataset(project, pipeline) return dataset
def _run_pipeline(self, pipeline: Pipeline): self._validate_pipeline(pipeline) missing_sources, wd_hashes = self._find_missing_sources(pipeline) for source_name in missing_sources: source = self._tree.sources[source_name] if wd_hashes.get(source_name): raise ForeignChangesError( "Local source '%s' data does not " "match any previous source revision. Probably, the source " "was modified outside Datumaro. You can restore the " "latest source revision with 'checkout' command." % source_name ) if self._project.readonly: # Source re-downloading is prohibited in readonly projects # because it can seriously hurt free storage space. It must # be run manually, so that the user could know about this. "Skipping re-downloading missing source '%s', " "because the project is read-only. Automatic downloading " "is disabled in read-only projects.", source_name, ) continue if not source.hash: raise MissingSourceHashError( "Unable to re-download source " "'%s': the source was added with no hash information. " % source_name ) with self._project._make_tmp_dir() as tmp_dir: obj_hash, _, _ = self._project._download_source(source.url, tmp_dir) if source.hash and source.hash != obj_hash: raise MismatchingObjectError( "Downloaded source '%s' data is different " "from what is saved in the build pipeline: " "'%s' vs '%s'" % (source_name, obj_hash, source.hash) ) return self._init_pipeline(pipeline, working_dir_hashes=wd_hashes) def _get_resulting_dataset(self, pipeline): graph, head = self._run_pipeline(pipeline) return graph.nodes[head]["dataset"] def _init_pipeline(self, pipeline: Pipeline, working_dir_hashes=None): """ Initializes datasets in the pipeline nodes. Currently, only the head node will have a dataset on exit, so no extra memory is wasted for the intermediate nodes. """ def _join_parent_datasets(force=False): parents = {p: graph.nodes[p] for p in graph.predecessors(stage_name)} if 1 < len(parents) or force: try: dataset = Dataset.from_extractors( *(p["dataset"] for p in parents.values()), env=self._tree.env ) except DatasetMergeError as e: e.sources = set(parents) raise e else: dataset = next(iter(parents.values()))["dataset"] # clear fully utilized datasets to release memory for p_name, p in parents.items(): p["_use_count"] = p.get("_use_count", 0) + 1 if p_name != head and p["_use_count"] == graph.out_degree(p_name): p.pop("dataset") return dataset if working_dir_hashes is None: working_dir_hashes = {} def _try_load_from_disk(stage_name: str, stage_config: BuildStage) -> Dataset: # Check if we can restore this stage from the cache or # from the working directory. # # If we have a hash, we have executed this stage already # and can have a cache entry or, # if this is the last stage of a target in the working tree, # we can use data from the working directory. stage_hash = stage_config.hash data_dir = None cached = False source_name, source_stage_name = ProjectBuildTargets.split_target_name(stage_name) if self._tree.is_working_tree and source_name in self._tree.sources: target = self._tree.build_targets[source_name] data_dir = self._project.source_data_dir(source_name) wd_hash = working_dir_hashes.get(source_name) if not stage_hash: if source_stage_name == and osp.isdir(data_dir): pass else: log.debug( "Build: skipping loading stage '%s' from " "working dir '%s', because the stage has no hash " "and is not the head stage", stage_name, data_dir, ) data_dir = None elif not wd_hash: if osp.isdir(data_dir): wd_hash = self._project.compute_source_hash(data_dir) working_dir_hashes[source_name] = wd_hash else: log.debug( "Build: skipping checking working dir '%s', " "because it does not exist", data_dir, ) data_dir = None if stage_hash and stage_hash != wd_hash: log.debug( "Build: skipping loading stage '%s' from " "working dir '%s', because hashes do not match", stage_name, data_dir, ) data_dir = None if not data_dir and stage_hash: if self._project._is_cached(stage_hash): data_dir = self._project.cache_path(stage_hash) cached = True elif self._project._can_retrieve_from_vcs_cache(stage_hash): data_dir = self._project._materialize_obj(stage_hash) cached = True if not data_dir or not osp.isdir(data_dir): log.debug( "Build: skipping loading stage '%s' from " "cache obj '%s', because it is not available", stage_name, stage_hash, ) return None if data_dir: assert osp.isdir(data_dir), data_dir log.debug("Build: loading stage '%s' from '%s'", stage_name, data_dir) return ProjectSourceDataset( data_dir, self._tree, source_name, readonly=cached or self._project.readonly ) return None # Pipeline is assumed to be validated already graph = pipeline._graph head = pipeline.head # traverse the graph and initialize nodes from sources to the head to_visit = [head] while to_visit: stage_name = to_visit.pop() stage = graph.nodes[stage_name] stage_config = stage["config"] stage_type = BuildStageType[stage_config.type] stage_hash = stage_config.hash assert stage.get("dataset") is None dataset = _try_load_from_disk(stage_name, stage_config) if dataset is not None: stage["dataset"] = dataset continue uninitialized_parents = [] for p_name in graph.predecessors(stage_name): parent = graph.nodes[p_name] if parent.get("dataset") is None: uninitialized_parents.append(p_name) if uninitialized_parents: to_visit.append(stage_name) to_visit.extend(uninitialized_parents) continue if stage_type == BuildStageType.transform: kind = stage_config.kind try: transform = self._tree.env.transforms[kind] except KeyError as e: raise UnknownStageError("Unknown transform '%s'" % kind) from e dataset = _join_parent_datasets() dataset = dataset.transform(transform, **stage_config.params) elif stage_type == BuildStageType.filter: dataset = _join_parent_datasets() dataset = dataset.filter(**stage_config.params) elif stage_type == BuildStageType.inference: kind = stage_config.kind model = self._project.make_model(kind) dataset = _join_parent_datasets() dataset = dataset.run_model(model) elif stage_type == BuildStageType.source: # Stages of type "Source" cannot have inputs, # they are build tree inputs themselves assert graph.in_degree(stage_name) == 0, stage_name # The only valid situation we get here is that it is a # generated source: # - No cache entry # - No local dir data source_name = ProjectBuildTargets.strip_target_name(stage_name) source = self._tree.sources[source_name] if not source.is_generated: # Source is missing in the cache and the working tree, # and cannot be retrieved from the VCS cache. # It is assumed that all the missing sources were # downloaded earlier. raise MissingObjectError( "Failed to initialize stage '%s': " "object '%s' was not found in cache" % (stage_name, stage_hash) ) # Generated sources do not require a data directory, # but they still can be bound to a directory if self._tree.is_working_tree: source_dir = self._project.source_data_dir(source_name) else: source_dir = None dataset = ProjectSourceDataset( source_dir, self._tree, source_name, readonly=not source_dir or self._project.readonly, ) elif stage_type == BuildStageType.project: dataset = _join_parent_datasets(force=True) elif stage_type == BuildStageType.convert: dataset = _join_parent_datasets() else: raise UnknownStageError("Unexpected stage type '%s'" % stage_type) stage["dataset"] = dataset return graph, head @staticmethod def _validate_pipeline(pipeline: Pipeline): graph = pipeline._graph if ( len(graph) == 0 or len(graph) == 1 and next(iter(graph.nodes)) == ProjectBuildTargets.make_target_name( ProjectBuildTargets.MAIN_TARGET, ProjectBuildTargets.BASE_STAGE ) ): raise EmptyPipelineError() head = pipeline.head if not head: raise MissingPipelineHeadError() for stage_name, stage in graph.nodes.items(): stage_type = BuildStageType[stage["config"].type] if graph.in_degree(stage_name) == 0: if stage_type != BuildStageType.source: raise InvalidStageError( "Stage '%s' of type '%s' must have inputs" % (stage_name, ) else: if stage_type == BuildStageType.source: raise InvalidStageError( "Stage '%s' of type '%s' can't have inputs" % (stage_name, ) if graph.out_degree(stage_name) == 0: if stage_name != head: raise InvalidStageError( "Stage '%s' of type '%s' has no outputs, " "but is not the head stage" % (stage_name, ) def _find_missing_sources(self, pipeline: Pipeline): work_dir_hashes = {} def _can_retrieve(stage_name: str, stage_config: BuildStage): stage_hash = stage_config.hash source_name, source_stage_name = ProjectBuildTargets.split_target_name(stage_name) if self._tree.is_working_tree and source_name in self._tree.sources: target = self._tree.build_targets[source_name] data_dir = self._project.source_data_dir(source_name) if not stage_hash: return source_stage_name == and osp.isdir(data_dir) wd_hash = work_dir_hashes.get(source_name) if not wd_hash and osp.isdir(data_dir): wd_hash = self._project.compute_source_hash( self._project.source_data_dir(source_name) ) work_dir_hashes[source_name] = wd_hash if stage_hash and stage_hash == wd_hash: return True if stage_hash and self._project.is_obj_cached(stage_hash): return True return False missing_sources = set() checked_deps = set() unchecked_deps = [pipeline.head] while unchecked_deps: stage_name = unchecked_deps.pop() if stage_name in checked_deps: continue stage_config = pipeline._graph.nodes[stage_name]["config"] if not _can_retrieve(stage_name, stage_config): if pipeline._graph.in_degree(stage_name) == 0: assert stage_config.type == "source", stage_config.type source_name = self._tree.build_targets.strip_target_name(stage_name) source = self._tree.sources[source_name] if not source.is_generated: missing_sources.add(source_name) else: for p in pipeline._graph.predecessors(stage_name): if p not in checked_deps: unchecked_deps.append(p) continue checked_deps.add(stage_name) return missing_sources, work_dir_hashes
[docs] class ProjectBuildTargets(CrudProxy[BuildTarget]): MAIN_TARGET = "project" BASE_STAGE = "root" def __init__(self, tree: Tree): self._tree = tree @property def _data(self): data = self._tree.config.build_targets if self.MAIN_TARGET not in data: data[self.MAIN_TARGET] = { "stages": [ BuildStage( { "name": self.BASE_STAGE, "type":, } ), ] } for source in self._tree.sources: if source not in data: data[source] = { "stages": [ BuildStage( { "name": self.BASE_STAGE, "type":, } ), ] } return data def __contains__(self, key): if "." in key: target, stage = self.split_target_name(key) return target in self._data and self._data[target].find_stage(stage) is not None return key in self._data
[docs] def add_target(self, name) -> BuildTarget: return self._data.set( name, { "stages": [ BuildStage( { "name": self.BASE_STAGE, "type":, } ), ] }, )
[docs] def add_stage(self, target, value, prev=None, name=None) -> str: target_name = target target_stage_name = None if "." in target: target_name, target_stage_name = self.split_target_name(target) if prev is None: prev = target_stage_name target = self._data[target_name] if prev: prev_stage = find(enumerate(target.stages), lambda e: e[1].name == prev) if prev_stage is None: raise KeyError("Can't find stage '%s'" % prev) prev_stage = prev_stage[0] else: prev_stage = len(target.stages) - 1 name = value.get("name") or name if not name: name = generate_next_name( ( for s in target.stages), "stage", sep="-", default="1" ) else: if target.find_stage(name): raise VcsError("Stage '%s' already exists" % name) value["name"] = name value = BuildStage(value) assert value.type in BuildStageType.__members__ target.stages.insert(prev_stage + 1, value) return self.make_target_name(target_name, name)
[docs] def remove_target(self, name: str): assert name != self.MAIN_TARGET, "Can't remove the main target" self._data.remove(name)
[docs] def remove_stage(self, target: str, name: str): assert name not in {self.BASE_STAGE}, "Can't remove a default stage" target = self._data[target] idx = find(enumerate(target.stages), lambda e: e[1].name == name) if idx is None: raise KeyError("Can't find stage '%s'" % name) target.stages.remove(idx)
[docs] def add_transform_stage( self, target: str, transform: str, params: Optional[Dict] = None, name: Optional[str] = None ): if transform not in self._tree.env.transforms: raise KeyError("Unknown transform '%s'" % transform) return self.add_stage( target, { "type":, "kind": transform, "params": params or {}, }, name=name, )
[docs] def add_inference_stage( self, target: str, model: str, params: Optional[Dict] = None, name: Optional[str] = None ): if model not in self._tree._project.models: raise KeyError("Unknown model '%s'" % model) return self.add_stage( target, { "type":, "kind": model, "params": params or {}, }, name=name, )
[docs] def add_filter_stage( self, target: str, expr: str, params: Optional[Dict] = None, name: Optional[str] = None ): params = params or {} params["expr_or_filter_func"] = expr return self.add_stage( target, { "type":, "params": params, }, name=name, )
[docs] def add_convert_stage( self, target: str, format: str, params: Optional[Dict] = None, name: Optional[str] = None ): if not self._tree.env.is_format_known(format): raise KeyError("Unknown format '%s'" % format) return self.add_stage( target, { "type":, "kind": format, "params": params or {}, }, name=name, )
[docs] def add_explore_stage( self, target: str, params: Optional[Dict] = None, name: Optional[str] = None ): return self.add_stage( target, { "type":, "params": params or {}, }, name=name, )
[docs] @staticmethod def make_target_name(target: str, stage: Optional[str] = None) -> str: if stage: return "%s.%s" % (target, stage) return target
[docs] @classmethod def split_target_name(cls, name: str) -> Tuple[str, str]: if "." in name: target, stage = name.split(".", maxsplit=1) if not target: raise ValueError("Wrong build target name '%s': " "a name can't be empty" % name) if not stage: raise ValueError( "Wrong build target name '%s': " "expected stage name after the separator" % name ) else: target = name stage = cls.BASE_STAGE return target, stage
[docs] @classmethod def strip_target_name(cls, name: str) -> str: return cls.split_target_name(name)[0]
def _make_full_pipeline(self) -> Pipeline: pipeline = Pipeline() graph = pipeline._graph for target_name, target in self.items(): if target_name == self.MAIN_TARGET: # main target combines all the others prev_stages = [ self.make_target_name(n, for n, t in self.items() if n != self.MAIN_TARGET ] else: prev_stages = [self.make_target_name(t, self[t] for t in target.parents] for stage in target.stages: stage_name = self.make_target_name(target_name, stage["name"]) graph.add_node(stage_name, config=stage) for prev_stage in prev_stages: graph.add_edge(prev_stage, stage_name) prev_stages = [stage_name] return pipeline
[docs] def make_pipeline(self, target: str) -> Pipeline: if target not in self: raise UnknownTargetError(target) # a subgraph with all the target dependencies if "." not in target: target = self.make_target_name(target, self[target] return self._make_full_pipeline().get_slice(target)
[docs] class GitWrapper:
[docs] @staticmethod def module(): try: import git return git except ModuleNotFoundError as e: raise ModuleNotFoundError( "Can't import the 'git' package. " "Make sure GitPython is installed, or install it with " "'pip install datumaro[default]'." ) from e
def _git_dir(self): return osp.join(self._project_dir, ".git") def __init__(self, project_dir, repo=None): self._project_dir = project_dir self.repo = repo if repo is None and osp.isdir(project_dir) and osp.isdir(self._git_dir()): self.repo = self.module().Repo(project_dir) @property def initialized(self): return self.repo is not None
[docs] def init(self): if self.initialized: return repo = self.module().Repo.init(path=self._project_dir) repo.config_writer().set_value("user", "name", "User").set_value( "user", "email", "<>" ).release() # GitPython's init produces an incomplete repo, which becomes normal # only after a first commit. Unless the commit is done, some # GitPython's functions will throw useless errors. # Call "git init" directly to have the desired behaviour. repo.git.init() self.repo = repo
[docs] def close(self): if self.repo: self.repo.close() self.repo = None
def __del__(self): with suppress(Exception): self.close()
[docs] def checkout(self, ref: str, dst_dir=None, clean=False, force=False): # If user wants to navigate to a head, we need to supply its object # insted of just a string. Otherwise, we'll get a detached head. try: ref_obj = self.repo.heads[ref] except IndexError: ref_obj = ref commit = self.repo.commit(ref) tree = commit.tree if not dst_dir: dst_dir = self._project_dir repo_dir = osp.abspath(self._project_dir) dst_dir = osp.abspath(dst_dir) assert is_subpath(dst_dir, base=repo_dir) if not force: statuses = self.status(tree, base_dir=dst_dir) # Only modified files produce conflicts in checkout dst_rpath = osp.relpath(dst_dir, repo_dir) conflicts = [osp.join(dst_rpath, p) for p, s in statuses.items() if s == "M"] if conflicts: raise UnsavedChangesError(conflicts) self.repo.head.ref = ref_obj self.repo.head.reset(working_tree=False) if clean: rmtree(dst_dir) self.write_tree(tree, dst_dir)
[docs] def add(self, paths, base=None): """ Adds paths to index. Paths can be truncated relatively to base. """ path_rewriter = None if base: base = osp.abspath(base) repo_root = osp.abspath(self._project_dir) assert is_subpath(base, base=repo_root), "Base path should be inside of the repo" base = osp.relpath(base, repo_root) path_rewriter = lambda entry: osp.relpath(entry.path, base).replace("\\", "/") if isinstance(paths, str): paths = [paths] # A workaround for path_rewriter incompatibility # with directory paths expansion paths_to_add = [] for path in paths: if not osp.isdir(path): paths_to_add.append(path) continue for d, _, filenames in os.walk(path): for fn in filenames: paths_to_add.append(osp.join(d, fn)) self.repo.index.add(paths_to_add, path_rewriter=path_rewriter)
[docs] def commit(self, message) -> str: """ Creates a new revision from index. Returns: new revision hash. """ return self.repo.index.commit(message).hexsha
GitTree = NewType("GitTree", object) GitStatus = NewType("GitStatus", str)
[docs] def status( self, paths: Union[str, GitTree, Iterable[str]] = None, base_dir: str = None ) -> Dict[str, GitStatus]: """ Compares working directory and index. Parameters: paths: an iterable of paths to compare, a git.Tree, or None. When None, uses all the paths from HEAD. base_dir: a base path for paths. Paths will be prepended by this. When None or '', uses repo root. Can be useful, if index contains displaced paths, which needs to be mapped on real paths. The statuses are: - "A" for added paths - "D" for deleted paths - "R" for renamed paths - "M" for paths with modified data - "T" for changed in the type paths Returns: { abspath(base_dir + path): status } """ if paths is None or isinstance(paths, self.module().objects.tree.Tree): if paths is None: tree = self.repo.head.commit.tree else: tree = paths paths = (obj.path for obj in tree.traverse() if obj.type == "blob") elif isinstance(paths, str): paths = [paths] if not base_dir: base_dir = self._project_dir repo_dir = osp.abspath(self._project_dir) base_dir = osp.abspath(base_dir) assert is_subpath(base_dir, base=repo_dir) statuses = {} for obj_path in paths: file_path = osp.join(base_dir, obj_path) index_entry = self.repo.index.entries.get((obj_path, 0), None) file_exists = osp.isfile(file_path) if not file_exists and index_entry: status = "D" elif file_exists and not index_entry: status = "A" elif file_exists and index_entry: # '--ignore-cr-at-eol' doesn't affect '--name-status' # so we can't really obtain 'T' status = self.repo.git.diff("--ignore-cr-at-eol", index_entry.hexsha, file_path) if status: status = "M" assert status in {"", "M", "T"}, status else: status = "" # ignore missing paths if status: statuses[obj_path] = status return statuses
[docs] def is_ref(self, rev): try: self.repo.commit(rev) return True except (ValueError, self.module().exc.BadName): return False
[docs] def has_commits(self): return self.is_ref("HEAD")
[docs] def get_tree(self, ref): return self.repo.tree(ref)
[docs] def write_tree(self, tree, base_path: str, include_files: Optional[List[str]] = None): os.makedirs(base_path, exist_ok=True) for obj in tree.traverse(visit_once=True): if include_files and obj.path not in include_files: continue path = osp.join(base_path, obj.path) os.makedirs(osp.dirname(path), exist_ok=True) if obj.type == "blob": with open(path, "wb") as f: obj.stream_data(f) elif obj.type == "tree": pass else: raise ValueError( "Unexpected object type in a " "git tree: %s (%s)" % (obj.type, obj.hexsha) )
@property def head(self) -> str: return self.repo.head.commit.hexsha @property def branch(self) -> str: if self.repo.head.is_detached: return None return self.repo.active_branch
[docs] def rev_parse(self, ref: str) -> Tuple[str, str]: """ Expands named refs and tags. Returns: object type, object hash """ obj = self.repo.rev_parse(ref) return obj.type, obj.hexsha
[docs] def ignore( self, paths: Union[str, List[str]], mode: Union[None, str, IgnoreMode] = None, gitignore: Optional[str] = None, ): if not gitignore: gitignore = ".gitignore" repo_root = self._project_dir gitignore = osp.abspath(osp.join(repo_root, gitignore)) assert is_subpath(gitignore, base=repo_root), gitignore _update_ignore_file(paths, repo_root=repo_root, mode=mode, filepath=gitignore)
[docs] @classmethod def is_hash(cls, s: str) -> bool: return len(s) == cls.HASH_LEN
[docs] def log(self, depth=10) -> List[Tuple[Any, int]]: """ Returns: a list of (commit, index) pairs """ commits = [] if not self.has_commits(): return commits for commit in zip(self.repo.iter_commits(rev="HEAD"), range(depth)): commits.append(commit) return commits
[docs] class DvcWrapper:
[docs] @staticmethod def module(): try: import dvc import dvc.cli import dvc.env import dvc.repo return dvc except ModuleNotFoundError as e: raise ModuleNotFoundError( "Can't import the 'dvc' package. " "Make sure DVC is installed, or install it with " "'pip install datumaro[default]'." ) from e
def _dvc_dir(self): return osp.join(self._project_dir, ".dvc")
[docs] class DvcError(Exception): pass
def __init__(self, project_dir): self._project_dir = project_dir self.repo = None if osp.isdir(project_dir) and osp.isdir(self._dvc_dir()): with logging_disabled(): self.repo = self.module().repo.Repo(project_dir) @property def initialized(self): return self.repo is not None
[docs] def init(self): if self.initialized: return with logging_disabled(): self.repo = self.module().repo.Repo.init(self._project_dir) repo_dir = osp.join(self._project_dir, ".dvc") _update_ignore_file( [osp.join(repo_dir, "plots")], filepath=osp.join(repo_dir, ".gitignore"), repo_root=repo_dir, )
[docs] def close(self): if self.repo: self.repo.close() self.repo = None
def __del__(self): with suppress(Exception): self.close()
[docs] def checkout(self, targets=None): args = ["checkout"] if targets: if isinstance(targets, str): args.append(targets) else: args.extend(targets) self._exec(args)
[docs] def add(self, paths, no_commit=False): args = ["add"] if no_commit: args.append("--no-commit") if paths: if isinstance(paths, str): args.append(paths) else: args.extend(paths) self._exec(args)
def _exec(self, args, hide_output=True, answer_on_input="y"): args = ["--cd", self._project_dir] + args # Avoid calling an extra process. Improves call performance and # removes an extra console window on Windows. os.environ[self.module().env.DVC_NO_ANALYTICS] = "1" with ExitStack() as es: es.callback(os.chdir, os.getcwd()) # restore cd after DVC if answer_on_input is not None: def _input(*args): return answer_on_input es.enter_context(unittest.mock.patch("dvc.prompt.input", new=_input)) log.debug("Calling DVC main with args: %s", args) logs = es.enter_context(catch_logs("dvc")) retcode = self.module().cli.main(args) logs = logs.getvalue() if retcode != 0: raise self.DvcError(logs) if not hide_output: print(logs) return logs
[docs] def is_cached(self, obj_hash): path = self.obj_path(obj_hash) if not osp.isfile(path): return False if obj_hash.endswith(self.DIR_HASH_SUFFIX): objects = parse_json_file(path) for entry in objects: if not osp.isfile(self.obj_path(entry["md5"])): return False return True
[docs] def obj_path(self, obj_hash, root=None): assert self.is_hash(obj_hash), obj_hash if not root: root = osp.join(self._project_dir, ".dvc", "cache", "files", "md5") return osp.join(root, obj_hash[:2], obj_hash[2:])
[docs] def ignore( self, paths: Union[str, List[str]], mode: Union[None, str, IgnoreMode] = None, dvcignore: Optional[str] = None, ): if not dvcignore: dvcignore = ".dvcignore" repo_root = self._project_dir dvcignore = osp.abspath(osp.join(repo_root, dvcignore)) assert is_subpath(dvcignore, base=repo_root), dvcignore _update_ignore_file(paths, repo_root=repo_root, mode=mode, filepath=dvcignore)
# This ruamel parser is needed to preserve comments, # order and form (if multiple forms allowed by the standard) # of the entries in the file. It can be reused. import ruamel.yaml as yaml yaml_parser = yaml.YAML(typ="rt")
[docs] @classmethod def get_hash_from_dvcfile(cls, path) -> str: with open(path) as f: contents = cls.yaml_parser.load(f) return contents["outs"][0]["md5"]
[docs] @classmethod def is_file_hash(cls, s: str) -> bool: return len(s) == cls.FILE_HASH_LEN
[docs] @classmethod def is_dir_hash(cls, s: str) -> bool: return len(s) == cls.DIR_HASH_LEN and s.endswith(cls.DIR_HASH_SUFFIX)
[docs] @classmethod def is_hash(cls, s: str) -> bool: return cls.is_file_hash(s) or cls.is_dir_hash(s)
[docs] def write_obj(self, obj_hash, dst_dir, allow_links=True): def _copy_obj(src, dst, link=False): os.makedirs(osp.dirname(dst), exist_ok=True) if link:, dst) else: shutil.copy(src, dst, follow_symlinks=True) src = self.obj_path(obj_hash) if osp.isfile(src): _copy_obj(src, dst_dir, link=allow_links) return src += self.DIR_HASH_SUFFIX if not osp.isfile(src): raise UnknownRefError(obj_hash) src_meta = parse_json_file(src) for entry in src_meta: _copy_obj( self.obj_path(entry["md5"]), osp.join(dst_dir, entry["relpath"]), link=allow_links )
[docs] def remove_cache_obj(self, obj_hash: str): src = self.obj_path(obj_hash) if osp.isfile(src): rmfile(src) return src += self.DIR_HASH_SUFFIX if not osp.isfile(src): raise UnknownRefError(obj_hash) src_meta = parse_json_file(src) for entry in src_meta: entry_path = self.obj_path(entry["md5"]) if osp.isfile(entry_path): rmfile(entry_path) rmfile(src)
[docs] class Tree: # can be: # - attached to the work dir # - attached to a revision def __init__( self, project: Project, config: Union[None, Dict, Config, TreeConfig] = None, rev: Union[None, Revision] = None, ): assert isinstance(project, Project) assert not rev or project.is_ref(rev), rev if not isinstance(config, TreeConfig): config = TreeConfig(config) if config.format_version != 2: raise ValueError( "Unexpected tree config version '%s', expected 2" % config.format_version ) self._config = config self._project = project self._rev = rev self._sources = ProjectSources(self) self._targets = ProjectBuildTargets(self)
[docs] def save(self): self.dump(self._config.config_path)
[docs] def dump(self, path): os.makedirs(osp.dirname(path), exist_ok=True) self._config.dump(path)
[docs] def clone(self) -> Tree: return Tree(self._project, TreeConfig(self.config), self._rev)
@property def sources(self) -> ProjectSources: return self._sources @property def build_targets(self) -> ProjectBuildTargets: return self._targets @property def config(self) -> Config: return self._config @property def env(self) -> Environment: return self._project.env @property def rev(self) -> Union[None, Revision]: return self._rev
[docs] def make_pipeline(self, target: Optional[str] = None) -> Pipeline: if not target: target = "project" return self.build_targets.make_pipeline(target)
[docs] def make_dataset(self, target: Union[None, str, Pipeline] = None) -> Dataset: if not target or isinstance(target, str): pipeline = self.make_pipeline(target) elif isinstance(target, Pipeline): pipeline = target else: raise TypeError(f"Unexpected target type {type(target)}") return ProjectBuilder(self._project, self).make_dataset(pipeline)
@property def is_working_tree(self) -> bool: return not self._rev
[docs] def source_data_dir(self, source) -> str: if self.is_working_tree: return self._project.source_data_dir(source) obj_hash = self.build_targets[source].head.hash return self._project.cache_path(obj_hash)
[docs] class DiffStatus(Enum): added = auto() modified = auto() removed = auto() missing = auto() foreign_modified = auto()
Revision = NewType("Revision", str) # a commit hash or a named reference ObjectId = NewType("ObjectId", str) # a commit or an object hash
[docs] class Project:
[docs] @staticmethod def find_project_dir(path: str) -> Optional[str]: path = osp.abspath(path) if osp.basename(path) != ProjectLayout.aux_dir: path = osp.join(path, ProjectLayout.aux_dir) if osp.isdir(path): return path return None
[docs] @staticmethod @scoped def migrate_from_v1_to_v2(src_dir: str, dst_dir: str, skip_import_errors=False): if not osp.isdir(src_dir): raise FileNotFoundError("Source project is not found") if osp.exists(dst_dir): raise FileExistsError("Output path already exists") src_dir = osp.abspath(src_dir) dst_dir = osp.abspath(dst_dir) if src_dir == dst_dir: raise MigrationError( "Source and destination paths are the same. " "Project migration cannot be done inplace." ) old_aux_dir = osp.join(src_dir, ".datumaro") old_config = Config.parse(osp.join(old_aux_dir, "config.yaml")) if old_config.format_version != 1: raise MigrationError( "Failed to migrate project: " "unexpected old version '%s'" % old_config.format_version ) on_error_do(rmtree, dst_dir, ignore_errors=True) new_project = scope_add(Project.init(dst_dir)) new_wtree_dir = osp.join(new_project._aux_dir, ProjectLayout.working_tree_dir) os.makedirs(new_wtree_dir, exist_ok=True) old_plugins_dir = osp.join(old_aux_dir, "plugins") if osp.isdir(old_plugins_dir): copytree(old_plugins_dir, osp.join(new_project._aux_dir, ProjectLayout.plugins_dir)) old_models_dir = osp.join(old_aux_dir, "models") if osp.isdir(old_models_dir): copytree(old_models_dir, osp.join(new_project._aux_dir, ProjectLayout.models_dir)) new_project.env.load_plugins(osp.join(new_project._aux_dir, ProjectLayout.plugins_dir)) new_tree_config = new_project.working_tree.config new_local_config = new_project.config if "models" in old_config: for name, old_model in old_config.models.items(): new_local_config.models[name] = Model( {"launcher": old_model["launcher"], "options": old_model["options"]} ) if "sources" in old_config: for name, old_source in old_config.sources.items(): is_local = False source_dir = osp.join(src_dir, "sources", name) url = osp.abspath(osp.join(source_dir, old_source["url"])) rpath = None if osp.exists(url): if is_subpath(url, source_dir): if url != source_dir: rpath = osp.relpath(url, source_dir) url = source_dir is_local = True elif osp.isfile(url): url, rpath = osp.split(url) elif not old_source["url"]: url = "" try: source = new_project.import_source( name, url=url, rpath=rpath, format=old_source["format"], options=old_source["options"], ) if is_local: source.url = "" new_project.working_tree.make_dataset(name) except Exception as e: if not skip_import_errors: raise MigrationError(f"Failed to migrate the source '{name}'") from e else: log.warning( f"Failed to migrate the source '{name}'. " "Try to add this source manually with " "'datum project import', once migration is finished. The " "reason is: %s", e, ) new_project.remove_source(name, force=True, keep_data=False) old_dataset_dir = osp.join(src_dir, "dataset") if osp.isdir(old_dataset_dir): # Such source cannot be represented in v2 directly. # However, it can be considered a generated source with # working tree data. name = generate_next_name( list(new_tree_config.sources), "local_dataset", sep="-", default="1" ) source = new_project.import_source(name, url=old_dataset_dir, format=DEFAULT_FORMAT) # Make the source generated. It can only have local data. source.url = "" new_project.close()
def __init__(self, path: Optional[str] = None, readonly=False): if not path: path = osp.curdir found_path = self.find_project_dir(path) if not found_path: raise ProjectNotFoundError(path) old_config_path = osp.join(found_path, "config.yaml") if osp.isfile(old_config_path): if Config.parse(old_config_path).format_version != 2: raise OldProjectError() self._aux_dir = found_path self._root_dir = osp.dirname(found_path) self._readonly = readonly # Force import errors on missing dependencies. # # TODO: maybe allow class use in some cases, which not require # Git or DVC GitWrapper.module() DvcWrapper.module() self._git = GitWrapper(self._root_dir) self._dvc = DvcWrapper(self._root_dir) self._working_tree = None self._head_tree = None local_config = osp.join(self._aux_dir, ProjectLayout.conf_file) if osp.isfile(local_config): self._config = ProjectConfig.parse(local_config) else: self._config = ProjectConfig() self._env = Environment() plugins_dir = osp.join(self._aux_dir, ProjectLayout.plugins_dir) if osp.isdir(plugins_dir): self._env.load_plugins(plugins_dir) def _init_vcs(self): # DVC requires Git to be initialized if not self._git.initialized: self._git.init() self._git.ignore( [ ProjectLayout.cache_dir, ], gitignore=osp.join(self._aux_dir, ".gitignore"), ) self._git.ignore([]) # create the file if not self._dvc.initialized: self._dvc.init() self._dvc.ignore( [ osp.join(self._aux_dir, ProjectLayout.cache_dir), osp.join(self._aux_dir, ProjectLayout.working_tree_dir), ] ) self._git.repo.index.remove( osp.join(self._root_dir, ".dvc", "plots"), r=True, ignore_unmatch=True ) self.commit("Initial commit", allow_empty=True)
[docs] @classmethod @scoped def init(cls, path) -> Project: existing_project = cls.find_project_dir(path) if existing_project: raise ProjectAlreadyExists(path) path = osp.abspath(path) if osp.basename(path) != ProjectLayout.aux_dir: path = osp.join(path, ProjectLayout.aux_dir) project_dir = osp.dirname(path) if not osp.isdir(project_dir): on_error_do(rmtree, project_dir, ignore_errors=True) os.makedirs(path, exist_ok=True) on_error_do(rmtree, osp.join(project_dir, ProjectLayout.cache_dir), ignore_errors=True) on_error_do(rmtree, osp.join(project_dir, ProjectLayout.tmp_dir), ignore_errors=True) os.makedirs(osp.join(path, ProjectLayout.cache_dir)) os.makedirs(osp.join(path, ProjectLayout.tmp_dir)) git_dir, dvc_dir = osp.join(project_dir, ".git"), osp.join(project_dir, ".dvc") if osp.exists(git_dir): raise VcsAlreadyExists(git_dir) if osp.exists(dvc_dir): raise VcsAlreadyExists(dvc_dir) on_error_do(rmtree, git_dir, ignore_errors=True) on_error_do(rmtree, dvc_dir, ignore_errors=True) project = Project(path) project._init_vcs() return project
[docs] def close(self): if self._dvc: self._dvc.close() self._dvc = None if self._git: self._git.close() self._git = None
def __del__(self): with suppress(Exception): self.close() def __enter__(self): return self def __exit__(self, *args, **kwargs): self.close()
[docs] def save(self): self._config.dump(osp.join(self._aux_dir, ProjectLayout.conf_file)) if self._working_tree:
@property def readonly(self) -> bool: return self._readonly @property def working_tree(self) -> Tree: if self._working_tree is None: self._working_tree = self.get_rev(None) return self._working_tree @property def head(self) -> Tree: if self._head_tree is None: self._head_tree = self.get_rev("HEAD") return self._head_tree @property def head_rev(self) -> Revision: return self._git.head @property def branch(self) -> str: return self._git.branch @property def config(self) -> Config: return self._config @property def env(self) -> Environment: return self._env @property def models(self) -> Dict[str, Model]: return dict(self._config.models)
[docs] def get_rev(self, rev: Union[None, Revision]) -> Tree: """ Reference conventions: - None or "" - working dir - "<40 symbols>" - revision hash """ obj_type, obj_hash = self._parse_ref(rev) assert obj_type == self._ObjectIdKind.tree, obj_type if self._is_working_tree_ref(obj_hash): config_path = osp.join( self._aux_dir, ProjectLayout.working_tree_dir, TreeLayout.conf_file ) if osp.isfile(config_path): tree_config = TreeConfig.parse(config_path) else: tree_config = TreeConfig() os.makedirs(osp.dirname(config_path), exist_ok=True) tree_config.dump(config_path) tree_config.config_path = config_path tree_config.base_dir = osp.dirname(config_path) tree = Tree(config=tree_config, project=self, rev=obj_hash) else: if not self.is_rev_cached(obj_hash): self._materialize_rev(obj_hash) rev_dir = self.cache_path(obj_hash) tree_config = TreeConfig.parse(osp.join(rev_dir, TreeLayout.conf_file)) tree_config.base_dir = rev_dir tree = Tree(config=tree_config, project=self, rev=obj_hash) return tree
[docs] def is_rev_cached(self, rev: Revision) -> bool: obj_type, obj_hash = self._parse_ref(rev) assert obj_type == self._ObjectIdKind.tree, obj_type return self._is_cached(obj_hash)
[docs] def is_obj_cached(self, obj_hash: ObjectId) -> bool: return self._is_cached(obj_hash) or self._can_retrieve_from_vcs_cache(obj_hash)
@staticmethod def _is_working_tree_ref(ref: Union[None, Revision, ObjectId]) -> bool: return not ref class _ObjectIdKind(Enum): # Project revision data. Currently, a Git commit hash. tree = auto() # Source revision data. DVC directories and files. blob = auto() def _parse_ref(self, ref: Union[None, Revision, ObjectId]) -> Tuple[_ObjectIdKind, ObjectId]: """ Resolves the reference to an object hash. """ if self._is_working_tree_ref(ref): return self._ObjectIdKind.tree, ref try: obj_type, obj_hash = self._git.rev_parse(ref) except Exception: # nosec try_except_pass pass # Ignore git errors else: if obj_type != "commit": raise UnknownRefError(obj_hash) return self._ObjectIdKind.tree, obj_hash try: assert self._dvc.is_hash(ref), ref return self._ObjectIdKind.blob, ref except Exception as e: raise UnknownRefError(ref) from e def _materialize_rev(self, rev: Revision) -> str: """ Restores the revision tree data in the project cache from Git. Returns: cache object path """ # TODO: maybe avoid this operation by providing a virtual filesystem # object # Allowed to be run when readonly, because it doesn't modify project # data and doesn't hurt disk space. obj_dir = self.cache_path(rev) if osp.isdir(obj_dir): return obj_dir tree = self._git.get_tree(rev) self._git.write_tree(tree, obj_dir) return obj_dir def _is_cached(self, obj_hash: ObjectId): return osp.isdir(self.cache_path(obj_hash))
[docs] def cache_path(self, obj_hash: ObjectId) -> str: assert self._git.is_hash(obj_hash) or self._dvc.is_hash(obj_hash), obj_hash if self._dvc.is_dir_hash(obj_hash): obj_hash = obj_hash[: self._dvc.FILE_HASH_LEN] return osp.join(self._aux_dir, ProjectLayout.cache_dir, obj_hash[:2], obj_hash[2:])
def _can_retrieve_from_vcs_cache(self, obj_hash: ObjectId): if not self._dvc.is_dir_hash(obj_hash): dir_check = self._dvc.is_cached(obj_hash + self._dvc.DIR_HASH_SUFFIX) else: dir_check = False return dir_check or self._dvc.is_cached(obj_hash)
[docs] def source_data_dir(self, name: str) -> str: return osp.join(self._root_dir, name)
def _source_dvcfile_path(self, name: str, root: Optional[str] = None) -> str: """ root - Path to the tree root directory. If not set, the working tree is used. """ if not root: root = osp.join(self._aux_dir, ProjectLayout.working_tree_dir) return osp.join(root, TreeLayout.sources_dir, name, "source.dvc") def _make_tmp_dir(self, suffix: Optional[str] = None): project_tmp_dir = osp.join(self._aux_dir, ProjectLayout.tmp_dir) os.makedirs(project_tmp_dir, exist_ok=True) if suffix: suffix = "_" + suffix return tempfile.TemporaryDirectory(suffix=suffix, dir=project_tmp_dir)
[docs] def remove_cache_obj(self, ref: Union[Revision, ObjectId]): if self.readonly: raise ReadonlyProjectError() obj_type, obj_hash = self._parse_ref(ref) if self._is_cached(obj_hash): rmtree(self.cache_path(obj_hash)) if obj_type == self._ObjectIdKind.tree: # Revision metadata is cheap enough and needed to materialize # the revision, so we keep it in the Git cache. pass elif obj_type == self._ObjectIdKind.blob: self._dvc.remove_cache_obj(obj_hash) else: raise ValueError("Unexpected object type '%s'" % obj_type)
[docs] def validate_source_name(self, name: str): if not name: raise ValueError("Source name cannot be empty") disallowed_symbols = r"[^\\ \.\~\-\w]" found_wrong_symbols = re.findall(disallowed_symbols, name) if found_wrong_symbols: raise ValueError("Source name contains invalid symbols: %s" % found_wrong_symbols) valid_filename = make_file_name(name) if valid_filename != name: raise ValueError( "Source name contains " "invalid symbols: %s" % (set(name) - set(valid_filename)) ) if name.startswith("."): raise ValueError("Source name can't start with '.'") reserved_names = {"dataset", "build", "project"} if name.lower() in reserved_names: raise ValueError("Source name is reserved for internal use")
@scoped def _download_source( self, url: str, dst_dir: str, *, no_cache: bool = False, no_hash: bool = False ) -> Tuple[str, str, str]: assert url assert dst_dir dvcfile = osp.join(dst_dir, "source.dvc") data_dir = osp.join(dst_dir, "data") log.debug(f"Copying from '{url}' to '{data_dir}'") if osp.isdir(url): copytree(url, data_dir) elif osp.isfile(url): os.makedirs(data_dir, exist_ok=True) shutil.copy(url, data_dir) else: raise UnexpectedUrlError(url) on_error_do(rmtree, data_dir, ignore_errors=True) log.debug("Done") if not no_hash: obj_hash = self.compute_source_hash(data_dir, dvcfile=dvcfile, no_cache=no_cache) if not no_cache: log.debug("Data is added to DVC cache") log.debug("Data hash: '%s'", obj_hash) else: obj_hash = "" return obj_hash, dvcfile, data_dir @staticmethod def _get_source_hash(dvcfile): obj_hash = DvcWrapper.get_hash_from_dvcfile(dvcfile) if obj_hash.endswith(DvcWrapper.DIR_HASH_SUFFIX): obj_hash = obj_hash[: -len(DvcWrapper.DIR_HASH_SUFFIX)] return obj_hash
[docs] @scoped def compute_source_hash( self, data_dir: str, dvcfile: Optional[str] = None, no_cache: bool = True, ) -> ObjectId: if not dvcfile: tmp_dir = scope_add(self._make_tmp_dir()) dvcfile = osp.join(tmp_dir, "source.dvc") self._dvc.add(data_dir, no_commit=no_cache) gen_dvcfile = osp.join(self._root_dir, data_dir + ".dvc") if os.path.isfile(gen_dvcfile): shutil.move(gen_dvcfile, dvcfile) obj_hash = self._get_source_hash(dvcfile) return obj_hash
[docs] def refresh_source_hash(self, source: str, no_cache: bool = True) -> ObjectId: """ Computes and updates the source hash in the working directory. Returns: hash """ if self.readonly: raise ReadonlyProjectError() build_target = self.working_tree.build_targets[source] source_dir = self.source_data_dir(source) if not osp.isdir(source_dir): return None dvcfile = self._source_dvcfile_path(source) os.makedirs(osp.dirname(dvcfile), exist_ok=True) obj_hash = self.compute_source_hash(source_dir, dvcfile=dvcfile, no_cache=no_cache) build_target.head.hash = obj_hash if not build_target.has_stages: self.working_tree.sources[source].hash = obj_hash return obj_hash
def _materialize_obj(self, obj_hash: ObjectId) -> str: """ Restores the object data in the project cache from DVC. Returns: cache object path """ # TODO: maybe avoid this operation by providing a virtual filesystem # object # Allowed to be run when readonly, because it shouldn't hurt disk # space, if object is materialized with symlinks. if not self._can_retrieve_from_vcs_cache(obj_hash): raise MissingObjectError(obj_hash) dst_dir = self.cache_path(obj_hash) if osp.isdir(dst_dir): return dst_dir self._dvc.write_obj(obj_hash, dst_dir, allow_links=True) return dst_dir
[docs] @scoped def import_source( self, name: str, url: Optional[str], format: str, options: Optional[Dict] = None, *, no_cache: bool = True, no_hash: bool = True, rpath: Optional[str] = None, ) -> Source: """ Adds a new source (dataset) to the working directory of the project. When 'rpath' is specified, will copy all the data from URL, but read only the specified file. Required to support subtasks and subsets in datasets. Parameters: name (str): Name of the new source url (str): URL of the new source. A path to a file or directory format (str): Dataset format options (dict): Options for the format Extractor no_cache (bool): Don't put a copy of files into the project cache. Can be used to reduce project cache size. no_hash (bool): Don't compute source data hash. Implies "no_cache". Useful to reduce import time at the cost of disabled data integrity checks. rpath (str): Used to specify a relative path to the dataset inside of the directory pointed by URL. Returns: the new source config """ if self.readonly: raise ReadonlyProjectError() self.validate_source_name(name) if name in self.working_tree.sources: raise SourceExistsError(name) data_dir = self.source_data_dir(name) if osp.exists(data_dir): if os.listdir(data_dir): raise FileExistsError("Source directory '%s' already " "exists" % data_dir) os.rmdir(data_dir) if url: url = osp.abspath(url) if not osp.exists(url): raise FileNotFoundError(url) if is_subpath(url, base=self._root_dir): raise SourceUrlInsideProjectError() if rpath: rpath = osp.normpath(osp.join(url, rpath)) if not osp.exists(rpath): raise FileNotFoundError(rpath) if not is_subpath(rpath, base=url): raise PathOutsideSourceError( "Source data path is outside of the directory, " "specified by source URL: '%s', '%s'" % (rpath, url) ) rpath = osp.relpath(rpath, url) elif osp.isfile(url): rpath = osp.basename(url) else: rpath = None if no_hash: no_cache = True config = Source( { "url": (url or "").replace("\\", "/"), "path": (rpath or "").replace("\\", "/"), "format": format, "options": options or {}, } ) if not config.is_generated: dvcfile = self._source_dvcfile_path(name) os.makedirs(osp.dirname(dvcfile), exist_ok=True) with self._make_tmp_dir() as tmp_dir: obj_hash, tmp_dvcfile, tmp_data_dir = self._download_source( url, tmp_dir, no_cache=no_cache, no_hash=no_hash ) shutil.move(tmp_data_dir, data_dir) on_error_do(rmtree, data_dir) if not no_hash: os.replace(tmp_dvcfile, dvcfile) config["hash"] = obj_hash self._git.ignore([data_dir]) config = self.working_tree.sources.add(name, config) target = self.working_tree.build_targets.add_target(name) target.root.hash = config.hash return config
[docs] @scoped def add_source( self, path: str, format: str, options: Optional[Dict] = None, *, rpath: Optional[str] = None ) -> Tuple[str, Source]: """ Adds a new source (dataset) from the working directory of the project. Only directories from the project root can be added. This command is useful after a source was removed and you need to re-add it, or when the dataset was copied or downloaded manually. When 'rpath' is specified, will copy all the data from URL, but read only the specified file. Required to support subtasks and subsets in datasets. Parameters: url (str): URL of the new source. A path to a directory format (str): Dataset format options (dict): Options for the format Extractor rpath (str): Used to specify a relative path to the dataset inside of the directory pointed by URL. Returns: the name and the config of the new source """ if self.readonly: raise ReadonlyProjectError() if not path: raise ValueError("Source path cannot be empty") path = osp.abspath(path) name = osp.basename(path) self.validate_source_name(name) if name in self.working_tree.sources: raise SourceExistsError(name) if not osp.isdir(path): raise FileNotFoundError("Source directory '%s' is not found" % path) if not (is_subpath(path, base=self._root_dir) and osp.dirname(path) == self._root_dir): raise UnexpectedUrlError( "The source path is expected to be " "a directory in the project root" ) if rpath: rpath = osp.normpath(osp.join(path, rpath)) if not osp.exists(rpath): raise FileNotFoundError(rpath) if not is_subpath(rpath, base=path): raise PathOutsideSourceError( "Source data path is outside of the directory, " "specified by source URL: '%s', '%s'" % (rpath, path) ) rpath = osp.relpath(rpath, path) else: rpath = None self._git.ignore([path]) config = self.working_tree.sources.add( name, { "url": (path or "").replace("\\", "/"), "path": (rpath or "").replace("\\", "/"), "format": format, "options": options or {}, }, ) self.working_tree.build_targets.add_target(name) return name, config
[docs] def remove_source(self, name: str, *, force: bool = False, keep_data: bool = True): """ Options: - force (bool) - ignores errors and tries to wipe remaining data - keep_data (bool) - leaves source data untouched """ if self.readonly: raise ReadonlyProjectError() if name not in self.working_tree.sources and not force: raise UnknownSourceError(name) self.working_tree.sources.remove(name) data_dir = self.source_data_dir(name) if not keep_data: if osp.isdir(data_dir): rmtree(data_dir) dvcfile = self._source_dvcfile_path(name) if osp.isfile(dvcfile): try: rmfile(dvcfile) except Exception: if not force: raise self.working_tree.build_targets.remove_target(name) self._git.ignore([data_dir], mode="remove")
[docs] def commit( self, message: str, *, no_cache: bool = False, allow_empty: bool = False, allow_foreign: bool = False, ) -> Revision: """ Copies tree and objects from the working dir to the cache. Creates a new commit. Moves the HEAD pointer to the new commit. Options: - no_cache (bool) - don't put added dataset data into cache, store only metainfo. Can be used to reduce storage size. - allow_empty (bool) - allow commits with no changes. - allow_foreign (bool) - allow commits with changes made not by Datumaro. Returns: the new commit hash """ if self.readonly: raise ReadonlyProjectError() statuses = self.status() if not allow_empty and not statuses: raise EmptyCommitError() for t, s in statuses.items(): if s == DiffStatus.foreign_modified: # TODO: compute a patch and a new stage, remove allow_foreign if allow_foreign: log.warning( "The source '%s' has been changed " "without Datumaro. It will be saved, but it will " "only be available for reproduction from the cache.", t, ) else: raise ForeignChangesError( "The source '%s' is changed outside Datumaro. You can " "restore the latest source revision with 'checkout' " "command." % t ) for s in self.working_tree.sources: self.refresh_source_hash(s, no_cache=no_cache) wtree_dir = osp.join(self._aux_dir, ProjectLayout.working_tree_dir) self._git.add(wtree_dir, base=wtree_dir) extra_files = [ osp.join(self._root_dir, ".dvc", ".gitignore"), osp.join(self._root_dir, ".dvc", "config"), osp.join(self._root_dir, ".dvcignore"), osp.join(self._root_dir, ".gitignore"), osp.join(self._aux_dir, ".gitignore"), ] self._git.add(extra_files, base=self._root_dir) head = self._git.commit(message) rev_dir = self.cache_path(head) copytree(wtree_dir, rev_dir) for p in extra_files: if osp.isfile(p): dst_path = osp.join(rev_dir, osp.relpath(p, self._root_dir)) os.makedirs(osp.dirname(dst_path), exist_ok=True) shutil.copyfile(p, dst_path) self._head_tree = None return head
@staticmethod def _move_dvc_dir(src_dir, dst_dir): for name in {"config", ".gitignore"}: os.replace(osp.join(src_dir, name), osp.join(dst_dir, name))
[docs] def checkout( self, rev: Union[None, Revision] = None, sources: Union[None, str, Iterable[str]] = None, *, force: bool = False, ): """ Copies tree and objects from the cache to the working tree. Sets HEAD to the specified revision, unless sources specified. When sources specified, only copies objects from the cache to the working tree. When no revision and no sources is specified, restores the sources from the current revision. By default, uses the current (HEAD) revision. Options: - force (bool) - ignore unsaved changes. By default, an error is raised """ if self.readonly: raise ReadonlyProjectError() if isinstance(sources, str): sources = {sources} elif sources is None: sources = {} else: sources = set(sources) rev = rev or "HEAD" if sources: rev_tree = self.get_rev(rev) # Check targets for s in sources: if s not in rev_tree.sources: raise UnknownSourceError(s) rev_dir = rev_tree.config.base_dir with self._make_tmp_dir() as tmp_dir: dvcfiles = [] for s in sources: dvcfile = self._source_dvcfile_path(s, root=rev_dir) tmp_dvcfile = osp.join(tmp_dir, s + ".dvc") with open(dvcfile) as f: conf = self._dvc.yaml_parser.load(f) conf["wdir"] = self._root_dir with open(tmp_dvcfile, "w") as f: self._dvc.yaml_parser.dump(conf, f) dvcfiles.append(tmp_dvcfile) self._dvc.checkout(dvcfiles) self._git.ignore(sources) for s in sources: self.working_tree.config.sources[s] = rev_tree.config.sources[s] self.working_tree.config.build_targets[s] = rev_tree.config.build_targets[s] else: # Check working tree for unsaved changes, # set HEAD to the revision # write revision tree to working tree wtree_dir = osp.join(self._aux_dir, ProjectLayout.working_tree_dir) self._git.checkout(rev, dst_dir=wtree_dir, clean=True, force=force) self._move_dvc_dir(osp.join(wtree_dir, ".dvc"), osp.join(self._root_dir, ".dvc")) self._working_tree = None # Restore sources from the commit. # Work with the working tree instead of cache, to # avoid extra memory use from materializing # the head commit sources in the cache rev_tree = self.working_tree with self._make_tmp_dir() as tmp_dir: dvcfiles = [] for s in rev_tree.sources: dvcfile = self._source_dvcfile_path(s) tmp_dvcfile = osp.join(tmp_dir, s + ".dvc") with open(dvcfile) as f: conf = self._dvc.yaml_parser.load(f) conf["wdir"] = self._root_dir with open(tmp_dvcfile, "w") as f: self._dvc.yaml_parser.dump(conf, f) dvcfiles.append(tmp_dvcfile) self._dvc.checkout(dvcfiles) os.replace(osp.join(wtree_dir, ".gitignore"), osp.join(self._root_dir, ".gitignore")) os.replace(osp.join(wtree_dir, ".dvcignore"), osp.join(self._root_dir, ".dvcignore")) self._working_tree = None
[docs] def is_ref(self, ref: Union[None, str]) -> bool: if self._is_working_tree_ref(ref): return True return self._git.is_ref(ref)
[docs] def has_commits(self) -> bool: return self._git.has_commits()
[docs] def status(self) -> Dict[str, DiffStatus]: wd = self.working_tree if not self.has_commits(): return {s: DiffStatus.added for s in wd.sources} head = self.head changed_targets = {} for t_name, wd_target in wd.build_targets.items(): if t_name == ProjectBuildTargets.MAIN_TARGET: continue if osp.isdir(self.source_data_dir(t_name)): old_hash = wd_target.head.hash new_hash = self.compute_source_hash(t_name, no_cache=True) if old_hash and old_hash != new_hash: changed_targets[t_name] = DiffStatus.foreign_modified for t_name in set(head.build_targets) | set(wd.build_targets): if t_name == ProjectBuildTargets.MAIN_TARGET: continue if t_name in changed_targets: continue head_target = head.build_targets.get(t_name) wd_target = wd.build_targets.get(t_name) status = None if head_target is None: status = DiffStatus.added elif wd_target is None: status = DiffStatus.removed else: if head_target != wd_target: status = DiffStatus.modified elif not osp.isdir(self.source_data_dir(t_name)): status = DiffStatus.missing if status: changed_targets[t_name] = status return changed_targets
[docs] def history(self, max_count=10) -> List[Tuple[Revision, str]]: return [(c.hexsha, c.message) for c, _ in self._git.log(max_count)]
[docs] def diff( self, rev_a: Union[Tree, Revision], rev_b: Union[Tree, Revision] ) -> Dict[str, DiffStatus]: """ Compares 2 revision trees. Returns: { target_name: status } for changed targets """ if rev_a == rev_b: return {} if isinstance(rev_a, str): tree_a = self.get_rev(rev_a) else: tree_a = rev_a if isinstance(rev_b, str): tree_b = self.get_rev(rev_b) else: tree_b = rev_b changed_targets = {} for t_name in set(tree_a.build_targets) | set(tree_b.build_targets): if t_name == ProjectBuildTargets.MAIN_TARGET: continue head_target = tree_a.build_targets.get(t_name) wd_target = tree_b.build_targets.get(t_name) status = None if head_target is None: status = DiffStatus.added elif wd_target is None: status = DiffStatus.removed else: if head_target != wd_target: status = DiffStatus.modified if status: changed_targets[t_name] = status return changed_targets
[docs] def model_data_dir(self, name: str) -> str: return osp.join(self._aux_dir, ProjectLayout.models_dir, name)
[docs] def make_model(self, name: str) -> Launcher: model = self._config.models[name] model_dir = self.model_data_dir(name) if not osp.isdir(model_dir): model_dir = None return self._env.make_launcher(model.launcher, **model.options, model_dir=model_dir)
[docs] def add_model(self, name: str, launcher: str, options: Dict[str, Any] = None) -> Model: if self.readonly: raise ReadonlyProjectError() if launcher not in self.env.launchers: raise KeyError("Unknown launcher '%s'" % launcher) if not name: raise ValueError("Model name can't be empty") if name in self.models: raise KeyError("Model '%s' already exists" % name) return self._config.models.set(name, {"launcher": launcher, "options": options or {}})
[docs] def remove_model(self, name: str): if self.readonly: raise ReadonlyProjectError() if name not in self.models: raise KeyError("Unknown model '%s'" % name) self._config.models.remove(name) data_dir = self.model_data_dir(name) if osp.isdir(data_dir): rmtree(data_dir)