major refactoring
This commit is contained in:
@ -1,22 +1,28 @@
|
||||
# Run in dev mode using:
|
||||
# fastapi dev app.py
|
||||
#
|
||||
# After starting the server, you can navigate to http://localhost:8000 to see the web interface.
|
||||
# This will automatically start both the web server AND the UDP relay (server.py).
|
||||
# After starting, navigate to http://localhost:8000 to see the web interface.
|
||||
#
|
||||
# Note: This requires the user to have the fastapi CLI tool installed.
|
||||
# The user should be in the same directory as `app.py` as well as `index.html`.
|
||||
# The user should be in the same directory as `app.py`, `server.py`, and `index.html`.
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import socket
|
||||
import subprocess
|
||||
import os
|
||||
import signal
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
clients = set()
|
||||
server_process = None # Will hold the server.py subprocess
|
||||
|
||||
# Broadcast function to notify all SSE clients
|
||||
async def notify_clients(message: str):
|
||||
@ -35,7 +41,7 @@ async def sock_recvfrom(nonblocking_sock, *pos, loop, **kw):
|
||||
finally:
|
||||
loop.remove_reader(nonblocking_sock.fileno())
|
||||
|
||||
# Background task: UDP listener
|
||||
# Background task: UDP listener for SSE clients
|
||||
async def udp_listener():
|
||||
loop = asyncio.get_running_loop()
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
@ -48,13 +54,100 @@ async def udp_listener():
|
||||
message = data.decode()
|
||||
await notify_clients(message)
|
||||
|
||||
# Background task: Tracking data listener
|
||||
async def tracking_listener():
|
||||
loop = asyncio.get_running_loop()
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(("0.0.0.0", 5002)) # Different port for tracking data
|
||||
sock.setblocking(False)
|
||||
|
||||
while True:
|
||||
data, addr = await sock_recvfrom(sock, 1024 * 16, loop=loop) # Larger buffer for tracking data
|
||||
|
||||
# Only record if tracking is active
|
||||
if tracking_state["is_recording"]:
|
||||
try:
|
||||
message = data.decode()
|
||||
|
||||
# Extract source IP from tagged data (added by server.py)
|
||||
if message.startswith("SOURCE_IP:"):
|
||||
parts = message.split("|", 1)
|
||||
source_ip = parts[0].replace("SOURCE_IP:", "")
|
||||
actual_data = parts[1] if len(parts) > 1 else ""
|
||||
else:
|
||||
# Fallback if data isn't tagged (shouldn't happen normally)
|
||||
source_ip = addr[0]
|
||||
actual_data = message
|
||||
|
||||
# Only record data from the active player
|
||||
if source_ip == tracking_state["active_player_ip"]:
|
||||
parsed = parse_tracking_data(actual_data)
|
||||
|
||||
if parsed:
|
||||
import time
|
||||
# Calculate time elapsed since experiment start
|
||||
elapsed_time = time.time() - tracking_state["experiment_start_time"]
|
||||
|
||||
# Add sample with metadata
|
||||
sample = {
|
||||
"timestamp": time.time(),
|
||||
"elapsed_time": elapsed_time,
|
||||
"current_word": current_word_state["word"],
|
||||
"word_time_remaining": 0.0,
|
||||
"data": parsed
|
||||
}
|
||||
|
||||
# Calculate word time remaining if word is active
|
||||
if current_word_state["startTime"]:
|
||||
word_elapsed = time.time() - current_word_state["startTime"]
|
||||
sample["word_time_remaining"] = max(0, current_word_state["timeSeconds"] - word_elapsed)
|
||||
|
||||
tracking_state["samples"].append(sample)
|
||||
except Exception as e:
|
||||
print(f"Error processing tracking data: {e}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global server_process
|
||||
|
||||
# Start server.py subprocess
|
||||
print("Starting UDP relay server (server.py)...")
|
||||
try:
|
||||
server_process = subprocess.Popen(
|
||||
["python", "server.py"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
print(f"UDP relay server started with PID {server_process.pid}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to start server.py: {e}")
|
||||
print("You may need to start server.py manually.")
|
||||
|
||||
# Start background tasks
|
||||
asyncio.create_task(udp_listener())
|
||||
asyncio.create_task(tracking_listener())
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup: Stop server.py subprocess
|
||||
if server_process:
|
||||
print("Stopping UDP relay server...")
|
||||
server_process.terminate()
|
||||
try:
|
||||
server_process.wait(timeout=5)
|
||||
print("UDP relay server stopped")
|
||||
except subprocess.TimeoutExpired:
|
||||
print("Forcefully killing UDP relay server...")
|
||||
server_process.kill()
|
||||
server_process.wait()
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# Mount static files for serving CSS and JavaScript
|
||||
app.mount("/static", StaticFiles(directory="."), name="static")
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
# sock.bind(("0.0.0.0", 5002))
|
||||
@ -86,6 +179,17 @@ class Word(BaseModel):
|
||||
class WordList(BaseModel):
|
||||
words: list[str]
|
||||
|
||||
class VRConfig(BaseModel):
|
||||
player1_ip: str
|
||||
player2_ip: str
|
||||
server_ip: str
|
||||
mode: str
|
||||
|
||||
class TrackingStart(BaseModel):
|
||||
group_id: str
|
||||
condition: str
|
||||
active_player_ip: str
|
||||
|
||||
# Global state for current word display
|
||||
current_word_state = {
|
||||
"word": "",
|
||||
@ -94,6 +198,97 @@ current_word_state = {
|
||||
"startTime": None
|
||||
}
|
||||
|
||||
# Global state for tracking data recording
|
||||
tracking_state = {
|
||||
"is_recording": False,
|
||||
"group_id": "",
|
||||
"condition": "",
|
||||
"active_player_ip": "",
|
||||
"experiment_start_time": None,
|
||||
"samples": [] # List of tracking data samples
|
||||
}
|
||||
|
||||
def parse_tracking_data(data_str):
|
||||
"""
|
||||
Parse tracking data from VR headset to extract camera and controller positions/rotations.
|
||||
Returns dict with center_eye, left_hand, right_hand data or None if parsing fails.
|
||||
"""
|
||||
try:
|
||||
parts = data_str.split(';')
|
||||
|
||||
# The data structure from TrackWeights.cs:
|
||||
# 0-62: Face expressions (63 values)
|
||||
# 63-66: Root orientation (4 values)
|
||||
# 67-69: Root position (3 values)
|
||||
# 70: Root scale (1 value)
|
||||
# 71: Bone rotations length (1 value)
|
||||
bone_rot_length = int(float(parts[71]))
|
||||
# 72 to 72+(bone_rot_length*4)-1: Bone rotations (4 values each)
|
||||
idx = 72 + (bone_rot_length * 4)
|
||||
# idx: IsDataValid
|
||||
idx += 1
|
||||
# idx: IsDataHighConfidence
|
||||
idx += 1
|
||||
# idx: Bone translations length
|
||||
bone_trans_length = int(float(parts[idx]))
|
||||
idx += 1
|
||||
# idx to idx+(bone_trans_length*3)-1: Bone translations (3 values each)
|
||||
idx += bone_trans_length * 3
|
||||
# idx: SkeletonChangedCount
|
||||
idx += 1
|
||||
# idx to idx+3: Left eye rotation (4 values)
|
||||
idx += 4
|
||||
# idx to idx+3: Right eye rotation (4 values)
|
||||
idx += 4
|
||||
|
||||
# Now we're at the data we want!
|
||||
# Center eye camera: position (3) + rotation (4) = 7 values
|
||||
center_eye_pos_x = float(parts[idx])
|
||||
center_eye_pos_y = float(parts[idx + 1])
|
||||
center_eye_pos_z = float(parts[idx + 2])
|
||||
center_eye_rot_w = float(parts[idx + 3])
|
||||
center_eye_rot_x = float(parts[idx + 4])
|
||||
center_eye_rot_y = float(parts[idx + 5])
|
||||
center_eye_rot_z = float(parts[idx + 6])
|
||||
idx += 7
|
||||
|
||||
# Left hand controller: position (3) + rotation (4) = 7 values
|
||||
left_hand_pos_x = float(parts[idx])
|
||||
left_hand_pos_y = float(parts[idx + 1])
|
||||
left_hand_pos_z = float(parts[idx + 2])
|
||||
left_hand_rot_w = float(parts[idx + 3])
|
||||
left_hand_rot_x = float(parts[idx + 4])
|
||||
left_hand_rot_y = float(parts[idx + 5])
|
||||
left_hand_rot_z = float(parts[idx + 6])
|
||||
idx += 7
|
||||
|
||||
# Right hand controller: position (3) + rotation (4) = 7 values
|
||||
right_hand_pos_x = float(parts[idx])
|
||||
right_hand_pos_y = float(parts[idx + 1])
|
||||
right_hand_pos_z = float(parts[idx + 2])
|
||||
right_hand_rot_w = float(parts[idx + 3])
|
||||
right_hand_rot_x = float(parts[idx + 4])
|
||||
right_hand_rot_y = float(parts[idx + 5])
|
||||
right_hand_rot_z = float(parts[idx + 6])
|
||||
|
||||
return {
|
||||
"center_eye": {
|
||||
"pos": {"x": center_eye_pos_x, "y": center_eye_pos_y, "z": center_eye_pos_z},
|
||||
"rot": {"w": center_eye_rot_w, "x": center_eye_rot_x, "y": center_eye_rot_y, "z": center_eye_rot_z}
|
||||
},
|
||||
"left_hand": {
|
||||
"pos": {"x": left_hand_pos_x, "y": left_hand_pos_y, "z": left_hand_pos_z},
|
||||
"rot": {"w": left_hand_rot_w, "x": left_hand_rot_x, "y": left_hand_rot_y, "z": left_hand_rot_z}
|
||||
},
|
||||
"right_hand": {
|
||||
"pos": {"x": right_hand_pos_x, "y": right_hand_pos_y, "z": right_hand_pos_z},
|
||||
"rot": {"w": right_hand_rot_w, "x": right_hand_rot_x, "y": right_hand_rot_y, "z": right_hand_rot_z}
|
||||
}
|
||||
}
|
||||
except (IndexError, ValueError) as e:
|
||||
print(f"Error parsing tracking data: {e}")
|
||||
return None
|
||||
|
||||
@app.post("/word")
|
||||
def read_word(word: Word):
|
||||
import time
|
||||
@ -140,6 +335,150 @@ def shuffle_words(word_list: WordList):
|
||||
random.shuffle(shuffled)
|
||||
return { "status": "ok", "shuffled_words": shuffled }
|
||||
|
||||
@app.post("/send-config")
|
||||
def send_vr_config(config: VRConfig):
|
||||
"""
|
||||
Send IP and MODE configuration to VR headsets.
|
||||
This integrates the functionality from control.py into the web interface.
|
||||
"""
|
||||
try:
|
||||
# Create UDP socket for sending commands
|
||||
cmd_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
|
||||
# Send IP configuration to both players
|
||||
ip_msg = f"IP:{config.server_ip}".encode('utf-8')
|
||||
cmd_sock.sendto(ip_msg, (config.player1_ip, 5000))
|
||||
cmd_sock.sendto(ip_msg, (config.player2_ip, 5000))
|
||||
|
||||
# Send MODE configuration to both players
|
||||
mode_msg = f"MODE:{config.mode}".encode('utf-8')
|
||||
cmd_sock.sendto(mode_msg, (config.player1_ip, 5000))
|
||||
cmd_sock.sendto(mode_msg, (config.player2_ip, 5000))
|
||||
|
||||
cmd_sock.close()
|
||||
|
||||
print(f"Sent IP config: {ip_msg.decode()}")
|
||||
print(f"Sent MODE config: {mode_msg.decode()}")
|
||||
print(f"To players: {config.player1_ip}, {config.player2_ip}")
|
||||
|
||||
return { "status": "ok" }
|
||||
except Exception as e:
|
||||
print(f"Error sending VR config: {e}")
|
||||
return { "status": "error", "message": str(e) }
|
||||
|
||||
@app.post("/tracking/start")
|
||||
def start_tracking(config: TrackingStart):
|
||||
"""
|
||||
Start recording tracking data from the active player.
|
||||
"""
|
||||
import time
|
||||
|
||||
tracking_state["is_recording"] = True
|
||||
tracking_state["group_id"] = config.group_id
|
||||
tracking_state["condition"] = config.condition
|
||||
tracking_state["active_player_ip"] = config.active_player_ip
|
||||
tracking_state["experiment_start_time"] = time.time()
|
||||
tracking_state["samples"] = []
|
||||
|
||||
print(f"Started tracking: group={config.group_id}, condition={config.condition}, player={config.active_player_ip}")
|
||||
return { "status": "ok", "message": "Tracking started" }
|
||||
|
||||
@app.post("/tracking/stop")
|
||||
def stop_tracking():
|
||||
"""
|
||||
Stop recording tracking data.
|
||||
"""
|
||||
tracking_state["is_recording"] = False
|
||||
sample_count = len(tracking_state["samples"])
|
||||
print(f"Stopped tracking: {sample_count} samples recorded")
|
||||
return { "status": "ok", "message": f"Tracking stopped. {sample_count} samples recorded." }
|
||||
|
||||
@app.get("/tracking/status")
|
||||
def get_tracking_status():
|
||||
"""
|
||||
Get current tracking status.
|
||||
"""
|
||||
return {
|
||||
"is_recording": tracking_state["is_recording"],
|
||||
"sample_count": len(tracking_state["samples"]),
|
||||
"group_id": tracking_state["group_id"],
|
||||
"condition": tracking_state["condition"]
|
||||
}
|
||||
|
||||
@app.get("/tracking/download")
|
||||
def download_tracking_csv():
|
||||
"""
|
||||
Download tracking data as CSV.
|
||||
"""
|
||||
import io
|
||||
from datetime import datetime
|
||||
|
||||
if len(tracking_state["samples"]) == 0:
|
||||
return { "status": "error", "message": "No tracking data available" }
|
||||
|
||||
# Create CSV content
|
||||
output = io.StringIO()
|
||||
|
||||
# Write header
|
||||
header = [
|
||||
"timestamp", "elapsed_time", "group_id", "condition", "current_word", "word_time_remaining",
|
||||
"center_eye_pos_x", "center_eye_pos_y", "center_eye_pos_z",
|
||||
"center_eye_rot_w", "center_eye_rot_x", "center_eye_rot_y", "center_eye_rot_z",
|
||||
"left_hand_pos_x", "left_hand_pos_y", "left_hand_pos_z",
|
||||
"left_hand_rot_w", "left_hand_rot_x", "left_hand_rot_y", "left_hand_rot_z",
|
||||
"right_hand_pos_x", "right_hand_pos_y", "right_hand_pos_z",
|
||||
"right_hand_rot_w", "right_hand_rot_x", "right_hand_rot_y", "right_hand_rot_z"
|
||||
]
|
||||
output.write(";".join(header) + "\n")
|
||||
|
||||
# Write data rows
|
||||
for sample in tracking_state["samples"]:
|
||||
data = sample["data"]
|
||||
row = [
|
||||
str(sample["timestamp"]),
|
||||
f"{sample['elapsed_time']:.4f}",
|
||||
tracking_state["group_id"],
|
||||
tracking_state["condition"],
|
||||
sample["current_word"],
|
||||
f"{sample['word_time_remaining']:.4f}",
|
||||
f"{data['center_eye']['pos']['x']:.4f}",
|
||||
f"{data['center_eye']['pos']['y']:.4f}",
|
||||
f"{data['center_eye']['pos']['z']:.4f}",
|
||||
f"{data['center_eye']['rot']['w']:.4f}",
|
||||
f"{data['center_eye']['rot']['x']:.4f}",
|
||||
f"{data['center_eye']['rot']['y']:.4f}",
|
||||
f"{data['center_eye']['rot']['z']:.4f}",
|
||||
f"{data['left_hand']['pos']['x']:.4f}",
|
||||
f"{data['left_hand']['pos']['y']:.4f}",
|
||||
f"{data['left_hand']['pos']['z']:.4f}",
|
||||
f"{data['left_hand']['rot']['w']:.4f}",
|
||||
f"{data['left_hand']['rot']['x']:.4f}",
|
||||
f"{data['left_hand']['rot']['y']:.4f}",
|
||||
f"{data['left_hand']['rot']['z']:.4f}",
|
||||
f"{data['right_hand']['pos']['x']:.4f}",
|
||||
f"{data['right_hand']['pos']['y']:.4f}",
|
||||
f"{data['right_hand']['pos']['z']:.4f}",
|
||||
f"{data['right_hand']['rot']['w']:.4f}",
|
||||
f"{data['right_hand']['rot']['x']:.4f}",
|
||||
f"{data['right_hand']['rot']['y']:.4f}",
|
||||
f"{data['right_hand']['rot']['z']:.4f}"
|
||||
]
|
||||
output.write(";".join(row) + "\n")
|
||||
|
||||
csv_content = output.getvalue()
|
||||
output.close()
|
||||
|
||||
# Create filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
filename = f"{tracking_state['group_id']}_tracking_{timestamp}.csv"
|
||||
|
||||
from fastapi.responses import Response
|
||||
return Response(
|
||||
content=csv_content,
|
||||
media_type="text/csv",
|
||||
headers={"Content-Disposition": f"attachment; filename={filename}"}
|
||||
)
|
||||
|
||||
|
||||
# SSE endpoint
|
||||
@app.get("/news")
|
||||
|
||||
Reference in New Issue
Block a user