Compare commits
3 Commits
297e20ce8d
...
f340d0fa3e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f340d0fa3e | ||
|
|
edc53cb6eb | ||
|
|
725cece5c1 |
@@ -55,12 +55,15 @@ async def get_current_user(
|
|||||||
raise credentials_exc
|
raise credentials_exc
|
||||||
|
|
||||||
# Live tier lookup — subscription row is the authoritative source.
|
# Live tier lookup — subscription row is the authoritative source.
|
||||||
|
# In dev, fall back to 'power' (unlimited) so quota limits don't
|
||||||
|
# block local development when no Stripe subscription exists.
|
||||||
from app.models import Subscription, User # noqa: PLC0415
|
from app.models import Subscription, User # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
tier: str = result.scalar_one_or_none() or "free"
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
|
tier: str = result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
# Fetch name/surname from user row.
|
# Fetch name/surname from user row.
|
||||||
user_result = await db.execute(
|
user_result = await db.execute(
|
||||||
|
|||||||
@@ -190,8 +190,11 @@ async def trigger_agent_run(
|
|||||||
enabled=True,
|
enabled=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
||||||
|
stable_agent_id = body.agent_id or config.id
|
||||||
|
|
||||||
run_log = AgentRunLog(
|
run_log = AgentRunLog(
|
||||||
agent_id=config.id,
|
agent_id=stable_agent_id,
|
||||||
agent_type="local",
|
agent_type="local",
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
status="running",
|
status="running",
|
||||||
@@ -200,8 +203,14 @@ async def trigger_agent_run(
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(run_log)
|
await db.refresh(run_log)
|
||||||
|
|
||||||
|
run_context = {
|
||||||
|
"type": "agent_batch",
|
||||||
|
"run_id": run_log.id,
|
||||||
|
"agent_id": stable_agent_id,
|
||||||
|
}
|
||||||
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
run_local_agent(current_user.id, config, run_log, device_manager)
|
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
|
||||||
)
|
)
|
||||||
|
|
||||||
return _to_run_log_response(run_log)
|
return _to_run_log_response(run_log)
|
||||||
|
|||||||
@@ -81,16 +81,18 @@ class TierManager:
|
|||||||
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||||
"""Return the current billing tier for ``user_id`` from the DB.
|
"""Return the current billing tier for ``user_id`` from the DB.
|
||||||
|
|
||||||
Falls back to ``'free'`` when no subscription row exists.
|
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod
|
||||||
|
when no subscription row exists.
|
||||||
"""
|
"""
|
||||||
from app.models import Subscription # noqa: PLC0415
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
tier: str | None = result.scalar_one_or_none()
|
tier: str | None = result.scalar_one_or_none()
|
||||||
if tier is None or tier not in FEATURES:
|
if tier is None or tier not in FEATURES:
|
||||||
return "free"
|
return "power" if settings.ENV == "dev" else "free"
|
||||||
return tier # type: ignore[return-value]
|
return tier # type: ignore[return-value]
|
||||||
|
|
||||||
# ── Feature access ───────────────────────────────────────────────────
|
# ── Feature access ───────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -188,12 +188,18 @@ def _is_overdue(schedule_cron: str, last_run_at: datetime | None) -> bool:
|
|||||||
def _make_agent_executor(
|
def _make_agent_executor(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
device_mgr: DeviceConnectionManager,
|
device_mgr: DeviceConnectionManager,
|
||||||
|
run_context: dict | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Create a WS callback for ``set_client_executor()`` so that all tools
|
"""Create a WS callback for ``set_client_executor()`` so that all tools
|
||||||
can use ``execute_on_client()`` during an agent run.
|
can use ``execute_on_client()`` during an agent run.
|
||||||
|
|
||||||
|
If *run_context* is provided it is attached to every ``tool_call`` frame
|
||||||
|
so the Electron client can attribute actions to the correct agent run.
|
||||||
"""
|
"""
|
||||||
async def _executor(payload: dict) -> dict:
|
async def _executor(payload: dict) -> dict:
|
||||||
payload["type"] = "tool_call"
|
payload["type"] = "tool_call"
|
||||||
|
if run_context:
|
||||||
|
payload["run_context"] = run_context
|
||||||
call_id = payload["id"]
|
call_id = payload["id"]
|
||||||
fut = device_mgr.create_pending_call(user_id, call_id)
|
fut = device_mgr.create_pending_call(user_id, call_id)
|
||||||
await device_mgr.send_frame(user_id, payload)
|
await device_mgr.send_frame(user_id, payload)
|
||||||
@@ -328,6 +334,7 @@ async def run_local_agent(
|
|||||||
config: LocalAgentConfig,
|
config: LocalAgentConfig,
|
||||||
run_log: AgentRunLog,
|
run_log: AgentRunLog,
|
||||||
device_mgr: DeviceConnectionManager,
|
device_mgr: DeviceConnectionManager,
|
||||||
|
run_context: dict | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute a local directory agent run using two-phase LLM-with-tools.
|
"""Execute a local directory agent run using two-phase LLM-with-tools.
|
||||||
|
|
||||||
@@ -363,7 +370,7 @@ async def run_local_agent(
|
|||||||
return
|
return
|
||||||
|
|
||||||
# ── Set up WS executor for tools ────────────────────────────────
|
# ── Set up WS executor for tools ────────────────────────────────
|
||||||
executor = _make_agent_executor(user_id, device_mgr)
|
executor = _make_agent_executor(user_id, device_mgr, run_context)
|
||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
|
|
||||||
errors: list[str] = []
|
errors: list[str] = []
|
||||||
@@ -508,6 +515,18 @@ async def run_local_agent(
|
|||||||
len(errors),
|
len(errors),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Notify the Electron client that the run is complete so it can close
|
||||||
|
# the run record in its local SQLite.
|
||||||
|
if run_context and device_mgr.is_online(user_id):
|
||||||
|
try:
|
||||||
|
await device_mgr.send_frame(user_id, {
|
||||||
|
"type": "run_complete",
|
||||||
|
"run_context": run_context,
|
||||||
|
"status": final_status,
|
||||||
|
})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: run=%s failed to send run_complete: %s", run_id, exc)
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud agent runner ─────────────────────────────────────────────────────
|
# ── Cloud agent runner ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
@@ -295,6 +295,7 @@ class AgentCreationCheckResponse(BaseModel):
|
|||||||
class AgentTriggerRequest(BaseModel):
|
class AgentTriggerRequest(BaseModel):
|
||||||
directory: str = Field(min_length=1)
|
directory: str = Field(min_length=1)
|
||||||
device_id: str = Field(default="")
|
device_id: str = Field(default="")
|
||||||
|
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
|
||||||
what_to_extract: list[str] = Field(min_length=1)
|
what_to_extract: list[str] = Field(min_length=1)
|
||||||
actions_by_type: dict[str, list[str]] | None = None
|
actions_by_type: dict[str, list[str]] | None = None
|
||||||
batch_interval: str = Field(min_length=1)
|
batch_interval: str = Field(min_length=1)
|
||||||
|
|||||||
Reference in New Issue
Block a user