139 lines
4.5 KiB
Python
139 lines
4.5 KiB
Python
"""Ollama Cloud API client helpers for AI search and chat."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from collections.abc import AsyncIterator
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from config import (
|
|
OLLAMA_API_BASE,
|
|
OLLAMA_API_KEY,
|
|
OLLAMA_DEFAULT_MODEL,
|
|
OLLAMA_REQUEST_TIMEOUT_SECONDS,
|
|
)
|
|
|
|
|
|
class OllamaCloudError(RuntimeError):
|
|
"""Raised when Ollama Cloud is unavailable or rejects a request."""
|
|
|
|
def __init__(self, message: str, status_code: int = 502) -> None:
|
|
super().__init__(message)
|
|
self.status_code = status_code
|
|
|
|
|
|
def is_ollama_configured() -> bool:
|
|
return bool(OLLAMA_API_KEY)
|
|
|
|
|
|
def default_model() -> str:
|
|
return OLLAMA_DEFAULT_MODEL
|
|
|
|
|
|
def _headers(require_auth: bool = False) -> dict[str, str]:
|
|
headers = {"Content-Type": "application/json"}
|
|
if OLLAMA_API_KEY:
|
|
headers["Authorization"] = f"Bearer {OLLAMA_API_KEY}"
|
|
elif require_auth:
|
|
raise OllamaCloudError(
|
|
"Ollama Cloud is not configured. Set OLLAMA_API_KEY before using AI responses.",
|
|
status_code=503,
|
|
)
|
|
return headers
|
|
|
|
|
|
def _normalise_error(response: httpx.Response) -> str:
|
|
try:
|
|
payload = response.json()
|
|
except ValueError:
|
|
return response.text.strip() or response.reason_phrase
|
|
detail = payload.get("error") or payload.get("detail") or payload
|
|
return str(detail)
|
|
|
|
|
|
async def list_models() -> list[dict[str, Any]]:
|
|
timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS)
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
response = await client.get(f"{OLLAMA_API_BASE}/api/tags", headers=_headers(require_auth=False))
|
|
if response.status_code >= 400:
|
|
raise OllamaCloudError(_normalise_error(response), status_code=response.status_code)
|
|
payload = response.json()
|
|
models = payload.get("models") or []
|
|
return sorted(models, key=lambda item: item.get("name") or item.get("model") or "")
|
|
|
|
|
|
async def chat(
|
|
model: str,
|
|
messages: list[dict[str, Any]],
|
|
think: bool | str | None = None,
|
|
) -> dict[str, Any]:
|
|
payload: dict[str, Any] = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"stream": False,
|
|
}
|
|
if think is not None and think != "off":
|
|
payload["think"] = think
|
|
|
|
timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS)
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
response = await client.post(
|
|
f"{OLLAMA_API_BASE}/api/chat",
|
|
headers=_headers(require_auth=True),
|
|
json=payload,
|
|
)
|
|
if response.status_code >= 400:
|
|
raise OllamaCloudError(_normalise_error(response), status_code=response.status_code)
|
|
return response.json()
|
|
|
|
|
|
async def stream_chat(
|
|
model: str,
|
|
messages: list[dict[str, Any]],
|
|
think: bool | str | None = None,
|
|
) -> AsyncIterator[dict[str, Any]]:
|
|
payload: dict[str, Any] = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"stream": True,
|
|
}
|
|
if think is not None and think != "off":
|
|
payload["think"] = think
|
|
|
|
timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS)
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
async with client.stream(
|
|
"POST",
|
|
f"{OLLAMA_API_BASE}/api/chat",
|
|
headers=_headers(require_auth=True),
|
|
json=payload,
|
|
) as response:
|
|
if response.status_code >= 400:
|
|
body = (await response.aread()).decode("utf-8", errors="replace")
|
|
raise OllamaCloudError(body or response.reason_phrase, status_code=response.status_code)
|
|
|
|
async for line in response.aiter_lines():
|
|
clean_line = line.strip()
|
|
if not clean_line:
|
|
continue
|
|
try:
|
|
yield json.loads(clean_line)
|
|
except json.JSONDecodeError:
|
|
yield {"message": {"content": clean_line}}
|
|
|
|
|
|
async def web_search(query: str, max_results: int = 5) -> list[dict[str, Any]]:
|
|
timeout = httpx.Timeout(OLLAMA_REQUEST_TIMEOUT_SECONDS)
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
response = await client.post(
|
|
f"{OLLAMA_API_BASE}/api/web_search",
|
|
headers=_headers(require_auth=True),
|
|
json={"query": query, "max_results": max(1, min(max_results, 10))},
|
|
)
|
|
if response.status_code >= 400:
|
|
raise OllamaCloudError(_normalise_error(response), status_code=response.status_code)
|
|
payload = response.json()
|
|
return payload.get("results") or []
|