from fastapi import FastAPI, WebSocket, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Dict, List, Optional
from datetime import datetime
import json
import asyncio
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Message(BaseModel):
type: str
content: str
sender_id: int
timestamp: datetime = None
class Room(BaseModel):
id: str
name: str
owner_id: int
created_at: datetime
class User(BaseModel):
id: int
username: str
avatar: str | None = None
class RealTimeSystem:
def __init__(self):
self.rooms: Dict[str, Dict[int, WebSocket]] = {}
self.users: Dict[int, dict] = {}
self.message_history: Dict[str, List[Message]] = {}
async def join_room(self, room_id: str, user_id: int, websocket: WebSocket):
if room_id not in self.rooms:
self.rooms[room_id] = {}
self.message_history[room_id] = []
self.rooms[room_id][user_id] = websocket
# 通知其他人
await self.broadcast(room_id, {
"type": "user_joined",
"user_id": user_id,
"timestamp": datetime.utcnow().isoformat()
}, exclude=user_id)
def leave_room(self, room_id: str, user_id: int):
if room_id in self.rooms:
self.rooms[room_id].pop(user_id, None)
if not self.rooms[room_id]:
del self.rooms[room_id]
del self.message_history[room_id]
async def send_message(self, room_id: str, message: Message):
if room_id not in self.rooms:
raise HTTPException(status_code=404, detail="Room not found")
message.timestamp = datetime.utcnow()
self.message_history[room_id].append(message)
await self.broadcast(room_id, {
"type": "message",
"message": message.model_dump()
})
async def broadcast(self, room_id: str, data: dict, exclude: int | None = None):
if room_id not in self.rooms:
return
for user_id, websocket in self.rooms[room_id].items():
if user_id != exclude:
try:
await websocket.send_text(json.dumps(data))
except:
pass
def get_history(self, room_id: str, limit: int = 50) -> List[dict]:
if room_id not in self.message_history:
return []
return [m.model_dump() for m in self.message_history[room_id][-limit:]]
system = RealTimeSystem()
@app.websocket("/ws/{room_id}/{user_id}")
async def websocket_endpoint(
websocket: WebSocket,
room_id: str,
user_id: int
):
await websocket.accept()
await system.join_room(room_id, user_id, websocket)
# 发送历史消息
history = system.get_history(room_id)
await websocket.send_text(json.dumps({
"type": "history",
"messages": history
}))
try:
while True:
data = await websocket.receive_text()
message = json.loads(data)
if message["type"] == "message":
msg = Message(
type="message",
content=message["content"],
sender_id=user_id
)
await system.send_message(room_id, msg)
elif message["type"] == "typing":
await system.broadcast(room_id, {
"type": "typing",
"user_id": user_id
}, exclude=user_id)
except Exception as e:
print(f"Error: {e}")
finally:
system.leave_room(room_id, user_id)
@app.get("/api/rooms/{room_id}/history")
async def get_history(room_id: str, limit: int = 50):
return system.get_history(room_id, limit)