mirror of
https://github.com/langgenius/dify.git
synced 2026-02-15 01:50:14 -05:00
Compare commits
8 Commits
main
...
feat/defau
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5720017e4d | ||
|
|
c21c6c3815 | ||
|
|
0e55ef7336 | ||
|
|
eb5b747a06 | ||
|
|
e643b83460 | ||
|
|
95d1913f2c | ||
|
|
0318f2ec71 | ||
|
|
68e3a1c990 |
@@ -289,6 +289,12 @@ class AccountService:
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account)
|
||||
|
||||
# Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace).
|
||||
if getattr(dify_config, "ENTERPRISE_ENABLED", False):
|
||||
from services.enterprise.enterprise_service import try_join_default_workspace
|
||||
|
||||
try_join_default_workspace(str(account.id))
|
||||
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
@@ -1407,6 +1413,12 @@ class RegisterService:
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace).
|
||||
if getattr(dify_config, "ENTERPRISE_ENABLED", False):
|
||||
from services.enterprise.enterprise_service import try_join_default_workspace
|
||||
|
||||
try_join_default_workspace(str(account.id))
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
db.session.rollback()
|
||||
logger.exception("Register failed")
|
||||
|
||||
@@ -39,6 +39,9 @@ class BaseRequest:
|
||||
endpoint: str,
|
||||
json: Any | None = None,
|
||||
params: Mapping[str, Any] | None = None,
|
||||
*,
|
||||
timeout: float | httpx.Timeout | None = None,
|
||||
raise_for_status: bool = False,
|
||||
) -> Any:
|
||||
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
@@ -53,7 +56,16 @@ class BaseRequest:
|
||||
logger.debug("Failed to generate traceparent header", exc_info=True)
|
||||
|
||||
with httpx.Client(mounts=mounts) as client:
|
||||
response = client.request(method, url, json=json, params=params, headers=headers)
|
||||
# IMPORTANT:
|
||||
# - In httpx, passing timeout=None disables timeouts (infinite) and overrides the library default.
|
||||
# - To preserve httpx's default timeout behavior for existing call sites, only pass the kwarg when set.
|
||||
request_kwargs: dict[str, Any] = {"json": json, "params": params, "headers": headers}
|
||||
if timeout is not None:
|
||||
request_kwargs["timeout"] = timeout
|
||||
|
||||
response = client.request(method, url, **request_kwargs)
|
||||
if raise_for_status:
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from configs import dify_config
|
||||
from services.enterprise.base import EnterpriseRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
|
||||
|
||||
|
||||
class WebAppSettings(BaseModel):
|
||||
access_mode: str = Field(
|
||||
@@ -30,6 +37,52 @@ class WorkspacePermission(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class DefaultWorkspaceJoinResult(BaseModel):
|
||||
"""
|
||||
Result of ensuring an account is a member of the enterprise default workspace.
|
||||
|
||||
- joined=True is idempotent (already a member also returns True)
|
||||
- joined=False means enterprise default workspace is not configured or invalid/archived
|
||||
"""
|
||||
|
||||
# Only workspace_id can be empty when "no default workspace configured".
|
||||
workspace_id: str = ""
|
||||
|
||||
# These fields are required to avoid silently treating error payloads as "skipped".
|
||||
joined: bool
|
||||
message: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
def try_join_default_workspace(account_id: str) -> None:
|
||||
"""
|
||||
Enterprise-only side-effect: ensure account is a member of the default workspace.
|
||||
|
||||
This is a best-effort integration. Failures must not block user registration.
|
||||
"""
|
||||
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
|
||||
try:
|
||||
result = EnterpriseService.join_default_workspace(account_id=account_id)
|
||||
if result.joined:
|
||||
logger.info(
|
||||
"Joined enterprise default workspace for account %s (workspace_id=%s)",
|
||||
account_id,
|
||||
result.workspace_id,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Skipped joining enterprise default workspace for account %s (message=%s)",
|
||||
account_id,
|
||||
result.message,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to join enterprise default workspace for account %s", account_id, exc_info=True)
|
||||
|
||||
|
||||
class EnterpriseService:
|
||||
@classmethod
|
||||
def get_info(cls):
|
||||
@@ -39,6 +92,34 @@ class EnterpriseService:
|
||||
def get_workspace_info(cls, tenant_id: str):
|
||||
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
|
||||
|
||||
@classmethod
|
||||
def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult:
|
||||
"""
|
||||
Call enterprise inner API to add an account to the default workspace.
|
||||
|
||||
NOTE: EnterpriseRequest.base_url is expected to already include the `/inner/api` prefix,
|
||||
so the endpoint here is `/default-workspace/members`.
|
||||
"""
|
||||
|
||||
# Ensure we are sending a UUID-shaped string (enterprise side validates too).
|
||||
try:
|
||||
uuid.UUID(account_id)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"account_id must be a valid UUID: {account_id}") from e
|
||||
|
||||
data = EnterpriseRequest.send_request(
|
||||
"POST",
|
||||
"/default-workspace/members",
|
||||
json={"account_id": account_id},
|
||||
timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS,
|
||||
raise_for_status=True,
|
||||
)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Invalid response format from enterprise default workspace API")
|
||||
if "joined" not in data or "message" not in data:
|
||||
raise ValueError("Invalid response payload from enterprise default workspace API")
|
||||
return DefaultWorkspaceJoinResult.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def get_app_sso_settings_last_update_time(cls) -> datetime:
|
||||
data = EnterpriseRequest.send_request("GET", "/sso/app/last-update-time")
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
"""Unit tests for enterprise service integrations.
|
||||
|
||||
This module covers the enterprise-only default workspace auto-join behavior:
|
||||
- Enterprise mode disabled: no external calls
|
||||
- Successful join / skipped join: no errors
|
||||
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.enterprise.enterprise_service import (
|
||||
DefaultWorkspaceJoinResult,
|
||||
EnterpriseService,
|
||||
try_join_default_workspace,
|
||||
)
|
||||
|
||||
|
||||
class TestJoinDefaultWorkspace:
|
||||
def test_join_default_workspace_success(self):
|
||||
account_id = "11111111-1111-1111-1111-111111111111"
|
||||
response = {"workspace_id": "22222222-2222-2222-2222-222222222222", "joined": True, "message": "ok"}
|
||||
|
||||
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
|
||||
mock_send_request.return_value = response
|
||||
|
||||
result = EnterpriseService.join_default_workspace(account_id=account_id)
|
||||
|
||||
assert isinstance(result, DefaultWorkspaceJoinResult)
|
||||
assert result.workspace_id == response["workspace_id"]
|
||||
assert result.joined is True
|
||||
assert result.message == "ok"
|
||||
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/default-workspace/members",
|
||||
json={"account_id": account_id},
|
||||
timeout=1.0,
|
||||
raise_for_status=True,
|
||||
)
|
||||
|
||||
def test_join_default_workspace_invalid_response_format_raises(self):
|
||||
account_id = "11111111-1111-1111-1111-111111111111"
|
||||
|
||||
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
|
||||
mock_send_request.return_value = "not-a-dict"
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid response format"):
|
||||
EnterpriseService.join_default_workspace(account_id=account_id)
|
||||
|
||||
def test_join_default_workspace_invalid_account_id_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
EnterpriseService.join_default_workspace(account_id="not-a-uuid")
|
||||
|
||||
def test_join_default_workspace_missing_required_fields_raises(self):
|
||||
account_id = "11111111-1111-1111-1111-111111111111"
|
||||
response = {"workspace_id": "", "message": "ok"} # missing "joined"
|
||||
|
||||
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
|
||||
mock_send_request.return_value = response
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid response payload"):
|
||||
EnterpriseService.join_default_workspace(account_id=account_id)
|
||||
|
||||
|
||||
class TestTryJoinDefaultWorkspace:
|
||||
def test_try_join_default_workspace_enterprise_disabled_noop(self):
|
||||
with (
|
||||
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
|
||||
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
|
||||
try_join_default_workspace("11111111-1111-1111-1111-111111111111")
|
||||
|
||||
mock_join.assert_not_called()
|
||||
|
||||
def test_try_join_default_workspace_successful_join_does_not_raise(self):
|
||||
account_id = "11111111-1111-1111-1111-111111111111"
|
||||
|
||||
with (
|
||||
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
|
||||
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_join.return_value = DefaultWorkspaceJoinResult(
|
||||
workspace_id="22222222-2222-2222-2222-222222222222",
|
||||
joined=True,
|
||||
message="ok",
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
try_join_default_workspace(account_id)
|
||||
|
||||
mock_join.assert_called_once_with(account_id=account_id)
|
||||
|
||||
def test_try_join_default_workspace_skipped_join_does_not_raise(self):
|
||||
account_id = "11111111-1111-1111-1111-111111111111"
|
||||
|
||||
with (
|
||||
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
|
||||
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_join.return_value = DefaultWorkspaceJoinResult(
|
||||
workspace_id="",
|
||||
joined=False,
|
||||
message="no default workspace configured",
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
try_join_default_workspace(account_id)
|
||||
|
||||
mock_join.assert_called_once_with(account_id=account_id)
|
||||
|
||||
def test_try_join_default_workspace_api_failure_soft_fails(self):
|
||||
account_id = "11111111-1111-1111-1111-111111111111"
|
||||
|
||||
with (
|
||||
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
|
||||
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_join.side_effect = Exception("network failure")
|
||||
|
||||
# Should not raise
|
||||
try_join_default_workspace(account_id)
|
||||
|
||||
mock_join.assert_called_once_with(account_id=account_id)
|
||||
|
||||
def test_try_join_default_workspace_invalid_account_id_soft_fails(self):
|
||||
with patch("services.enterprise.enterprise_service.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
# Should not raise even though UUID parsing fails inside join_default_workspace
|
||||
try_join_default_workspace("not-a-uuid")
|
||||
@@ -1064,6 +1064,67 @@ class TestRegisterService:
|
||||
|
||||
# ==================== Registration Tests ====================
|
||||
|
||||
def test_create_account_and_tenant_calls_default_workspace_join_when_enterprise_enabled(
|
||||
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
|
||||
):
|
||||
"""Enterprise-only side effect should be invoked when ENTERPRISE_ENABLED is True."""
|
||||
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False)
|
||||
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
|
||||
account_id="11111111-1111-1111-1111-111111111111"
|
||||
)
|
||||
|
||||
with (
|
||||
patch("services.account_service.AccountService.create_account") as mock_create_account,
|
||||
patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_workspace,
|
||||
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
|
||||
):
|
||||
mock_create_account.return_value = mock_account
|
||||
|
||||
result = AccountService.create_account_and_tenant(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
password=None,
|
||||
)
|
||||
|
||||
assert result == mock_account
|
||||
mock_create_workspace.assert_called_once_with(account=mock_account)
|
||||
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
|
||||
|
||||
def test_create_account_and_tenant_does_not_call_default_workspace_join_when_enterprise_disabled(
|
||||
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
|
||||
):
|
||||
"""Enterprise-only side effect should not be invoked when ENTERPRISE_ENABLED is False."""
|
||||
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False, raising=False)
|
||||
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
|
||||
account_id="11111111-1111-1111-1111-111111111111"
|
||||
)
|
||||
|
||||
with (
|
||||
patch("services.account_service.AccountService.create_account") as mock_create_account,
|
||||
patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_workspace,
|
||||
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
|
||||
):
|
||||
mock_create_account.return_value = mock_account
|
||||
|
||||
AccountService.create_account_and_tenant(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
password=None,
|
||||
)
|
||||
|
||||
mock_create_workspace.assert_called_once_with(account=mock_account)
|
||||
mock_join_default_workspace.assert_not_called()
|
||||
|
||||
def test_register_success(self, mock_db_dependencies, mock_external_service_dependencies):
|
||||
"""Test successful account registration."""
|
||||
# Setup mocks
|
||||
@@ -1115,6 +1176,65 @@ class TestRegisterService:
|
||||
mock_event.send.assert_called_once_with(mock_tenant)
|
||||
self._assert_database_operations_called(mock_db_dependencies["db"])
|
||||
|
||||
def test_register_calls_default_workspace_join_when_enterprise_enabled(
|
||||
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
|
||||
):
|
||||
"""Enterprise-only side effect should be invoked after successful register commit."""
|
||||
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False)
|
||||
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
|
||||
account_id="11111111-1111-1111-1111-111111111111"
|
||||
)
|
||||
|
||||
with (
|
||||
patch("services.account_service.AccountService.create_account") as mock_create_account,
|
||||
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
|
||||
):
|
||||
mock_create_account.return_value = mock_account
|
||||
|
||||
result = RegisterService.register(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
password="password123",
|
||||
language="en-US",
|
||||
create_workspace_required=False,
|
||||
)
|
||||
|
||||
assert result == mock_account
|
||||
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
|
||||
|
||||
def test_register_does_not_call_default_workspace_join_when_enterprise_disabled(
|
||||
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
|
||||
):
|
||||
"""Enterprise-only side effect should not be invoked when ENTERPRISE_ENABLED is False."""
|
||||
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False, raising=False)
|
||||
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
|
||||
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
|
||||
account_id="11111111-1111-1111-1111-111111111111"
|
||||
)
|
||||
|
||||
with (
|
||||
patch("services.account_service.AccountService.create_account") as mock_create_account,
|
||||
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
|
||||
):
|
||||
mock_create_account.return_value = mock_account
|
||||
|
||||
RegisterService.register(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
password="password123",
|
||||
language="en-US",
|
||||
create_workspace_required=False,
|
||||
)
|
||||
|
||||
mock_join_default_workspace.assert_not_called()
|
||||
|
||||
def test_register_with_oauth(self, mock_db_dependencies, mock_external_service_dependencies):
|
||||
"""Test account registration with OAuth integration."""
|
||||
# Setup mocks
|
||||
|
||||
Reference in New Issue
Block a user