Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
},
"readonly": {
"arg": Arg(
type=bool,
action="store_true",
default=False,
),
"config": {},
},
"label": {
"arg": Arg(type=str, default="unlabeled"),
"config": {},
},
"source": {
"arg": Arg(
type=BaseSource.load,
default=JSONSource,
),
"config": {},
},
if maintained is not None and maintained[0] is not None:
repo.evaluated({"maintained": str(maintained[0])})
return repo
async def __aenter__(self) -> "DemoAppSourceContext":
self.__conn = self.parent.db.cursor()
self.conn = await self.__conn.__aenter__()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self.__conn.__aexit__(exc_type, exc_value, traceback)
await self.parent.db.commit()
@entry_point("demoapp")
class DemoAppSource(BaseSource):
CONTEXT = DemoAppSourceContext
async def __aenter__(self) -> "DemoAppSource":
self.pool = await aiomysql.create_pool(
host=self.config.host,
port=self.config.port,
user=self.config.user,
password=self.config.password,
db=self.config.db,
)
self.__db = self.pool.acquire()
self.db = await self.__db.__aenter__()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
if features is not None:
repo.evaluated(features)
# Get prediction
prediction = await db.execute(
"SELECT * FROM prediction WHERE " "src_url=?", (repo.src_url,)
)
prediction = await prediction.fetchone()
if prediction is not None:
repo.predicted(prediction["value"], prediction["confidence"])
return repo
async def __aexit__(self, exc_type, exc_value, traceback):
await self.parent.db.commit()
class CustomSQLiteSource(BaseSource):
CONTEXT = CustomSQLiteSourceContext
FEATURE_COLS = ["PetalLength", "PetalWidth", "SepalLength", "SepalWidth"]
PREDICTION_COLS = ["value", "confidence"]
async def __aenter__(self) -> "BaseSourceContext":
self.__db = aiosqlite.connect(self.config.filename)
self.db = await self.__db.__aenter__()
self.db.row_factory = aiosqlite.Row
# Create table for feature data
await self.db.execute(
"CREATE TABLE IF NOT EXISTS features ("
"src_url TEXT PRIMARY KEY NOT NULL, "
+ (" REAL, ".join(self.FEATURE_COLS))
+ " REAL"
")"
from ..base import config
from .source import BaseSource
from ..util.cli.arg import Arg
from ..util.cli.cmd import CMD
from ..util.entrypoint import entry_point
@config
class FileSourceConfig:
filename: str
label: str = "unlabeled"
readonly: bool = False
@entry_point("file")
class FileSource(BaseSource):
"""
FileSource reads and write from a file on open / close.
"""
CONFIG = FileSourceConfig
async def __aenter__(self) -> "BaseSourceContext":
await self._open()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self._close()
async def _empty_file_init(self):
return {}
async def repos(self) -> AsyncIterator[Repo]:
for repo in self.parent.mem.values():
yield repo
async def repo(self, src_url: str) -> Repo:
return self.parent.mem.get(src_url, Repo(src_url))
@config
class MemorySourceConfig:
repos: List[Repo]
@entry_point("memory")
class MemorySource(BaseSource):
"""
Stores repos in a dict in memory
"""
CONFIG = MemorySourceConfig
CONTEXT = MemorySourceContext
def __init__(self, config: MemorySourceConfig) -> None:
super().__init__(config)
self.mem: Dict[str, Repo] = {}
if isinstance(self.config, MemorySourceConfig):
self.mem = {repo.src_url: repo for repo in self.config.repos}
for key in self.keys:
repo = await sctx.repo(key)
pdb.set_trace()
await sctx.update(repo)
class Merge(CMD):
"""
Merge repo data between sources
"""
arg_dest = Arg(
"dest", help="Sources merge repos into", type=BaseSource.load_labeled
)
arg_src = Arg(
"src", help="Sources to pull repos from", type=BaseSource.load_labeled
)
async def run(self):
async with self.src.withconfig(
self.extra_config
) as src, self.dest.withconfig(self.extra_config) as dest:
async with src() as sctx, dest() as dctx:
async for src in sctx.repos():
repo = Repo(src.src_url)
repo.merge(src)
repo.merge(await dctx.repo(repo.src_url))
await dctx.update(repo)
class ImportExportCMD(PortCMD, SourcesCMD):
"""Shared import export arguments"""
async def run(self):
async with self.sources as sources:
async with sources() as sctx:
for key in self.keys:
repo = await sctx.repo(key)
pdb.set_trace()
await sctx.update(repo)
class Merge(CMD):
"""
Merge repo data between sources
"""
arg_dest = Arg(
"dest", help="Sources merge repos into", type=BaseSource.load_labeled
)
arg_src = Arg(
"src", help="Sources to pull repos from", type=BaseSource.load_labeled
)
async def run(self):
async with self.src.withconfig(
self.extra_config
) as src, self.dest.withconfig(self.extra_config) as dest:
async with src() as sctx, dest() as dctx:
async for src in sctx.repos():
repo = Repo(src.src_url)
repo.merge(src)
repo.merge(await dctx.repo(repo.src_url))
await dctx.update(repo)
class SourcesCMD(CMD):
arg_sources = Arg(
"-sources",
help="Sources for loading and saving",
nargs="+",
default=Sources(
JSONSource(
FileSourceConfig(
filename=os.path.join(
os.path.expanduser("~"), ".cache", "dffml.json"
)
)
)
),
type=BaseSource.load_labeled,
action=list_action(Sources),
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Go through the list of sources and instantiate them with a config
# created from loading their arguments from cmd (self).
for i in range(0, len(self.sources)):
if inspect.isclass(self.sources[i]):
self.sources[i] = self.sources[i].withconfig(self.extra_config)
class ModelCMD(CMD):
"""
Set a models model dir.
"""
async def configure_source(self, request):
source_name = request.match_info["source"]
label = request.match_info["label"]
config = await request.json()
try:
source = BaseSource.load_labeled(f"{label}={source_name}")
except EntrypointNotFound as error:
self.logger.error(
f"/configure/source/ failed to load source: {error}"
)
return web.json_response(
{"error": f"source {source_name} not found"},
status=HTTPStatus.NOT_FOUND,
)
try:
source = source.withconfig(config)
except MissingConfig as error:
self.logger.error(
f"failed to configure source {source_name}: {error}"
)
return web.json_response(
async def list_sources(self, request):
return web.json_response(
{
source.ENTRY_POINT_ORIG_LABEL: source.args({})
for source in BaseSource.load()
},
dumps=partial(json.dumps, cls=JSONEncoder),
)